From ed9d82d6e8cd58fb97421242e6f26a64e0e66d53 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 11 Oct 2024 18:41:37 +0200 Subject: [PATCH 001/178] feat[next]: GTIR `as_fieldop` fusion pass (#1670) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a pass that transforms expressions like ``` as_fieldop(λ(__arg0, __arg1) → ·__arg0 + ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)( as_fieldop(λ(__arg0, __arg1) → ·__arg0 × ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2), inp3 ) ``` into ``` as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3) ``` --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 2 +- .../iterator/transforms/fuse_as_fieldop.py | 204 ++++++++++++++++++ .../next/iterator/type_system/inference.py | 10 + .../transforms_tests/test_fuse_as_fieldop.py | 112 ++++++++++ 4 files changed, 327 insertions(+), 1 deletion(-) create mode 100644 src/gt4py/next/iterator/transforms/fuse_as_fieldop.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index b2662fa278..19e26f24b6 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -446,7 +446,7 @@ def domain( ) -def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call: +def as_fieldop(expr: itir.Expr, domain: Optional[itir.Expr] = None) -> call: """ Create an `as_fieldop` call. diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py new file mode 100644 index 0000000000..51bbd91d83 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -0,0 +1,204 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dataclasses +from typing import Optional + +from gt4py import eve +from gt4py.eve import utils as eve_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.transforms import inline_lambdas, inline_lifts, trace_shifts +from gt4py.next.iterator.type_system import ( + inference as type_inference, + type_specifications as it_ts, +) +from gt4py.next.type_system import type_info, type_specifications as ts + + +def _merge_arguments( + args1: dict[str, itir.Expr], arg2: dict[str, itir.Expr] +) -> dict[str, itir.Expr]: + new_args = {**args1} + for stencil_param, stencil_arg in arg2.items(): + if stencil_param not in new_args: + new_args[stencil_param] = stencil_arg + else: + assert new_args[stencil_param] == stencil_arg + return new_args + + +def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: + """ + Canonicalize applied `as_fieldop`s. + + In case the stencil argument is a `deref` wrap it into a lambda such that we have a unified + format to work with (e.g. each parameter has a name without the need to special case). + """ + assert cpm.is_applied_as_fieldop(expr) + + stencil = expr.fun.args[0] # type: ignore[attr-defined] + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] + if cpm.is_ref_to(stencil, "deref"): + stencil = im.lambda_("arg")(im.deref("arg")) + new_expr = im.as_fieldop(stencil, domain)(*expr.args) + type_inference.copy_type(from_=expr, to=new_expr) + + return new_expr + + return expr + + +@dataclasses.dataclass +class FuseAsFieldOp(eve.NodeTranslator): + """ + Merge multiple `as_fieldop` calls into one. + + >>> from gt4py import next as gtx + >>> from gt4py.next.iterator.ir_utils import ir_makers as im + >>> IDim = gtx.Dimension("IDim") + >>> field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + >>> d = im.domain("cartesian_domain", {IDim: (0, 1)}) + >>> nested_as_fieldop = im.op_as_fieldop("plus", d)( + ... im.op_as_fieldop("multiplies", d)( + ... im.ref("inp1", field_type), im.ref("inp2", field_type) + ... ), + ... im.ref("inp3", field_type), + ... ) + >>> print(nested_as_fieldop) + as_fieldop(λ(__arg0, __arg1) → ·__arg0 + ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)( + as_fieldop(λ(__arg0, __arg1) → ·__arg0 × ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2), inp3 + ) + >>> print( + ... FuseAsFieldOp.apply( + ... nested_as_fieldop, offset_provider={}, allow_undeclared_symbols=True + ... ) + ... ) + as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3) + """ # noqa: RUF002 # ignore ambiguous multiplication character + + uids: eve_utils.UIDGenerator + + def _inline_as_fieldop_arg(self, arg: itir.Expr) -> tuple[itir.Expr, dict[str, itir.Expr]]: + assert cpm.is_applied_as_fieldop(arg) + arg = _canonicalize_as_fieldop(arg) + + stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` + inner_args: list[itir.Expr] = arg.args + extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg + + stencil_params: list[itir.Sym] = [] + stencil_body: itir.Expr = stencil.expr + + for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): + if isinstance(inner_arg, itir.SymRef): + stencil_params.append(inner_param) + extracted_args[inner_arg.id] = inner_arg + elif isinstance(inner_arg, itir.Literal): + # note: only literals, not all scalar expressions are required as it doesn't make sense + # for them to be computed per grid point. + stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( + stencil_body + ) + else: + # a scalar expression, a previously not inlined `as_fieldop` call or an opaque + # expression e.g. containing a tuple + stencil_params.append(inner_param) + new_outer_stencil_param = self.uids.sequential_id(prefix="__iasfop") + extracted_args[new_outer_stencil_param] = inner_arg + + return im.lift(im.lambda_(*stencil_params)(stencil_body))( + *extracted_args.keys() + ), extracted_args + + @classmethod + def apply( + cls, + node: itir.Program, + *, + offset_provider, + uids: Optional[eve_utils.UIDGenerator] = None, + allow_undeclared_symbols=False, + ): + node = type_inference.infer( + node, offset_provider=offset_provider, allow_undeclared_symbols=allow_undeclared_symbols + ) + + if not uids: + uids = eve_utils.UIDGenerator() + + return cls(uids=uids).visit(node) + + def visit_FunCall(self, node: itir.FunCall): + node = self.generic_visit(node) + + if cpm.is_call_to(node.fun, "as_fieldop"): + node = _canonicalize_as_fieldop(node) + + if cpm.is_call_to(node.fun, "as_fieldop") and isinstance(node.fun.args[0], itir.Lambda): + stencil: itir.Lambda = node.fun.args[0] + domain = node.fun.args[1] if len(node.fun.args) > 1 else None + + shifts = trace_shifts.trace_stencil(stencil) + + args: list[itir.Expr] = node.args + + new_args: dict[str, itir.Expr] = {} + new_stencil_body: itir.Expr = stencil.expr + + for stencil_param, arg, arg_shifts in zip(stencil.params, args, shifts, strict=True): + assert isinstance(arg.type, ts.TypeSpec) + dtype = type_info.extract_dtype(arg.type) + # TODO(tehrengruber): make this configurable + should_inline = isinstance(arg, itir.Literal) or ( + isinstance(arg, itir.FunCall) + and (cpm.is_call_to(arg.fun, "as_fieldop") or cpm.is_call_to(arg, "if_")) + and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) + ) + if should_inline: + if cpm.is_applied_as_fieldop(arg): + pass + elif cpm.is_call_to(arg, "if_"): + # TODO(tehrengruber): revisit if we want to inline if_ + type_ = arg.type + arg = im.op_as_fieldop("if_")(*arg.args) + arg.type = type_ + elif isinstance(arg, itir.Literal): + arg = im.op_as_fieldop(im.lambda_()(arg))() + else: + raise NotImplementedError() + + inline_expr, extracted_args = self._inline_as_fieldop_arg(arg) + + new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body) + + new_args = _merge_arguments(new_args, extracted_args) + else: + new_param: str + if isinstance( + arg, itir.SymRef + ): # use name from outer scope (optional, just to get a nice IR) + new_param = arg.id + new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body) + else: + new_param = stencil_param.id + new_args = _merge_arguments(new_args, {new_param: arg}) + + # simplify stencil directly to keep the tree small + new_stencil_body = inline_lambdas.InlineLambdas.apply( + new_stencil_body, opcount_preserving=True + ) + new_stencil_body = inline_lifts.InlineLifts().visit(new_stencil_body) + + new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( + *new_args.values() + ) + type_inference.copy_type(from_=node, to=new_node) + + return new_node + return node diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index bc1095dfb8..fccaa56232 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -96,6 +96,16 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: node.type = type_ +def copy_type(from_: itir.Node, to: itir.Node) -> None: + """ + Copy type from one node to another. + + This function mainly exists for readability reasons. + """ + assert isinstance(from_.type, ts.TypeSpec) + _set_node_type(to, from_.type) + + def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: """ Execute `callback` as soon as all `args` have a type. diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py new file mode 100644 index 0000000000..da2c16336e --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -0,0 +1,112 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Callable, Optional + +from gt4py import next as gtx +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import fuse_as_fieldop +from gt4py.next.type_system import type_specifications as ts + +IDim = gtx.Dimension("IDim") +field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + + +def test_trivial(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.op_as_fieldop("plus", d)( + im.op_as_fieldop("multiplies", d)(im.ref("inp1", field_type), im.ref("inp2", field_type)), + im.ref("inp3", field_type), + ) + expected = im.as_fieldop( + im.lambda_("inp1", "inp2", "inp3")( + im.plus(im.multiplies_(im.deref("inp1"), im.deref("inp2")), im.deref("inp3")) + ), + d, + )(im.ref("inp1", field_type), im.ref("inp2", field_type), im.ref("inp3", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_trivial_literal(): + d = im.domain("cartesian_domain", {}) + testee = im.op_as_fieldop("plus", d)(im.op_as_fieldop("multiplies", d)(1, 2), 3) + expected = im.as_fieldop(im.lambda_()(im.plus(im.multiplies_(1, 2), 3)), d)() + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_symref_used_twice(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.as_fieldop(im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), d)( + im.as_fieldop(im.lambda_("c", "d")(im.multiplies_(im.deref("c"), im.deref("d"))), d)( + im.ref("inp1", field_type), im.ref("inp2", field_type) + ), + im.ref("inp1", field_type), + ) + expected = im.as_fieldop( + im.lambda_("inp1", "inp2")( + im.plus(im.multiplies_(im.deref("inp1"), im.deref("inp2")), im.deref("inp1")) + ), + d, + )("inp1", "inp2") + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_no_inline(): + d1 = im.domain("cartesian_domain", {IDim: (1, 2)}) + d2 = im.domain("cartesian_domain", {IDim: (0, 3)}) + testee = im.as_fieldop( + im.lambda_("a")( + im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))) + ), + d1, + )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type))) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + ) + assert actual == testee + + +def test_partial_inline(): + d1 = im.domain("cartesian_domain", {IDim: (1, 2)}) + d2 = im.domain("cartesian_domain", {IDim: (0, 3)}) + testee = im.as_fieldop( + # first argument read at multiple locations -> not inlined + # second argument only reat at a single location -> inlined + im.lambda_("a", "b")( + im.plus( + im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))), + im.deref("b"), + ) + ), + d1, + )( + im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), + im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), + ) + expected = im.as_fieldop( + im.lambda_("a", "inp1")( + im.plus( + im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))), + im.deref("inp1"), + ) + ), + d1, + )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), "inp1") + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + ) + assert actual == expected From b339b82615eaabc9e2a5eb93f5656553863a7c6b Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 14 Oct 2024 16:24:47 +0200 Subject: [PATCH 002/178] feat[next]: Add IR transform to remove unnecessary cast expressions (#1688) Add IR transformation that removes cast expressions where the argument is already in the target type. --- .../next/iterator/transforms/prune_casts.py | 45 +++++++++++++++++++ .../runners/dace_fieldview/gtir_sdfg.py | 3 ++ .../transforms_tests/test_prune_casts.py | 23 ++++++++++ 3 files changed, 71 insertions(+) create mode 100644 src/gt4py/next/iterator/transforms/prune_casts.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py diff --git a/src/gt4py/next/iterator/transforms/prune_casts.py b/src/gt4py/next/iterator/transforms/prune_casts.py new file mode 100644 index 0000000000..0720394db5 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/prune_casts.py @@ -0,0 +1,45 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.type_system import type_specifications as ts + + +class PruneCasts(PreserveLocationVisitor, NodeTranslator): + """ + Removes cast expressions where the argument is already in the target type. + + This transformation requires the IR to be fully type-annotated, + therefore it should be applied after type-inference. + """ + + def visit_FunCall(self, node: ir.FunCall) -> ir.Node: + node = self.generic_visit(node) + + if not cpm.is_call_to(node, "cast_"): + return node + + value, type_constructor = node.args + + assert ( + value.type + and isinstance(type_constructor, ir.SymRef) + and (type_constructor.id in ir.TYPEBUILTINS) + ) + dtype = ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper())) + + if value.type == dtype: + return value + + return node + + @classmethod + def apply(cls, node: ir.Node) -> ir.Node: + return cls().visit(node) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 7d878dde99..09d5d6c0d0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -26,6 +26,7 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.transforms import prune_casts as ir_prune_casts from gt4py.next.iterator.type_system import inference as gtir_type_inference from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -656,7 +657,9 @@ def build_sdfg_from_gtir( Returns: An SDFG in the DaCe canonical form (simplified) """ + ir = gtir_type_inference.infer(ir, offset_provider=offset_provider) + ir = ir_prune_casts.PruneCasts().visit(ir) ir = dace_gtir_utils.patch_gtir(ir) sdfg_genenerator = GTIRToSDFG(offset_provider) sdfg = sdfg_genenerator.visit(ir) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py new file mode 100644 index 0000000000..462eed8408 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py @@ -0,0 +1,23 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.iterator.transforms.prune_casts import PruneCasts +from gt4py.next.iterator.type_system import inference as type_inference + + +def test_prune_casts_simple(): + x_ref = im.ref("x", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) + y_ref = im.ref("y", ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + testee = im.call("plus")(im.call("cast_")(x_ref, "float64"), im.call("cast_")(y_ref, "float64")) + testee = type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + + expected = im.call("plus")(im.call("cast_")(x_ref, "float64"), y_ref) + actual = PruneCasts.apply(testee) + assert actual == expected From 9feb51db27bde798245d3f80f4075e622bd42173 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 14 Oct 2024 17:01:26 +0200 Subject: [PATCH 003/178] [next]: Fix inline lambda pass opcount preserving option (#1687) In #1531 the `itir.Node` class got a `type` attribute, that until now contributed to the hash computation of all nodes. As such two `itir.SymRef` with the same `id`, but one with a type inferred and one without (i.e. `None`) got a different hash value. Consequently the `inline_lambda` pass did not recognize them as a reference to the same symbol and erroneously inlined the expression even with `opcount_preserving=True`. This PR fixes the hash computation, such that again `node1 == node2` implies `hash(node1) == hash(node2)`. --- src/gt4py/next/iterator/ir.py | 12 ++++++++---- .../transforms_tests/test_inline_lambdas.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index b2a549501f..42da4c83a6 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -37,10 +37,14 @@ def __str__(self) -> str: return pformat(self) def __hash__(self) -> int: - return hash(type(self)) ^ hash( - tuple( - hash(tuple(v)) if isinstance(v, list) else hash(v) - for v in self.iter_children_values() + return hash( + ( + type(self), + *( + tuple(v) if isinstance(v, list) else v + for (k, v) in self.iter_children_items() + if k not in ["location", "type"] + ), ) ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index e45281734b..2e0a83d33b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -8,6 +8,7 @@ import pytest +from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas @@ -39,6 +40,21 @@ ), im.multiplies_(im.plus(2, 1), im.plus("x", "x")), ), + ( + # ensure opcount preserving option works whether `itir.SymRef` has a type or not + "typed_ref", + im.let("a", im.call("opaque")())( + im.plus(im.ref("a", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)), im.ref("a", None)) + ), + { + True: im.let("a", im.call("opaque")())( + im.plus( # stays as is + im.ref("a", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)), im.ref("a", None) + ) + ), + False: im.plus(im.call("opaque")(), im.call("opaque")()), + }, + ), ] From 5ce0fb8b9234c0f514ac83cc060dee6b549684c1 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 16 Oct 2024 17:01:57 +0200 Subject: [PATCH 004/178] feat[next]: gtir lowering of broadcasted scalars (#1677) --- src/gt4py/next/ffront/foast_to_gtir.py | 5 ++++- .../unit_tests/ffront_tests/test_foast_to_gtir.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 948a8481d7..9cb0ce05f5 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -373,7 +373,10 @@ def create_if(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall: _visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self.visit(node.args[0], **kwargs) + expr = self.visit(node.args[0], **kwargs) + if isinstance(node.args[0].type, ts.ScalarType): + return im.as_fieldop(im.ref("deref"))(expr) + return expr def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self._map(self.visit(node.func, **kwargs), *node.args) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 3951c410dc..09f18246dc 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -916,3 +916,14 @@ def foo(inp: gtx.Field[[TDim], float64]): assert lowered.id == "foo" assert lowered.expr == im.ref("inp") + + +def test_scalar_broadcast(): + def foo(): + return broadcast(1, (UDim, TDim)) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + assert lowered.id == "foo" + assert lowered.expr == im.as_fieldop("deref")(1) From 3f7fceed483e8a34b17fdf4d9a2625ecb0896759 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 16 Oct 2024 17:03:51 +0200 Subject: [PATCH 005/178] feat[next]: Allow type inference without domain argument to `as_fieldop` (#1689) In case we don't have a domain argument to `as_fieldop` we can not infer the exact result type. In order to still allow some passes which don't need this information to run before the domain inference, we continue with a dummy domain. One example is the CollapseTuple pass which only needs information about the structure, e.g. how many tuple elements does this node have, but not the dimensions of a field. Note that it might appear as if using the TraceShift pass would allow us to deduce the return type of `as_fieldop` without a domain, but this is not the case, since we don't have information on the ordering of dimensions. In this example ``` as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field) ``` it is unclear if the result has dimension I, J or J, I. --- .../next/iterator/type_system/inference.py | 4 ++- .../type_system/type_specifications.py | 2 +- .../iterator/type_system/type_synthesizer.py | 27 ++++++++++++++++--- .../iterator_tests/test_type_inference.py | 13 +++++++++ 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index fccaa56232..a13c7fb816 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -504,9 +504,11 @@ def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.TupleType: domain = self.visit(node.domain, ctx=ctx) assert isinstance(domain, it_ts.DomainType) + assert domain.dims != "unknown" assert node.dtype return type_info.apply_to_primitive_constituents( - lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), node.dtype + lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), # type: ignore[arg-type] # ensured by domain.dims != "unknown" above + node.dtype, ) def visit_IfStmt(self, node: itir.IfStmt, *, ctx) -> None: diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index cfe3987b8c..94a174dca4 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -20,7 +20,7 @@ class NamedRangeType(ts.TypeSpec): @dataclasses.dataclass(frozen=True) class DomainType(ts.DataType): - dims: list[common.Dimension] + dims: list[common.Dimension] | Literal["unknown"] @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 77cd39389a..c836de1391 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -271,17 +271,36 @@ def _convert_as_fieldop_input_to_iterator( @_register_builtin_type_synthesizer def as_fieldop( - stencil: TypeSynthesizer, domain: it_ts.DomainType, offset_provider: common.OffsetProvider + stencil: TypeSynthesizer, + domain: Optional[it_ts.DomainType] = None, + *, + offset_provider: common.OffsetProvider, ) -> TypeSynthesizer: + # In case we don't have a domain argument to `as_fieldop` we can not infer the exact result + # type. In order to still allow some passes which don't need this information to run before the + # domain inference, we continue with a dummy domain. One example is the CollapseTuple pass + # which only needs information about the structure, e.g. how many tuple elements does this node + # have, but not the dimensions of a field. + # Note that it might appear as if using the TraceShift pass would allow us to deduce the return + # type of `as_fieldop` without a domain, but this is not the case, since we don't have + # information on the ordering of dimensions. In this example + # `as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field)` + # it is unclear if the result has dimension I, J or J, I. + if domain is None: + domain = it_ts.DomainType(dims="unknown") + @TypeSynthesizer - def applied_as_fieldop(*fields) -> ts.FieldType: + def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: stencil_return = stencil( *(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields), offset_provider=offset_provider, ) assert isinstance(stencil_return, ts.DataType) return type_info.apply_to_primitive_constituents( - lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type), stencil_return + lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type) + if domain.dims != "unknown" + else ts.DeferredType(constraint=ts.FieldType), + stencil_return, ) return applied_as_fieldop @@ -329,7 +348,7 @@ def applied_reduce(*args: it_ts.ListType, offset_provider: common.OffsetProvider @_register_builtin_type_synthesizer -def shift(*offset_literals, offset_provider) -> TypeSynthesizer: +def shift(*offset_literals, offset_provider: common.OffsetProvider) -> TypeSynthesizer: @TypeSynthesizer def apply_shift( it: it_ts.IteratorType | ts.DeferredType, diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 05cd6b6854..20a1d7e9b7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -478,3 +478,16 @@ def test_if_stmt(): result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) assert result.cond.type == bool_type assert result.true_branch[0].expr.type == float_i_field + + +def test_as_fieldop_without_domain(): + testee = im.as_fieldop(im.lambda_("it")(im.deref(im.shift("IOff", 1)("it"))))( + im.ref("inp", float_i_field) + ) + result = itir_type_inference.infer( + testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + ) + assert result.type == ts.DeferredType(constraint=ts.FieldType) + assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( + position_dims="unknown", defined_dims=float_i_field.dims, element_type=float_i_field.dtype + ) From 0a27c7a415a8cb7ec61e2a3fe2cdd4595a3481d7 Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 17 Oct 2024 16:08:03 +0200 Subject: [PATCH 006/178] feat[next][dace]: GTIR-to-DaCe lowering of map-reduce (only full connectivity) (#1683) This PR adds support for lowering of `map_` and `make_const_list` builtin functions. However, the current implementation only supports neighbor tables with full connectivity (no skip values). The support for skip values will be added in next PR. To be noted: - This PR generalizes the handling of tasklets without arguments inside a map scope. The return type for `input_connections` is extended to contain a `TaskletConnection` variant, which is lowered to an empty edge from map entry node to the tasklet node. - The result of `make_const_list` is a scalar value to be broadcasted on a local field. However, in order to keep the lowering simple, this value is represented as a 1D 1-element array (`shape=(1,)`). --- .../ir_utils/common_pattern_matcher.py | 10 + .../next/iterator/transforms/fuse_maps.py | 19 +- .../runners/dace_common/utility.py | 9 +- .../gtir_builtin_translators.py | 127 +++++-- .../runners/dace_fieldview/gtir_dataflow.py | 329 +++++++++++++----- .../dace_fieldview/gtir_python_codegen.py | 11 + .../runners/dace_fieldview/gtir_sdfg.py | 25 +- .../runners/dace_fieldview/utility.py | 46 +-- .../dace_tests/test_gtir_to_sdfg.py | 111 +++--- 9 files changed, 453 insertions(+), 234 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 4aea7ef149..16a88b282a 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -22,6 +22,16 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: ) +def is_applied_map(arg: itir.Node) -> TypeGuard[itir.FunCall]: + """Match expressions of the form `map(λ(...) → ...)(...)`.""" + return ( + isinstance(arg, itir.FunCall) + and isinstance(arg.fun, itir.FunCall) + and isinstance(arg.fun.fun, itir.SymRef) + and arg.fun.fun.id == "map_" + ) + + def is_applied_reduce(arg: itir.Node) -> TypeGuard[itir.FunCall]: """Match expressions of the form `reduce(λ(...) → ...)(...)`.""" return ( diff --git a/src/gt4py/next/iterator/transforms/fuse_maps.py b/src/gt4py/next/iterator/transforms/fuse_maps.py index 430d794880..8d27178682 100644 --- a/src/gt4py/next/iterator/transforms/fuse_maps.py +++ b/src/gt4py/next/iterator/transforms/fuse_maps.py @@ -7,7 +7,6 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -from typing import TypeGuard from gt4py.eve import NodeTranslator, traits from gt4py.eve.utils import UIDGenerator @@ -16,14 +15,6 @@ from gt4py.next.iterator.transforms import inline_lambdas -def _is_map(node: ir.Node) -> TypeGuard[ir.FunCall]: - return ( - isinstance(node, ir.FunCall) - and isinstance(node.fun, ir.FunCall) - and node.fun.fun == ir.SymRef(id="map_") - ) - - @dataclasses.dataclass(frozen=True) class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ @@ -58,10 +49,10 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda: def visit_FunCall(self, node: ir.FunCall, **kwargs): node = self.generic_visit(node) - if _is_map(node) or cpm.is_applied_reduce(node): - if any(_is_map(arg) for arg in node.args): + if cpm.is_applied_map(node) or cpm.is_applied_reduce(node): + if any(cpm.is_applied_map(arg) for arg in node.args): first_param = ( - 0 if _is_map(node) else 1 + 0 if cpm.is_applied_map(node) else 1 ) # index of the first param of op that maps to args (0 for map, 1 for reduce) assert isinstance(node.fun, ir.FunCall) assert isinstance(node.fun.args[0], (ir.Lambda, ir.SymRef)) @@ -76,7 +67,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): new_params.append(outer_op.params[0]) for i in range(len(node.args)): - if _is_map(node.args[i]): + if cpm.is_applied_map(node.args[i]): map_call = node.args[i] assert isinstance(map_call, ir.FunCall) assert isinstance(map_call.fun, ir.FunCall) @@ -102,7 +93,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): new_body ) # removes one level of nesting (the recursive inliner could simplify more, however this can also be done on the full tree later) new_op = ir.Lambda(params=new_params, expr=new_body) - if _is_map(node): + if cpm.is_applied_map(node): return ir.FunCall( fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args ) diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index dec34ecbac..d678fdab7f 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -37,12 +37,13 @@ def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: raise ValueError(f"Scalar type '{type_}' not supported.") -def as_scalar_type(typestr: str) -> ts.ScalarType: - """Obtain GT4Py scalar type from generic numpy string representation.""" +def as_itir_type(dtype: dace.typeclass) -> ts.ScalarType: + """Get GT4Py scalar representation of a DaCe type.""" + type_name = str(dtype.as_numpy_dtype()) try: - kind = getattr(ts.ScalarKind, typestr.upper()) + kind = getattr(ts.ScalarKind, type_name.upper()) except AttributeError as ex: - raise ValueError(f"Data type {typestr} not supported.") from ex + raise ValueError(f"Data type {type_name} not supported.") from ex return ts.ScalarType(kind) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index e91bd880c6..8fb1451efb 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -10,12 +10,13 @@ import abc import dataclasses -from typing import TYPE_CHECKING, Iterable, Optional, Protocol, TypeAlias +from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, TypeAlias import dace import dace.subsets as sbs from gt4py.next import common as gtx_common, utils as gtx_utils +from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.iterator.type_system import type_specifications as itir_ts @@ -32,16 +33,29 @@ from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg -IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes - - @dataclasses.dataclass(frozen=True) class Field: data_node: dace.nodes.AccessNode data_type: ts.FieldType | ts.ScalarType +FieldopDomain: TypeAlias = list[ + tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] +] +""" +Domain of a field operator represented as a list of tuples with 3 elements: + - dimension definition + - symbolic expression for lower bound (inclusive) + - symbolic expression for upper bound (exclusive) +""" + + FieldopResult: TypeAlias = Field | tuple[Field | tuple, ...] +"""Result of a field operator, can be either a field or a tuple fields.""" + + +INDEX_DTYPE: Final[dace.typeclass] = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) +"""Data type used for field indexing.""" class PrimitiveTranslator(Protocol): @@ -81,11 +95,11 @@ def _parse_fieldop_arg( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, - domain: list[ - tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] - ], + domain: FieldopDomain, reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: + """Helper method to visit an expression passed as argument to a field operator.""" + arg = sdfg_builder.visit( node, sdfg=sdfg, @@ -101,10 +115,7 @@ def _parse_fieldop_arg( return gtir_dataflow.MemletExpr(arg.data_node, sbs.Indices([0])) elif isinstance(arg.data_type, ts.FieldType): indices: dict[gtx_common.Dimension, gtir_dataflow.ValueExpr] = { - dim: gtir_dataflow.SymbolExpr( - dace_gtir_utils.get_map_variable(dim), - IteratorIndexDType, - ) + dim: gtir_dataflow.SymbolExpr(dace_gtir_utils.get_map_variable(dim), INDEX_DTYPE) for dim, _, _ in domain } dims = arg.data_type.dims + ( @@ -120,12 +131,11 @@ def _parse_fieldop_arg( def _create_temporary_field( sdfg: dace.SDFG, state: dace.SDFGState, - domain: list[ - tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] - ], + domain: FieldopDomain, node_type: ts.FieldType, - output_desc: dace.data.Data, + dataflow_output: gtir_dataflow.DataflowOutputEdge, ) -> Field: + """Helper method to allocate a temporary field where to write the output of a field operator.""" domain_dims, _, domain_ubs = zip(*domain) field_dims = list(domain_dims) # It should be enough to allocate an array with shape (upper_bound - lower_bound) @@ -138,6 +148,7 @@ def _create_temporary_field( # eliminate most of transient arrays. field_shape = list(domain_ubs) + output_desc = dataflow_output.result.node.desc(sdfg) if isinstance(output_desc, dace.data.Array): assert isinstance(node_type.dtype, itir_ts.ListType) assert isinstance(node_type.dtype.element_type, ts.ScalarType) @@ -157,7 +168,31 @@ def _create_temporary_field( return Field(field_node, field_type) -def translate_as_field_op( +def extract_domain(node: gtir.Node) -> FieldopDomain: + """ + Visits the domain of a field operator and returns a list of dimensions and + the corresponding lower and upper bounds. The returned lower bound is inclusive, + the upper bound is exclusive: [lower_bound, upper_bound[ + """ + assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) + + domain = [] + for named_range in node.args: + assert cpm.is_call_to(named_range, "named_range") + assert len(named_range.args) == 3 + axis = named_range.args[0] + assert isinstance(axis, gtir.AxisLiteral) + lower_bound, upper_bound = ( + dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(arg)) + for arg in named_range.args[1:3] + ) + dim = gtx_common.Dimension(axis.value, axis.kind) + domain.append((dim, lower_bound, upper_bound)) + + return domain + + +def translate_as_fieldop( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, @@ -188,25 +223,55 @@ def translate_as_field_op( assert isinstance(domain_expr, gtir.FunCall) # parse the domain of the field operator - domain = dace_gtir_utils.get_domain(domain_expr) + domain = extract_domain(domain_expr) + # The reduction identity value is used in place of skip values when building + # a list of neighbor values in the unstructured domain. + # + # A reduction on neighbor values can be either expressed in local view (itir): + # vertices @ u⟨ Vertexₕ: [0, nvertices) ⟩ + # ← as_fieldop( + # λ(it) → reduce(plus, 0)(neighbors(V2Eₒ, it)), u⟨ Vertexₕ: [0, nvertices) ⟩ + # )(edges); + # + # or in field view (gtir): + # vertices @ u⟨ Vertexₕ: [0, nvertices) ⟩ + # ← as_fieldop(λ(it) → reduce(plus, 0)(·it), u⟨ Vertexₕ: [0, nvertices) ⟩)( + # as_fieldop(λ(it) → neighbors(V2Eₒ, it), u⟨ Vertexₕ: [0, nvertices) ⟩)(edges) + # ); + # + # In local view, the list of neighbors is (recursively) built while visiting + # the current expression. + # In field view, the list of neighbors is built as argument to the current + # expression. Therefore, the reduction identity value needs to be passed to + # the argument visitor (`reduce_identity_for_args = reduce_identity`). if cpm.is_applied_reduce(stencil_expr.expr): if reduce_identity is not None: - raise NotImplementedError("nested reductions not supported.") - - # the reduce identity value is used to fill the skip values in neighbors list - _, _, reduce_identity = gtir_dataflow.get_reduce_params(stencil_expr.expr) + raise NotImplementedError("Nested reductions are not supported.") + _, _, reduce_identity_for_args = gtir_dataflow.get_reduce_params(stencil_expr.expr) + elif cpm.is_call_to(stencil_expr.expr, "neighbors"): + # When the visitor hits a neighbors expression, we stop carrying the reduce + # identity further (`reduce_identity_for_args = None`) because the reduce + # identity value is filled in place of skip values in the context of neighbors + # itself, not in the arguments context. + # Besides, setting `reduce_identity_for_args = None` enables a sanity check + # that the sequence 'reduce(V2E) -> neighbors(V2E) -> reduce(C2E) -> neighbors(C2E)' + # is accepted, while 'reduce(V2E) -> reduce(C2E) -> neighbors(V2E) -> neighbors(C2E)' + # is not. The latter sequence would raise the 'NotImplementedError' exception above. + reduce_identity_for_args = None + else: + reduce_identity_for_args = reduce_identity # visit the list of arguments to be passed to the lambda expression stencil_args = [ - _parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain, reduce_identity) + _parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain, reduce_identity_for_args) for arg in node.args ] # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder, reduce_identity) input_edges, output = taskgen.visit(stencil_expr, args=stencil_args) - output_desc = output.expr.node.desc(sdfg) + output_desc = output.result.node.desc(sdfg) domain_index = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain]) if isinstance(node.type.dtype, itir_ts.ListType): @@ -220,11 +285,17 @@ def translate_as_field_op( output_subset = sbs.Range.from_indices(domain_index) # create map range corresponding to the field operator domain - map_ranges = {dace_gtir_utils.get_map_variable(dim): f"{lb}:{ub}" for dim, lb, ub in domain} - me, mx = sdfg_builder.add_map("field_op", state, map_ranges) + me, mx = sdfg_builder.add_map( + "fieldop", + state, + ndrange={ + dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" + for dim, lower_bound, upper_bound in domain + }, + ) # allocate local temporary storage for the result field - result_field = _create_temporary_field(sdfg, state, domain, node.type, output_desc) + result_field = _create_temporary_field(sdfg, state, domain, node.type, output) # here we setup the edges from the map entry node for edge in input_edges: @@ -439,7 +510,7 @@ def translate_tuple_get( if not isinstance(node.args[0], gtir.Literal): raise ValueError("Tuple can only be subscripted with compile-time constants.") - assert node.args[0].type == dace_utils.as_scalar_type(gtir.INTEGER_INDEX_BUILTIN) + assert node.args[0].type == dace_utils.as_itir_type(INDEX_DTYPE) index = int(node.args[0].value) data_nodes = sdfg_builder.visit( @@ -566,7 +637,7 @@ def translate_symbol_ref( if TYPE_CHECKING: # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol __primitive_translators: list[PrimitiveTranslator] = [ - translate_as_field_op, + translate_as_fieldop, translate_if, translate_literal, translate_make_tuple, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 9739d7927a..0e571fc17d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -18,7 +18,7 @@ from gt4py import eve from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -145,7 +145,7 @@ class DataflowOutputEdge: """ state: dace.SDFGState - expr: DataExpr + result: DataExpr def connect( self, @@ -154,13 +154,13 @@ def connect( subset: sbs.Range, ) -> None: # retrieve the node which writes the result - last_node = self.state.in_edges(self.expr.node)[0].src + last_node = self.state.in_edges(self.result.node)[0].src if isinstance(last_node, dace.nodes.Tasklet): # the last transient node can be deleted - last_node_connector = self.state.in_edges(self.expr.node)[0].src_conn - self.state.remove_node(self.expr.node) + last_node_connector = self.state.in_edges(self.result.node)[0].src_conn + self.state.remove_node(self.result.node) else: - last_node = self.expr.node + last_node = self.result.node last_node_connector = None self.state.add_memlet_path( @@ -272,7 +272,12 @@ def _add_map( ], **kwargs: Any, ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: - """Helper method to add a map with unique name in current state.""" + """ + Helper method to add a map in current state. + + The subgraph builder ensures that the map receives a unique name, + by adding a unique suffix to the provided name. + """ return self.subgraph_builder.add_map(name, self.state, ndrange, **kwargs) def _add_tasklet( @@ -283,7 +288,12 @@ def _add_tasklet( code: str, **kwargs: Any, ) -> dace.nodes.Tasklet: - """Helper method to add a tasklet with unique name in current state.""" + """ + Helper method to add a tasklet in current state. + + The subgraph builder ensures that the tasklet receives a unique name, + by adding a unique suffix to the provided name. + """ tasklet_node = self.subgraph_builder.add_tasklet( name, self.state, inputs, outputs, code, **kwargs ) @@ -295,15 +305,68 @@ def _add_tasklet( self.input_edges.append(edge) return tasklet_node + def _add_mapped_tasklet( + self, + name: str, + map_ranges: Dict[str, str | dace.subsets.Subset] + | List[Tuple[str, str | dace.subsets.Subset]], + inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + code: str, + outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + **kwargs: Any, + ) -> tuple[dace.nodes.Tasklet, dace.nodes.MapEntry, dace.nodes.MapExit]: + """ + Helper method to add a mapped tasklet in current state. + + The subgraph builder ensures that the tasklet receives a unique name, + by adding a unique suffix to the provided name. + """ + return self.subgraph_builder.add_mapped_tasklet( + name, self.state, map_ranges, inputs, code, outputs, **kwargs + ) + + def _construct_local_view(self, field: MemletExpr | DataExpr) -> DataExpr: + if isinstance(field, MemletExpr): + desc = field.node.desc(self.sdfg) + local_dim_indices = [i for i, size in enumerate(field.subset.size()) if size != 1] + if len(local_dim_indices) == 0: + # we are accessing a single-element array with shape (1,) + view_shape = (1,) + view_strides = (1,) + else: + view_shape = tuple(desc.shape[i] for i in local_dim_indices) + view_strides = tuple(desc.strides[i] for i in local_dim_indices) + view, _ = self.sdfg.add_view( + f"{field.node.data}_view", + view_shape, + desc.dtype, + strides=view_strides, + find_new_name=True, + ) + local_view_node = self.state.add_access(view) + self._add_input_data_edge(field.node, field.subset, local_view_node) + + return DataExpr(local_view_node, desc.dtype) + + else: + return field + def _construct_tasklet_result( self, dtype: dace.typeclass, src_node: dace.nodes.Tasklet, src_connector: str, + use_array: bool = False, ) -> DataExpr: temp_name = self.sdfg.temp_data_name() - self.sdfg.add_scalar(temp_name, dtype, transient=True) - data_type = dace_utils.as_scalar_type(str(dtype.as_numpy_dtype())) + if use_array: + # In some cases, such as result data with list-type annotation, we want + # that output data is represented as an array (single-element 1D array) + # in order to allow for composition of array shape in external memlets. + self.sdfg.add_array(temp_name, (1,), dtype, transient=True) + else: + self.sdfg.add_scalar(temp_name, dtype, transient=True) + data_type = dace_utils.as_itir_type(dtype) temp_node = self.state.add_access(temp_name) self._add_edge( src_node, @@ -412,6 +475,7 @@ def _visit_deref(self, node: gtir.FunCall) -> ValueExpr: def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr: assert len(node.args) == 2 + assert isinstance(node.type, itir_ts.ListType) assert isinstance(node.args[0], gtir.OffsetLiteral) offset = node.args[0].value @@ -422,9 +486,6 @@ def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr: it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) assert offset_provider.neighbor_axis in it.dimensions - neighbor_dim_index = it.dimensions.index(offset_provider.neighbor_axis) - assert offset_provider.neighbor_axis not in it.indices - assert offset_provider.origin_axis not in it.dimensions assert offset_provider.origin_axis in it.indices origin_index = it.indices[offset_provider.origin_axis] assert isinstance(origin_index, SymbolExpr) @@ -446,38 +507,24 @@ def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr: # node). For the specific case of `neighbors` we need to nest the neighbors map # inside the field map and the memlets will traverse the external map and write # to the view nodes. The simplify pass will remove the redundant access nodes. - field_slice_view, field_slice_desc = self.sdfg.add_view( - f"{offset_provider.neighbor_axis.value}_view", - (field_desc.shape[neighbor_dim_index],), - field_desc.dtype, - strides=(field_desc.strides[neighbor_dim_index],), - find_new_name=True, - ) - field_slice_node = self.state.add_access(field_slice_view) - field_subset = ",".join( - it.indices[dim].value # type: ignore[union-attr] - if dim != offset_provider.neighbor_axis - else f"0:{size}" - for dim, size in zip(it.dimensions, field_desc.shape, strict=True) - ) - self._add_input_data_edge( - it.field, - sbs.Range.from_string(field_subset), - field_slice_node, - ) - - connectivity_slice_view, _ = self.sdfg.add_view( - "neighbors_view", - (offset_provider.max_neighbors,), - connectivity_desc.dtype, - strides=(connectivity_desc.strides[1],), - find_new_name=True, + field_slice = self._construct_local_view( + MemletExpr( + it.field, + sbs.Range.from_string( + ",".join( + it.indices[dim].value # type: ignore[union-attr] + if dim != offset_provider.neighbor_axis + else f"0:{size}" + for dim, size in zip(it.dimensions, field_desc.shape, strict=True) + ) + ), + ) ) - connectivity_slice_node = self.state.add_access(connectivity_slice_view) - self._add_input_data_edge( - self.state.add_access(connectivity), - sbs.Range.from_string(f"{origin_index.value}, 0:{offset_provider.max_neighbors}"), - connectivity_slice_node, + connectivity_slice = self._construct_local_view( + MemletExpr( + self.state.add_access(connectivity), + sbs.Range.from_string(f"{origin_index.value}, 0:{offset_provider.max_neighbors}"), + ) ) neighbors_temp, _ = self.sdfg.add_temp_transient( @@ -487,64 +534,135 @@ def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr: offset_dim = gtx_common.Dimension(offset, kind=gtx_common.DimensionKind.LOCAL) neighbor_idx = dace_gtir_utils.get_map_variable(offset_dim) - me, mx = self._add_map( - f"{offset}_neighbors", - { - neighbor_idx: f"0:{offset_provider.max_neighbors}", - }, - ) + index_connector = "__index" + output_connector = "__val" + tasklet_expression = f"{output_connector} = __field[{index_connector}]" + input_memlets = { + "__field": self.sdfg.make_array_memlet(field_slice.node.data), + index_connector: dace.Memlet(data=connectivity_slice.node.data, subset=neighbor_idx), + } + input_nodes = { + field_slice.node.data: field_slice.node, + connectivity_slice.node.data: connectivity_slice.node, + } + if offset_provider.has_skip_values: assert self.reduce_identity is not None assert self.reduce_identity.dtype == field_desc.dtype - # TODO: Investigate if a NestedSDFG brings benefits - tasklet_node = self._add_tasklet( - "gather_neighbors_with_skip_values", - {"__field", index_connector}, - {"__val"}, - f"__val = __field[{index_connector}] if {index_connector} != {gtx_common._DEFAULT_SKIP_VALUE} else {self.reduce_identity.dtype}({self.reduce_identity.value})", - ) + tasklet_expression += f" if {index_connector} != {gtx_common._DEFAULT_SKIP_VALUE} else {field_desc.dtype}({self.reduce_identity.value})" + + self._add_mapped_tasklet( + name=f"{offset}_neighbors", + map_ranges={neighbor_idx: f"0:{offset_provider.max_neighbors}"}, + code=tasklet_expression, + inputs=input_memlets, + input_nodes=input_nodes, + outputs={ + output_connector: dace.Memlet(data=neighbors_temp, subset=neighbor_idx), + }, + output_nodes={neighbors_temp: neighbors_node}, + external_edges=True, + ) - else: - tasklet_node = self._add_tasklet( - "gather_neighbors", - {"__field", index_connector}, - {"__val"}, - f"__val = __field[{index_connector}]", - ) + return DataExpr(neighbors_node, node.type) - self.state.add_memlet_path( - field_slice_node, - me, - tasklet_node, - dst_conn="__field", - memlet=dace.Memlet.from_array(field_slice_view, field_slice_desc), - ) - self.state.add_memlet_path( - connectivity_slice_node, - me, - tasklet_node, - dst_conn=index_connector, - memlet=dace.Memlet(data=connectivity_slice_view, subset=neighbor_idx), - ) - self.state.add_memlet_path( - tasklet_node, - mx, - neighbors_node, - src_conn="__val", - memlet=dace.Memlet(data=neighbors_temp, subset=neighbor_idx), - ) + def _visit_map(self, node: gtir.FunCall) -> DataExpr: + """ + A map node defines an operation to be mapped on all elements of input arguments. + + The map operation is applied on the local dimension of input fields. + In the example below, the local dimension consists of a list of neighbor + values as the first argument, and a list of constant values `1.0`: + `map_(plus)(neighbors(V2E, it), make_const_list(1.0))` + + The `plus` operation is lowered to a tasklet inside a map that computes + the domain of the local dimension (in this example, max neighbors in V2E). + The result is a 1D local field, with same size as the input local dimension. + In above example, the result would be an array with size V2E.max_neighbors, + containing the V2E neighbor values incremented by 1.0. + """ assert isinstance(node.type, itir_ts.ListType) - return DataExpr(neighbors_node, node.type) + assert isinstance(node.fun, gtir.FunCall) + assert len(node.fun.args) == 1 # the operation to be mapped on the arguments + + assert isinstance(node.type.element_type, ts.ScalarType) + dtype = dace_utils.as_dace_type(node.type.element_type) + + input_args = [self.visit(arg) for arg in node.args] + input_connectors = [f"__arg{i}" for i in range(len(input_args))] + output_connector = "__out" + + # Here we build the body of the tasklet + fun_node = im.call(node.fun.args[0])(*input_connectors) + fun_python_code = gtir_python_codegen.get_source(fun_node) + tasklet_expression = f"{output_connector} = {fun_python_code}" + + # TODO(edopao): extract offset_dim from the input arguments + offset_dim = gtx_common.Dimension("", gtx_common.DimensionKind.LOCAL) + map_index = dace_gtir_utils.get_map_variable(offset_dim) + + # The dataflow we build in this class has some loose connections on input edges. + # These edges are described as set of nodes, that will have to be connected to + # external data source nodes passing through the map entry node of the field map. + # Similarly to `neighbors` expressions, the `map_` input edges terminate on view + # nodes (see `_construct_local_view` in the for-loop below), because it is simpler + # than representing map-to-map edges (which require memlets with 2 pass-nodes). + input_memlets = {} + input_nodes = {} + local_size: Optional[int] = None + for conn, input_expr in zip(input_connectors, input_args): + input_node = self._construct_local_view(input_expr).node + input_desc = input_node.desc(self.sdfg) + # we assume that there is a single local dimension + if len(input_desc.shape) != 1: + raise ValueError(f"More than one local dimension in map expression {node}.") + input_size = input_desc.shape[0] + if input_size == 1: + input_memlets[conn] = dace.Memlet(data=input_node.data, subset="0") + elif local_size is not None and input_size != local_size: + raise ValueError(f"Invalid node {node}") + else: + input_memlets[conn] = dace.Memlet(data=input_node.data, subset=map_index) + local_size = input_size + + input_nodes[input_node.data] = input_node + + if local_size is None: + # corner case where map is applied to 1-element lists + assert len(input_nodes) >= 1 + local_size = 1 + + out, _ = self.sdfg.add_temp_transient((local_size,), dtype) + out_node = self.state.add_access(out) + + self._add_mapped_tasklet( + name="map", + map_ranges={map_index: f"0:{local_size}"}, + code=tasklet_expression, + inputs=input_memlets, + input_nodes=input_nodes, + outputs={ + output_connector: dace.Memlet(data=out, subset=map_index), + }, + output_nodes={out: out_node}, + external_edges=True, + ) + + return DataExpr(out_node, dtype) def _visit_reduce(self, node: gtir.FunCall) -> DataExpr: + assert isinstance(node.type, ts.ScalarType) op_name, reduce_init, reduce_identity = get_reduce_params(node) - dtype = reduce_identity.dtype - # We store the value of reduce identity in the visitor context while visiting - # the input to reduction; this value will be use by the `neighbors` visitor - # to fill the skip values in the neighbors list. + # The input to reduction is a list of elements on a local dimension. + # This list is provided by an argument that typically calls the neighbors + # builtin function, to built a list of neighbor values for each element + # in the field target dimension. + # We store the value of reduce identity in the visitor context to have it + # available while visiting the input to reduction; this value might be used + # by the `neighbors` visitor to fill the skip values in the neighbors list. prev_reduce_identity = self.reduce_identity self.reduce_identity = reduce_identity @@ -585,7 +703,7 @@ def _visit_reduce(self, node: gtir.FunCall) -> DataExpr: ) temp_name = self.sdfg.temp_data_name() - self.sdfg.add_scalar(temp_name, dtype, transient=True) + self.sdfg.add_scalar(temp_name, reduce_identity.dtype, transient=True) temp_node = self.state.add_access(temp_name) self.state.add_nedge( @@ -593,7 +711,6 @@ def _visit_reduce(self, node: gtir.FunCall) -> DataExpr: temp_node, dace.Memlet(data=temp_name, subset="0"), ) - assert isinstance(node.type, ts.ScalarType) return DataExpr(temp_node, node.type) def _split_shift_args( @@ -816,9 +933,6 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> DataExpr: Generic handler called by `visit_FunCall()` when it encounters a builtin function that does not match any other specific handler. """ - assert isinstance(node.type, ts.ScalarType) - dtype = dace_utils.as_dace_type(node.type) - node_internals = [] node_connections: dict[str, MemletExpr | DataExpr] = {} for i, arg in enumerate(node.args): @@ -863,7 +977,27 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> DataExpr: connector, ) - return self._construct_tasklet_result(dtype, tasklet_node, "result") + if isinstance(node.type, itir_ts.ListType): + # The only builtin function (so far) handled here that returns a list + # is 'make_const_list'. There are other builtin functions (map_, neighbors) + # that return a list but they are handled in specialized visit methods. + # This method (the generic visitor for builtin functions) always returns + # a single value. This is also the case of 'make_const_list' expression: + # it simply broadcasts a scalar on the local domain of another expression, + # for example 'map_(plus)(neighbors(V2Eₒ, it), make_const_list(1.0))'. + # Therefore we handle `ListType` as a single-element array with shape (1,) + # that will be accessed in a map expression on a local domain. + assert isinstance(node.type.element_type, ts.ScalarType) + dtype = dace_utils.as_dace_type(node.type.element_type) + # In order to ease the lowring of the parent expression on local dimension, + # we represent the scalar value as a single-element 1D array. + use_array = True + else: + assert isinstance(node.type, ts.ScalarType) + dtype = dace_utils.as_dace_type(node.type) + use_array = False + + return self._construct_tasklet_result(dtype, tasklet_node, "result", use_array=use_array) def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | ValueExpr: if cpm.is_call_to(node, "deref"): @@ -872,6 +1006,9 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | ValueExpr: elif cpm.is_call_to(node, "neighbors"): return self._visit_neighbors(node) + elif cpm.is_applied_map(node): + return self._visit_map(node) + elif cpm.is_applied_reduce(node): return self._visit_reduce(node) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index f133a9224d..6aee33c56e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -75,6 +75,7 @@ def builtin_cast(*args: Any) -> str: val, target_type = args + assert target_type in gtir.TYPEBUILTINS return MATH_BUILTINS_MAPPING[target_type].format(val) @@ -83,9 +84,19 @@ def builtin_if(*args: Any) -> str: return f"{true_val} if {cond} else {false_val}" +def make_const_list(arg: str) -> str: + """ + Takes a single scalar argument and broadcasts this value on the local dimension + of map expression. In a dataflow, we represent it as a tasklet that writes + a value to a scalar node. + """ + return arg + + GENERAL_BUILTIN_MAPPING: dict[str, Callable[[Any], str]] = { "cast_": builtin_cast, "if_": builtin_if, + "make_const_list": make_const_list, } diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 09d5d6c0d0..d79d887318 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -77,6 +77,21 @@ def add_tasklet( unique_name = self.unique_tasklet_name(name) return state.add_tasklet(unique_name, inputs, outputs, code, **kwargs) + def add_mapped_tasklet( + self, + name: str, + state: dace.SDFGState, + map_ranges: Dict[str, str | dace.subsets.Subset] + | List[Tuple[str, str | dace.subsets.Subset]], + inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + code: str, + outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + **kwargs: Any, + ) -> tuple[dace.nodes.Tasklet, dace.nodes.MapEntry, dace.nodes.MapExit]: + """Wrapper of `dace.SDFGState.add_mapped_tasklet` that assigns unique name.""" + unique_name = self.unique_tasklet_name(name) + return state.add_mapped_tasklet(unique_name, map_ranges, inputs, code, outputs, **kwargs) + class SDFGBuilder(DataflowBuilder, Protocol): """Visitor interface available to GTIR-primitive translators.""" @@ -111,7 +126,7 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): map_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") ) - tesklet_uids: eve.utils.UIDGenerator = dataclasses.field( + tasklet_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="tlet") ) @@ -125,7 +140,7 @@ def unique_map_name(self, name: str) -> str: return f"{self.map_uids.sequential_id()}_{name}" def unique_tasklet_name(self, name: str) -> str: - return f"{self.tesklet_uids.sequential_id()}_{name}" + return f"{self.tasklet_uids.sequential_id()}_{name}" def _make_array_shape_and_strides( self, name: str, dims: Sequence[gtx_common.Dimension] @@ -353,7 +368,9 @@ def visit_SetAt( target_fields = self._visit_expression(stmt.target, sdfg, state, use_temp=False) # convert domain expression to dictionary to ease access to dimension boundaries - domain = dace_gtir_utils.get_domain_ranges(stmt.domain) + domain = { + dim: (lb, ub) for dim, lb, ub in gtir_builtin_translators.extract_domain(stmt.domain) + } expr_input_args = { sym_id @@ -422,7 +439,7 @@ def visit_FunCall( node, sdfg, head_state, self, reduce_identity ) elif cpm.is_applied_as_fieldop(node): - return gtir_builtin_translators.translate_as_field_op( + return gtir_builtin_translators.translate_as_fieldop( node, sdfg, head_state, self, reduce_identity ) elif isinstance(node.fun, gtir.Lambda): diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 2988b01a61..855dc9c91a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -11,61 +11,19 @@ import itertools from typing import Any -import dace - from gt4py import eve from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview import gtir_python_codegen from gt4py.next.type_system import type_specifications as ts -def get_domain( - node: gtir.Expr, -) -> list[tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: - """ - Specialized visit method for domain expressions. - - Returns for each domain dimension the corresponding range. - - TODO: Domain expressions will be recurrent in the GTIR program. An interesting idea - would be to cache the results of lowering here (e.g. using `functools.lru_cache`) - """ - assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) - - domain = [] - for named_range in node.args: - assert cpm.is_call_to(named_range, "named_range") - assert len(named_range.args) == 3 - axis = named_range.args[0] - assert isinstance(axis, gtir.AxisLiteral) - bounds = [ - dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(arg)) - for arg in named_range.args[1:3] - ] - dim = gtx_common.Dimension(axis.value, axis.kind) - domain.append((dim, bounds[0], bounds[1])) - - return domain - - -def get_domain_ranges( - node: gtir.Expr, -) -> dict[gtx_common.Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: - """ - Returns domain represented in dictionary form. - """ - domain = get_domain(node) - - return {dim: (lb, ub) for dim, lb, ub in domain} - - def get_map_variable(dim: gtx_common.Dimension) -> str: """ Format map variable name based on the naming convention for application-specific SDFG transformations. """ suffix = "dim" if dim.kind == gtx_common.DimensionKind.LOCAL else "" + # TODO(edopao): raise exception if dim.value is empty return f"i_{dim.value}_gtx_{dim.kind}{suffix}" @@ -140,7 +98,7 @@ def visit_FunCall(self, node: gtir.FunCall) -> gtir.Node: ) node.args = [] - node.args = [self.visit(arg) for arg in node.args] + node.args = self.visit(node.args) node.fun = self.visit(node.fun) return node diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index e819cdcd8c..98e15dac3c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1323,63 +1323,86 @@ def test_gtir_reduce_with_skip_values(): def test_gtir_reduce_dot_product(): - # FIXME[#1582](edopao): Enable testcase when type inference is working - pytest.skip("Field of lists not fully supported as a type in GTIR yet") init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) - testee = gtir.Program( - id="reduce_dot_product", - function_definitions=[], - params=[ - gtir.Sym(id="edges", type=EFTYPE), - gtir.Sym(id="vertices", type=VFTYPE), - gtir.Sym(id="nvertices", type=SIZE_TYPE), - ], - declarations=[], - body=[ - gtir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) - ), - vertex_domain, - ) - )( - im.op_as_fieldop("multiplies", vertex_domain)( - im.as_fieldop_neighbors("V2E", "edges", vertex_domain), - im.as_fieldop_neighbors("V2E", "edges", vertex_domain), - ), - ), - domain=vertex_domain, - target=gtir.SymRef(id="vertices"), - ) - ], - ) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) - e = np.random.rand(SIMPLE_MESH.num_edges) v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) v_ref = [ - reduce(lambda x, y: x + y, e[v2e_neighbors] * e[v2e_neighbors], init_value) + functools.reduce( + lambda x, y: x + y, (e[v2e_neighbors] * e[v2e_neighbors]) + 1.0, init_value + ) for v2e_neighbors in connectivity_V2E.table ] - sdfg( - e, - v, - connectivity_V2E=connectivity_V2E.table, - **FSYMBOLS, - **make_mesh_symbols(SIMPLE_MESH), + stencil_inlined = im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.map_("plus")( + im.map_("multiplies")( + im.neighbors("V2E", "it"), + im.neighbors("V2E", "it"), + ), + im.call("make_const_list")(1.0), + ) + ) + ), + vertex_domain, + ) + )("edges") + + stencil_fieldview = im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.deref("it") + ) + ), + vertex_domain, + ) + )( + im.op_as_fieldop(im.map_("plus"), vertex_domain)( + im.op_as_fieldop(im.map_("multiplies"), vertex_domain)( + im.as_fieldop_neighbors("V2E", "edges", vertex_domain), + im.as_fieldop_neighbors("V2E", "edges", vertex_domain), + ), + im.op_as_fieldop("make_const_list", vertex_domain)(1.0), + ) ) - assert np.allclose(v, v_ref) + + for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): + testee = gtir.Program( + id=f"reduce_dot_product_{i}", + function_definitions=[], + params=[ + gtir.Sym(id="edges", type=EFTYPE), + gtir.Sym(id="vertices", type=VFTYPE), + gtir.Sym(id="nvertices", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=stencil, + domain=vertex_domain, + target=gtir.SymRef(id="vertices"), + ) + ], + ) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + + sdfg( + e, + v, + connectivity_V2E=connectivity_V2E.table, + **FSYMBOLS, + **make_mesh_symbols(SIMPLE_MESH), + ) + assert np.allclose(v, v_ref) def test_gtir_reduce_with_cond_neighbors(): From 78791c77ece94ae80bb3ab49cb96bd2d72d7a8ba Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 18 Oct 2024 11:13:08 +0200 Subject: [PATCH 007/178] feat[next][dace]: GTIR-to-SDFG broadcast tuple scalar arg + cleanup (#1693) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR implements the lowering of expressions like: `as_fieldop(deref, domain)(x[0] + (x[1][0] × 2.0 + x[1][1] × 3.0))` where the elements of tuple `x` being accessed are all scalar values, so `x[0] + (x[1][0] × 2.0 + x[1][1] × 3.0)` is a scalar expression. Therefore, this PR allows to lower expressions where the argument to `as_field_op` is a scalar expression rather than a lambda. Note that: - in case of a lambda, the `as_fieldop` function computes the lambda over the field domain; - in case of scalar expression, it broadcasts the scalar value over the field domain. Additionally, some cleanup: - Removed call to `gt_simplify` so that dace backend returns a plain unoptimized SDFG. - Removed `patch_gtir`, since the functionality of this IR pass is provided by similar changes delivered in #1677 and #1689. --- .../gtir_builtin_translators.py | 83 +++++++++++++++++-- .../runners/dace_fieldview/gtir_sdfg.py | 4 - .../runners/dace_fieldview/utility.py | 41 --------- .../dace_tests/test_gtir_to_sdfg.py | 69 +++++++++++++-- 4 files changed, 139 insertions(+), 58 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 8fb1451efb..a8ae1cc0e8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -209,8 +209,8 @@ def translate_as_fieldop( The dataflow can be as simple as a single tasklet, or implement a local computation as a composition of tasklets and even include a map to range on local dimensions (e.g. neighbors and map builtins). - The stencil dataflow is instantiated inside a map scope, which apply the stencil over - the field domain. + The stencil dataflow is instantiated inside a map scope, which applies the stencil + over the field domain. """ assert isinstance(node, gtir.FunCall) assert cpm.is_call_to(node.fun, "as_fieldop") @@ -219,11 +219,24 @@ def translate_as_fieldop( fun_node = node.fun assert len(fun_node.args) == 2 stencil_expr, domain_expr = fun_node.args - assert isinstance(stencil_expr, gtir.Lambda) - assert isinstance(domain_expr, gtir.FunCall) + + if isinstance(stencil_expr, gtir.Lambda): + # Default case, handled below: the argument expression is a lambda function + # representing the stencil operation to be computed over the field domain. + pass + elif cpm.is_ref_to(stencil_expr, "deref"): + # Special usage of 'deref' as argument to fieldop expression, to pass a scalar + # value to 'as_fieldop' function. It results in broadcasting the scalar value + # over the field domain. + return translate_broadcast_scalar(node, sdfg, state, sdfg_builder, reduce_identity) + else: + raise NotImplementedError( + f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node." + ) # parse the domain of the field operator domain = extract_domain(domain_expr) + domain_indices = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain]) # The reduction identity value is used in place of skip values when building # a list of neighbor values in the unstructured domain. @@ -273,16 +286,15 @@ def translate_as_fieldop( input_edges, output = taskgen.visit(stencil_expr, args=stencil_args) output_desc = output.result.node.desc(sdfg) - domain_index = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain]) if isinstance(node.type.dtype, itir_ts.ListType): assert isinstance(output_desc, dace.data.Array) assert set(output_desc.offset) == {0} # additional local dimension for neighbors # TODO(phimuell): Investigate if we should swap the two. - output_subset = sbs.Range.from_indices(domain_index) + sbs.Range.from_array(output_desc) + output_subset = sbs.Range.from_indices(domain_indices) + sbs.Range.from_array(output_desc) else: assert isinstance(output_desc, dace.data.Scalar) - output_subset = sbs.Range.from_indices(domain_index) + output_subset = sbs.Range.from_indices(domain_indices) # create map range corresponding to the field operator domain me, mx = sdfg_builder.add_map( @@ -307,6 +319,62 @@ def translate_as_fieldop( return result_field +def translate_broadcast_scalar( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, + reduce_identity: Optional[gtir_dataflow.SymbolExpr], +) -> FieldopResult: + """ + Generates the dataflow subgraph for the 'as_fieldop' builtin function for the + special case where the argument to 'as_fieldop' is a 'deref' scalar expression, + rather than a lambda function. This case corresponds to broadcasting the scalar + value over the field domain. Therefore, it is lowered to a mapped tasklet that + just writes the scalar value out to all elements of the result field. + """ + assert isinstance(node, gtir.FunCall) + assert cpm.is_call_to(node.fun, "as_fieldop") + assert isinstance(node.type, ts.FieldType) + + fun_node = node.fun + assert len(fun_node.args) == 2 + stencil_expr, domain_expr = fun_node.args + assert cpm.is_ref_to(stencil_expr, "deref") + + domain = extract_domain(domain_expr) + domain_indices = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain]) + + assert len(node.args) == 1 + assert isinstance(node.args[0].type, ts.ScalarType) + scalar_expr = _parse_fieldop_arg( + node.args[0], sdfg, state, sdfg_builder, domain, reduce_identity=None + ) + assert isinstance(scalar_expr, gtir_dataflow.MemletExpr) + assert scalar_expr.subset == sbs.Indices.from_string("0") + result = gtir_dataflow.DataflowOutputEdge( + state, gtir_dataflow.DataExpr(scalar_expr.node, node.args[0].type) + ) + result_field = _create_temporary_field(sdfg, state, domain, node.type, dataflow_output=result) + + sdfg_builder.add_mapped_tasklet( + "broadcast", + state, + map_ranges={ + dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" + for dim, lower_bound, upper_bound in domain + }, + inputs={"__inp": dace.Memlet(data=scalar_expr.node.data, subset="0")}, + code="__val = __inp", + outputs={"__val": dace.Memlet(data=result_field.data_node.data, subset=domain_indices)}, + input_nodes={scalar_expr.node.data: scalar_expr.node}, + output_nodes={result_field.data_node.data: result_field.data_node}, + external_edges=True, + ) + + return result_field + + def translate_if( node: gtir.Node, sdfg: dace.SDFG, @@ -638,6 +706,7 @@ def translate_symbol_ref( # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol __primitive_translators: list[PrimitiveTranslator] = [ translate_as_fieldop, + translate_broadcast_scalar, translate_if, translate_literal, translate_make_tuple, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index d79d887318..3697609d76 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -32,7 +32,6 @@ from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_builtin_translators, gtir_dataflow, - transformations as gtx_transformations, utility as dace_gtir_utils, ) from gt4py.next.type_system import type_specifications as ts, type_translation as tt @@ -665,7 +664,6 @@ def build_sdfg_from_gtir( The lowering to SDFG requires that the program node is type-annotated, therefore this function runs type ineference as first step. - As a final step, it runs the `simplify` pass to ensure that the SDFG is in the DaCe canonical form. Arguments: ir: The GTIR program node to be lowered to SDFG @@ -677,10 +675,8 @@ def build_sdfg_from_gtir( ir = gtir_type_inference.infer(ir, offset_provider=offset_provider) ir = ir_prune_casts.PruneCasts().visit(ir) - ir = dace_gtir_utils.patch_gtir(ir) sdfg_genenerator = GTIRToSDFG(offset_provider) sdfg = sdfg_genenerator.visit(ir) assert isinstance(sdfg, dace.SDFG) - gtx_transformations.gt_simplify(sdfg) return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 855dc9c91a..b5c447a1be 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -11,10 +11,7 @@ import itertools from typing import Any -from gt4py import eve from gt4py.next import common as gtx_common -from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.type_system import type_specifications as ts @@ -65,41 +62,3 @@ def get_tuple_type(data: tuple[Any, ...]) -> ts.TupleType: return ts.TupleType( types=[get_tuple_type(d) if isinstance(d, tuple) else d.data_type for d in data] ) - - -def patch_gtir(ir: gtir.Program) -> gtir.Program: - """ - Make the IR compliant with the requirements of lowering to SDFG. - - Applies canonicalization of as_fieldop expressions as well as some temporary workarounds. - This allows to lower the IR to SDFG for some special cases. - """ - - class PatchGTIR(eve.PreserveLocationVisitor, eve.NodeTranslator): - def visit_FunCall(self, node: gtir.FunCall) -> gtir.Node: - if cpm.is_applied_as_fieldop(node): - assert isinstance(node.fun, gtir.FunCall) - assert isinstance(node.type, ts.FieldType) - - # Handle the case of fieldop without domain. This case should never happen, but domain - # inference currently produces this kind of nodes for unreferenced tuple fields. - # TODO(tehrengruber): remove this workaround once domain ineference supports this case - if len(node.fun.args) == 1: - return gtir.Literal(value="0", type=node.type.dtype) - - assert len(node.fun.args) == 2 - stencil = node.fun.args[0] - - # Canonicalize as_fieldop: always expect a lambda expression. - # Here we replace the call to deref with a lambda expression and empty arguments list. - if cpm.is_ref_to(stencil, "deref"): - node.fun.args[0] = gtir.Lambda( - expr=gtir.FunCall(fun=stencil, args=node.args), params=[] - ) - node.args = [] - - node.args = self.visit(node.args) - node.fun = self.visit(node.fun) - return node - - return PatchGTIR().visit(ir) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 98e15dac3c..728b4b02b9 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -40,10 +40,11 @@ N = 10 -IFTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) -CFTYPE = ts.FieldType(dims=[Cell], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) -EFTYPE = ts.FieldType(dims=[Edge], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) -VFTYPE = ts.FieldType(dims=[Vertex], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) +FLOAT_TYPE = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) +IFTYPE = ts.FieldType(dims=[IDim], dtype=FLOAT_TYPE) +CFTYPE = ts.FieldType(dims=[Cell], dtype=FLOAT_TYPE) +EFTYPE = ts.FieldType(dims=[Edge], dtype=FLOAT_TYPE) +VFTYPE = ts.FieldType(dims=[Vertex], dtype=FLOAT_TYPE) V2E_FTYPE = ts.FieldType(dims=[Vertex, V2EDim], dtype=EFTYPE.dtype) CARTESIAN_OFFSETS = { "IDim": IDim, @@ -315,6 +316,62 @@ def test_gtir_tuple_expr(): assert np.allclose(c, a * 2 + b) +def test_gtir_tuple_broadcast_scalar(): + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + testee = gtir.Program( + id="gtir_tuple_broadcast_scalar", + function_definitions=[], + params=[ + gtir.Sym( + id="x", + type=ts.TupleType(types=[FLOAT_TYPE, ts.TupleType(types=[FLOAT_TYPE, FLOAT_TYPE])]), + ), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.as_fieldop("deref", domain)( + im.plus( + im.tuple_get(0, "x"), + im.plus( + im.multiplies_( + im.tuple_get( + 0, + im.tuple_get(1, "x"), + ), + 2.0, + ), + im.multiplies_( + im.tuple_get( + 1, + im.tuple_get(1, "x"), + ), + 3.0, + ), + ), + ) + ), + domain=domain, + target=gtir.SymRef(id="y"), + ) + ], + ) + + a = np.random.rand() + b = np.random.rand() + c = np.random.rand() + d = np.empty(N, dtype=type(a)) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + + x_fields = (a, b, c) + + sdfg(*x_fields, d, **FSYMBOLS) + assert np.allclose(d, a + 2 * b + 3 * c) + + def test_gtir_tuple_return(): domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) testee = gtir.Program( @@ -954,8 +1011,8 @@ def test_gtir_connectivity_shift(): im.op_as_fieldop("plus", edge_domain)("e2v_offset", 0), ) - CE_FTYPE = ts.FieldType(dims=[Cell, Edge], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) - EV_FTYPE = ts.FieldType(dims=[Edge, Vertex], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + CE_FTYPE = ts.FieldType(dims=[Cell, Edge], dtype=FLOAT_TYPE) + EV_FTYPE = ts.FieldType(dims=[Edge, Vertex], dtype=FLOAT_TYPE) CELL_OFFSET_FTYPE = ts.FieldType(dims=[Cell], dtype=SIZE_TYPE) EDGE_OFFSET_FTYPE = ts.FieldType(dims=[Edge], dtype=SIZE_TYPE) From cb77ccb85c1d5baa83720493c626a66c0c116451 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 18 Oct 2024 11:49:12 +0200 Subject: [PATCH 008/178] refactor[cartesian]: Warn if GT4Py can't find DaCe (#1692) DaCe backends are optional backends in the cartesian version of GT4Py. Currently, we silently drop support for DaCe backends if DaCe can't be imported. This can can happen because DaCe isn't available or for a couple other reasons (e.g. a circular import in the import path). With this PR we thus add a warning message allowing developers to easily figure out that/why DaCe backends are disabled. --- src/gt4py/cartesian/backend/__init__.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/gt4py/cartesian/backend/__init__.py b/src/gt4py/cartesian/backend/__init__.py index 7a6f877295..e58c7a01a7 100644 --- a/src/gt4py/cartesian/backend/__init__.py +++ b/src/gt4py/cartesian/backend/__init__.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from warnings import warn + from .base import ( REGISTRY, Backend, @@ -16,13 +18,6 @@ from_name, register, ) - - -try: - from .dace_backend import DaceCPUBackend, DaceGPUBackend -except ImportError: - pass - from .cuda_backend import CudaBackend from .gtcpp_backend import GTCpuIfirstBackend, GTCpuKfirstBackend, GTGpuBackend from .module_generator import BaseModuleGenerator @@ -47,5 +42,12 @@ ] -if "DaceCPUBackend" in globals(): +try: + from .dace_backend import DaceCPUBackend, DaceGPUBackend + __all__ += ["DaceCPUBackend", "DaceGPUBackend"] +except ImportError: + warn( + "GT4Py was unable to load DaCe. DaCe backends (`dace:cpu` and `dace:gpu`) will not be available.", + stacklevel=2, + ) From eb0a0c1b322c2a9c32f2beb9e087ee940b4bd19c Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 18 Oct 2024 15:21:55 +0200 Subject: [PATCH 009/178] feat[next]: GTIR temporary extraction pass (#1678) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New temporary extraction pass. Transforms an `itir.Program` like ``` testee(inp, out) { out @ c⟨ IDimₕ: [0, 1) ⟩ ← as_fieldop(deref, c⟨ IDimₕ: [0, 1) ⟩)(as_fieldop(deref, c⟨ IDimₕ: [0, 1) ⟩)(inp)); } ``` into ``` testee(inp, out) { __tmp_1 = temporary(domain=c⟨ IDimₕ: [0, 1) ⟩, dtype=float64); __tmp_1 @ c⟨ IDimₕ: [0, 1) ⟩ ← as_fieldop(deref, c⟨ IDimₕ: [0, 1) ⟩)(inp); out @ c⟨ IDimₕ: [0, 1) ⟩ ← as_fieldop(deref, c⟨ IDimₕ: [0, 1) ⟩)(__tmp_1); } ``` Note that this pass intentionally unconditionally extracts. In case you don't want a temporary you should fuse the `as_fieldop` before. As such the fusion pass (see https://github.com/GridTools/gt4py/pull/1670) contains the heuristics on what to fuse. --- src/gt4py/next/iterator/transforms/cse.py | 13 +- .../iterator/transforms/fencil_to_program.py | 15 +- .../next/iterator/transforms/global_tmps.py | 720 ++++-------------- .../next/iterator/transforms/infer_domain.py | 45 +- .../next/iterator/transforms/pass_manager.py | 52 +- .../next/iterator/type_system/inference.py | 27 +- .../next/program_processors/runners/gtfn.py | 4 +- src/gt4py/next/utils.py | 26 +- .../transforms_tests/test_cse.py | 12 +- .../transforms_tests/test_domain_inference.py | 88 ++- .../transforms_tests/test_global_tmps.py | 553 ++++---------- .../runners_tests/test_gtfn.py | 14 - 12 files changed, 476 insertions(+), 1093 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 1a89adbb20..ccc1d2195f 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -32,7 +32,7 @@ @dataclasses.dataclass class _NodeReplacer(PreserveLocationVisitor, NodeTranslator): - PRESERVED_ANNEX_ATTRS = ("type",) + PRESERVED_ANNEX_ATTRS = ("type", "domain") expr_map: dict[int, itir.SymRef] @@ -43,15 +43,16 @@ def visit_Expr(self, node: itir.Node) -> itir.Node: def visit_FunCall(self, node: itir.FunCall) -> itir.Node: node = cast(itir.FunCall, self.visit_Expr(node)) + # TODO(tehrengruber): Use symbol name from the inner let, to increase readability of IR # If we encounter an expression like: # (λ(_cs_1) → (λ(a) → a+a)(_cs_1))(outer_expr) # (non-recursively) inline the lambda to obtain: # (λ(_cs_1) → _cs_1+_cs_1)(outer_expr) - # This allows identifying more common subexpressions later on + # In the CSE this allows identifying more common subexpressions later on. Other users + # of `extract_subexpression` (e.g. temporary extraction) can also rely on this to avoid + # the need to handle this artificial let-statements. if isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda): - eligible_params = [] - for arg in node.args: - eligible_params.append(isinstance(arg, itir.SymRef) and arg.id.startswith("_cs")) + eligible_params = [isinstance(arg, itir.SymRef) for arg in node.args] if any(eligible_params): # note: the inline is opcount preserving anyway so avoid the additional # effort in the inliner by disabling opcount preservation. @@ -319,7 +320,7 @@ def extract_subexpression( subexprs = CollectSubexpressions.apply(node) # collect multiple occurrences and map them to fresh symbols - expr_map = dict[int, itir.SymRef]() + expr_map: dict[int, itir.SymRef] = {} ignored_ids = set() for expr, subexpr_entry in ( subexprs.items() if not deepest_expr_first else reversed(subexprs.items()) diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py index db0b81a837..4ad91645d4 100644 --- a/src/gt4py/next/iterator/transforms/fencil_to_program.py +++ b/src/gt4py/next/iterator/transforms/fencil_to_program.py @@ -9,14 +9,11 @@ from gt4py import eve from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms import global_tmps class FencilToProgram(eve.NodeTranslator): @classmethod - def apply( - cls, node: itir.FencilDefinition | global_tmps.FencilWithTemporaries | itir.Program - ) -> itir.Program: + def apply(cls, node: itir.FencilDefinition | itir.Program) -> itir.Program: return cls().visit(node) def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.SetAt: @@ -32,13 +29,3 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program: body=self.visit(node.closures), implicit_domain=node.implicit_domain, ) - - def visit_FencilWithTemporaries(self, node: global_tmps.FencilWithTemporaries) -> itir.Program: - return itir.Program( - id=node.fencil.id, - function_definitions=node.fencil.function_definitions, - params=node.params, - declarations=node.tmps, - body=self.visit(node.fencil.closures), - implicit_domain=node.fencil.implicit_domain, - ) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 5a6873f916..11d3fccec1 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -8,582 +8,196 @@ from __future__ import annotations -import copy -import dataclasses -from collections.abc import Mapping -from typing import Any, Callable, Final, Iterable, Literal, Optional, Sequence - -import gt4py.next as gtx -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.eve.traits import SymbolTableTrait -from gt4py.eve.utils import UIDGenerator -from gt4py.next import common -from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.ir_utils.domain_utils import ( - SymbolicDomain, - SymbolicRange, - _max_domain_sizes_by_location_type, - domain_union, -) -from gt4py.next.iterator.pretty_printer import PrettyPrinter -from gt4py.next.iterator.transforms import trace_shifts -from gt4py.next.iterator.transforms.cse import extract_subexpression -from gt4py.next.iterator.transforms.eta_reduction import EtaReduction -from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas -from gt4py.next.iterator.transforms.prune_closure_inputs import PruneClosureInputs -from gt4py.next.iterator.transforms.symbol_ref_utils import collect_symbol_refs -from gt4py.next.iterator.type_system import ( - inference as itir_type_inference, - type_specifications as it_ts, +import functools +from typing import Callable, Optional + +from gt4py.eve import utils as eve_utils +from gt4py.next import common, utils as next_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, ) -from gt4py.next.type_system import type_specifications as ts - - -"""Iterator IR extension for global temporaries. - -Replaces lifted function calls by temporaries using the following steps: -1. Split closures by popping up lifted function calls to the top of the expression tree, (that is, - to stencil arguments) and then extracting them as new closures. -2. Introduces a new fencil-scope variable (the temporary) for each output of newly created closures. - The domain size is set to a new symbol `_gtmp_auto_domain`. -3. Infer the domain sizes for the new closures by analysing the accesses/shifts within all closures - and replace all occurrences of `_gtmp_auto_domain` by concrete domain sizes. -4. Infer the data type and size of the temporary buffers. -""" - - -AUTO_DOMAIN: Final = ir.FunCall(fun=ir.SymRef(id="_gtmp_auto_domain"), args=[]) - - -# Iterator IR extension nodes - - -class FencilWithTemporaries( - ir.Node, SymbolTableTrait -): # TODO(havogt): remove and use new `itir.Program` instead. - """Iterator IR extension: declaration of a fencil with temporary buffers.""" - - fencil: ir.FencilDefinition - params: list[ir.Sym] - tmps: list[ir.Temporary] - - -# Extensions for `PrettyPrinter` for easier debugging - - -def pformat_FencilWithTemporaries( - printer: PrettyPrinter, node: FencilWithTemporaries, *, prec: int -) -> list[str]: - assert prec == 0 - params = printer.visit(node.params, prec=0) - fencil = printer.visit(node.fencil, prec=0) - tmps = printer.visit(node.tmps, prec=0) - args = params + [[tmp.id] for tmp in node.tmps] - - hparams = printer._hmerge([node.fencil.id + "("], *printer._hinterleave(params, ", "), [") {"]) - vparams = printer._vmerge( - [node.fencil.id + "("], *printer._hinterleave(params, ",", indent=True), [") {"] - ) - params = printer._optimum(hparams, vparams) - - hargs = printer._hmerge(*printer._hinterleave(args, ", ")) - vargs = printer._vmerge(*printer._hinterleave(args, ",")) - args = printer._optimum(hargs, vargs) - - fencil = printer._hmerge(fencil, [";"]) - - hcall = printer._hmerge([node.fencil.id + "("], args, [");"]) - vcall = printer._vmerge(printer._hmerge([node.fencil.id + "("]), printer._indent(args), [");"]) - call = printer._optimum(hcall, vcall) - - body = printer._vmerge(*tmps, fencil, call) - return printer._vmerge(params, printer._indent(body), ["}"]) - - -PrettyPrinter.visit_FencilWithTemporaries = pformat_FencilWithTemporaries # type: ignore - - -# Main implementation -def canonicalize_applied_lift(closure_params: list[str], node: ir.FunCall) -> ir.FunCall: - """ - Canonicalize applied lift expressions. - - Transform lift such that the arguments to the applied lift are only symbols. - - >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) - >>> it_type = it_ts.IteratorType(position_dims=[], defined_dims=[], element_type=bool_type) - >>> expr = im.lift(im.lambda_("a")(im.deref("a")))(im.lift("deref")(im.ref("inp", it_type))) - >>> print(expr) - (↑(λ(a) → ·a))((↑deref)(inp)) - >>> print(canonicalize_applied_lift(["inp"], expr)) - (↑(λ(inp) → (λ(a) → ·a)((↑deref)(inp))))(inp) - """ - assert cpm.is_applied_lift(node) - stencil = node.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied lift - it_args = node.args - if any(not isinstance(it_arg, ir.SymRef) for it_arg in it_args): - closure_param_refs = collect_symbol_refs(node, as_ref=True) - assert not ({str(ref.id) for ref in closure_param_refs} - set(closure_params)) - new_node = im.lift( - im.lambda_(*[im.sym(param.id) for param in closure_param_refs])( - im.call(stencil)(*it_args) +from gt4py.next.iterator.transforms import cse, infer_domain, inline_lambdas +from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_info, type_specifications as ts + + +def _transform_if( + stmt: itir.Stmt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator +) -> Optional[list[itir.Stmt]]: + if isinstance(stmt, itir.SetAt) and cpm.is_call_to(stmt.expr, "if_"): + cond, true_val, false_val = stmt.expr.args + return [ + itir.IfStmt( + cond=cond, + true_branch=_transform_stmt( + itir.SetAt(target=stmt.target, expr=true_val, domain=stmt.domain), + declarations, + uids, + ), + false_branch=_transform_stmt( + itir.SetAt(target=stmt.target, expr=false_val, domain=stmt.domain), + declarations, + uids, + ), ) - )(*closure_param_refs) - # ensure all types are inferred - return itir_type_inference.infer( - new_node, inplace=True, allow_undeclared_symbols=True, offset_provider={} - ) - return node - - -@dataclasses.dataclass(frozen=True) -class TemporaryExtractionPredicate: - """ - Construct a callable that determines if a lift expr can and should be extracted to a temporary. - - The class optionally takes a heuristic that can restrict the extraction. - """ - - heuristics: Optional[Callable[[ir.Expr], bool]] = None - - def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: - """Determine if `expr` is an applied lift that should be extracted as a temporary.""" - if not cpm.is_applied_lift(expr): - return False - # do not extract when the result is a list (i.e. a lift expression used in a `reduce` call) - # as we can not create temporaries for these stencils - assert isinstance(expr.type, it_ts.IteratorType) - if isinstance(expr.type.element_type, it_ts.ListType): - return False - if self.heuristics and not self.heuristics(expr): - return False - stencil = expr.fun.args[0] # type: ignore[attr-defined] # ensured by `is_applied_lift` - # do not extract when the stencil is capturing - used_symbols = collect_symbol_refs(stencil) - if used_symbols: - return False - return True - - -@dataclasses.dataclass(frozen=True) -class SimpleTemporaryExtractionHeuristics: - """ - Heuristic that extracts only if a lift expr is derefed in more than one position. - - Note that such expression result in redundant computations if inlined instead of being - placed into a temporary. - """ - - closure: ir.StencilClosure - - def __post_init__(self) -> None: - trace_shifts.trace_stencil( - self.closure.stencil, num_args=len(self.closure.inputs), save_to_annex=True - ) - - def __call__(self, expr: ir.Expr) -> bool: - shifts = expr.annex.recorded_shifts - if len(shifts) > 1: - return True - return False - - -def _closure_parameter_argument_mapping(closure: ir.StencilClosure) -> dict[str, ir.Expr]: - """ - Create a mapping from the closures parameters to the closure arguments. - - E.g. for the closure `out ← (λ(param) → ...)(arg) @ u⟨ ... ⟩;` we get a mapping from `param` - to `arg`. In case the stencil is a scan, a mapping from closure inputs to scan pass (i.e. first - arg is ignored) is returned. - """ - is_scan = cpm.is_call_to(closure.stencil, "scan") - - if is_scan: - stencil = closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan - return { - param.id: arg for param, arg in zip(stencil.params[1:], closure.inputs, strict=True) - } - else: - assert isinstance(closure.stencil, ir.Lambda) - return { - param.id: arg for param, arg in zip(closure.stencil.params, closure.inputs, strict=True) - } - - -def _ensure_expr_does_not_capture(expr: ir.Expr, whitelist: list[ir.Sym]) -> None: - used_symbol_refs = collect_symbol_refs(expr) - assert not (set(used_symbol_refs) - {param.id for param in whitelist}) - - -def split_closures( - node: ir.FencilDefinition, - offset_provider: common.OffsetProvider, - *, - extraction_heuristics: Optional[ - Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] - ] = None, -) -> FencilWithTemporaries: - """Split closures on lifted function calls and introduce new temporary buffers for return values. - - Newly introduced temporaries will have the symbolic size of `AUTO_DOMAIN`. A symbol with the - same name is also added as a fencil argument (to be replaced at a later stage). - - For each closure, follows these steps: - 1. Pops up lifted function calls to the top of the expression tree. - 2. Introduce new temporary for the output. - 3. Extract lifted function class as new closures with the previously created temporary as output. - The closures are processed in reverse order to properly respect the dependencies. - """ - if not extraction_heuristics: - # extract all (eligible) lifts - def always_extract_heuristics(_: ir.StencilClosure) -> Callable[[ir.Expr], bool]: - return lambda _: True - - extraction_heuristics = always_extract_heuristics - - uid_gen_tmps = UIDGenerator(prefix="_tmp") - - node = itir_type_inference.infer(node, offset_provider=offset_provider) - - tmps: list[tuple[str, ts.DataType]] = [] - - closures: list[ir.StencilClosure] = [] - for closure in reversed(node.closures): - closure_stack: list[ir.StencilClosure] = [closure] - while closure_stack: - current_closure: ir.StencilClosure = closure_stack.pop() - - if ( - isinstance(current_closure.stencil, ir.SymRef) - and current_closure.stencil.id == "deref" - ): - closures.append(current_closure) - continue - - is_scan: bool = cpm.is_call_to(current_closure.stencil, "scan") - current_closure_stencil = ( - current_closure.stencil if not is_scan else current_closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan - ) - - extraction_predicate = TemporaryExtractionPredicate( - extraction_heuristics(current_closure) - ) - - stencil_body, extracted_lifts, _ = extract_subexpression( - current_closure_stencil.expr, - extraction_predicate, - uid_gen_tmps, - once_only=True, - deepest_expr_first=True, - ) - - if extracted_lifts: - for tmp_sym, lift_expr in extracted_lifts.items(): - # make sure the applied lift is not capturing anything except of closure params - _ensure_expr_does_not_capture(lift_expr, current_closure_stencil.params) - - assert isinstance(lift_expr, ir.FunCall) and isinstance( - lift_expr.fun, ir.FunCall - ) - - # make sure the arguments to the applied lift are only symbols - if not all(isinstance(arg, ir.SymRef) for arg in lift_expr.args): - lift_expr = canonicalize_applied_lift( - [str(param.id) for param in current_closure_stencil.params], lift_expr - ) - assert all(isinstance(arg, ir.SymRef) for arg in lift_expr.args) - - # create a mapping from the closures parameters to the closure arguments - closure_param_arg_mapping = _closure_parameter_argument_mapping(current_closure) - - # usually an ir.Lambda or scan - stencil: ir.Node = lift_expr.fun.args[0] # type: ignore[attr-defined] # ensured by canonicalize_applied_lift - - # allocate a new temporary - assert isinstance(stencil.type, ts.FunctionType) - assert isinstance(stencil.type.returns, ts.DataType) - tmps.append((tmp_sym.id, stencil.type.returns)) - - # create a new closure that executes the stencil of the applied lift and - # writes the result to the newly created temporary - closure_stack.append( - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=stencil, - output=im.ref(tmp_sym.id), - inputs=[ - closure_param_arg_mapping[param.id] # type: ignore[attr-defined] - for param in lift_expr.args - ], - location=current_closure.location, - ) - ) + ] + return None + + +def _transform_by_pattern( + stmt: itir.Stmt, + predicate: Callable[[itir.Expr, int], bool], + declarations: list[itir.Temporary], + uids: eve_utils.UIDGenerator, +) -> Optional[list[itir.Stmt]]: + if not isinstance(stmt, itir.SetAt): + return None + + new_expr, extracted_fields, _ = cse.extract_subexpression( + stmt.expr, + predicate=predicate, + uid_generator=eve_utils.UIDGenerator(prefix="__tmp_subexpr"), + # TODO(tehrengruber): extracting the deepest expression first would allow us to fuse + # the extracted expressions resulting in fewer kernel calls & better data-locality. + # Extracting multiple expressions deepest-first is however not supported right now. + # deepest_expr_first=True # noqa: ERA001 + ) - new_stencil: ir.Lambda | ir.FunCall - # create a new stencil where all applied lifts that have been extracted are - # replaced by references to the respective temporary - new_stencil = ir.Lambda( - params=current_closure_stencil.params + list(extracted_lifts.keys()), - expr=stencil_body, - ) - # if we are extracting from an applied scan we have to wrap the scan pass again, - # i.e. transform `λ(state, ...) → ...` into `scan(λ(state, ...) → ..., ...)` - if is_scan: - new_stencil = im.call("scan")(new_stencil, current_closure.stencil.args[1:]) # type: ignore[attr-defined] # ensure by is_scan - # inline such that let statements which are just rebinding temporaries disappear - new_stencil = InlineLambdas.apply( - new_stencil, opcount_preserving=True, force_inline_lift_args=False + if extracted_fields: + tmp_stmts: list[itir.Stmt] = [] + + # for each extracted expression generate: + # - one or more `Temporary` declarations (depending on whether the expression is a field + # or a tuple thereof) + # - one `SetAt` statement that materializes the expression into the temporary + for tmp_sym, tmp_expr in extracted_fields.items(): + domain = tmp_expr.annex.domain + + # TODO(tehrengruber): Implement. This happens when the expression is a combination + # of an `if_` call with a tuple, e.g., `if_(cond, {a, b}, {c, d})`. As long as we are + # able to eliminate all tuples, e.g., by propagating the scalar ifs to the top-level + # of a SetAt, the CollapseTuple pass will eliminate most of this cases. + if isinstance(domain, tuple): + flattened_domains: tuple[domain_utils.SymbolicDomain] = ( + next_utils.flatten_nested_tuple(domain) # type: ignore[assignment] # mypy not smart enough ) - # we're done with the current closure, add it back to the stack for further - # extraction. - closure_stack.append( - ir.StencilClosure( - domain=current_closure.domain, - stencil=new_stencil, - output=current_closure.output, - inputs=current_closure.inputs - + [ir.SymRef(id=sym.id) for sym in extracted_lifts.keys()], - location=current_closure.location, + if not all(d == flattened_domains[0] for d in flattened_domains): + raise NotImplementedError( + "Tuple expressions with different domains is not supported yet." ) + domain = flattened_domains[0] + assert isinstance(domain, domain_utils.SymbolicDomain) + domain_expr = domain.as_expr() + + assert isinstance(tmp_expr.type, ts.TypeSpec) + tmp_names: str | tuple[str | tuple, ...] = type_info.apply_to_primitive_constituents( + lambda x: uids.sequential_id(), + tmp_expr.type, + tuple_constructor=lambda *elements: tuple(elements), + ) + tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = ( + type_info.apply_to_primitive_constituents( + type_info.extract_dtype, + tmp_expr.type, + tuple_constructor=lambda *elements: tuple(elements), ) - else: - closures.append(current_closure) - - return FencilWithTemporaries( - fencil=ir.FencilDefinition( - id=node.id, - function_definitions=node.function_definitions, - params=node.params + [im.sym(name) for name, _ in tmps] + [im.sym(AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant - closures=list(reversed(closures)), - location=node.location, - implicit_domain=node.implicit_domain, - ), - params=node.params, - tmps=[ir.Temporary(id=name, dtype=type_) for name, type_ in tmps], - ) - - -def prune_unused_temporaries(node: FencilWithTemporaries) -> FencilWithTemporaries: - """Remove temporaries that are never read.""" - unused_tmps = {tmp.id for tmp in node.tmps} - for closure in node.fencil.closures: - unused_tmps -= {inp.id for inp in closure.inputs} - - if not unused_tmps: - return node - - closures = [ - closure - for closure in node.fencil.closures - if not (isinstance(closure.output, ir.SymRef) and closure.output.id in unused_tmps) - ] - return FencilWithTemporaries( - fencil=ir.FencilDefinition( - id=node.fencil.id, - function_definitions=node.fencil.function_definitions, - params=[p for p in node.fencil.params if p.id not in unused_tmps], - closures=closures, - location=node.fencil.location, - ), - params=node.params, - tmps=[tmp for tmp in node.tmps if tmp.id not in unused_tmps], - ) - - -def _group_offsets( - offset_literals: Sequence[ir.OffsetLiteral], -) -> Sequence[tuple[str, int | Literal[trace_shifts.Sentinel.ALL_NEIGHBORS]]]: - tags = [tag.value for tag in offset_literals[::2]] - offsets = [ - offset.value if isinstance(offset, ir.OffsetLiteral) else offset - for offset in offset_literals[1::2] - ] - assert all(isinstance(tag, str) for tag in tags) - assert all( - isinstance(offset, int) or offset == trace_shifts.Sentinel.ALL_NEIGHBORS - for offset in offsets - ) - return zip(tags, offsets, strict=True) # type: ignore[return-value] # mypy doesn't infer literal correctly - + ) -def update_domains( - node: FencilWithTemporaries, - offset_provider: Mapping[str, Any], - symbolic_sizes: Optional[dict[str, str]], -) -> FencilWithTemporaries: - horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) - closures: list[ir.StencilClosure] = [] - domains = dict[str, ir.FunCall]() - for closure in reversed(node.fencil.closures): - if closure.domain == AUTO_DOMAIN: - # every closure with auto domain should have a single out field - assert isinstance(closure.output, ir.SymRef) + # allocate temporary for all tuple elements + def allocate_temporary(tmp_name: str, dtype: ts.ScalarType): + declarations.append(itir.Temporary(id=tmp_name, domain=domain_expr, dtype=dtype)) # noqa: B023 # function only used inside loop - if closure.output.id not in domains: - raise NotImplementedError(f"Closure output '{closure.output.id}' is never used.") + next_utils.tree_map(allocate_temporary)(tmp_names, tmp_dtypes) - domain = domains[closure.output.id] + # if the expr is a field this just gives a simple `itir.SymRef`, otherwise we generate a + # `make_tuple` expression. + target_expr: itir.Expr = next_utils.tree_map( + lambda x: im.ref(x), result_collection_constructor=lambda els: im.make_tuple(*els) + )(tmp_names) # type: ignore[assignment] # typing of tree_map does not reflect action of `result_collection_constructor` yet - closure = ir.StencilClosure( - domain=copy.deepcopy(domain), - stencil=closure.stencil, - output=closure.output, - inputs=closure.inputs, - location=closure.location, + # note: the let would be removed automatically by the `cse.extract_subexpression`, but + # we remove it here for readability & debuggability. + new_expr = inline_lambdas.inline_lambda( + im.let(tmp_sym, target_expr)(new_expr), opcount_preserving=False ) - else: - domain = closure.domain - closures.append(closure) - - local_shifts = trace_shifts.trace_stencil(closure.stencil, num_args=len(closure.inputs)) - for param_sym, shift_chains in zip(closure.inputs, local_shifts): - param = param_sym.id - assert isinstance(param, str) - consumed_domains: list[SymbolicDomain] = ( - [SymbolicDomain.from_expr(domains[param])] if param in domains else [] + # TODO(tehrengruber): _transform_stmt not needed if deepest_expr_first=True + tmp_stmts.extend( + _transform_stmt( + itir.SetAt(target=target_expr, domain=domain_expr, expr=tmp_expr), + declarations, + uids, + ) ) - for shift_chain in shift_chains: - consumed_domain = SymbolicDomain.from_expr(domain) - for offset_name, offset in _group_offsets(shift_chain): - if isinstance(offset_provider[offset_name], gtx.Dimension): - # cartesian shift - dim = offset_provider[offset_name] - assert offset is not trace_shifts.Sentinel.ALL_NEIGHBORS - consumed_domain.ranges[dim] = consumed_domain.ranges[dim].translate(offset) - elif isinstance(offset_provider[offset_name], common.Connectivity): - # unstructured shift - nbt_provider = offset_provider[offset_name] - old_axis = nbt_provider.origin_axis - new_axis = nbt_provider.neighbor_axis - assert new_axis not in consumed_domain.ranges or old_axis == new_axis + return [*tmp_stmts, itir.SetAt(target=stmt.target, domain=stmt.domain, expr=new_expr)] + return None - if symbolic_sizes is None: - new_range = SymbolicRange( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.literal( - str(horizontal_sizes[new_axis.value]), ir.INTEGER_INDEX_BUILTIN - ), - ) - else: - new_range = SymbolicRange( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(symbolic_sizes[new_axis.value]), - ) - consumed_domain.ranges = dict( - (axis, range_) if axis != old_axis else (new_axis, new_range) - for axis, range_ in consumed_domain.ranges.items() - ) - # TODO(tehrengruber): Revisit. Somehow the order matters so preserve it. - consumed_domain.ranges = dict( - (axis, range_) if axis != old_axis else (new_axis, new_range) - for axis, range_ in consumed_domain.ranges.items() - ) - else: - raise NotImplementedError() - consumed_domains.append(consumed_domain) - # compute the bounds of all consumed domains - if consumed_domains: - if all( - consumed_domain.ranges.keys() == consumed_domains[0].ranges.keys() - for consumed_domain in consumed_domains - ): # scalar otherwise - domains[param] = domain_union(*consumed_domains).as_expr() +def _transform_stmt( + stmt: itir.Stmt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator +) -> list[itir.Stmt]: + unprocessed_stmts: list[itir.Stmt] = [stmt] + stmts: list[itir.Stmt] = [] - return FencilWithTemporaries( - fencil=ir.FencilDefinition( - id=node.fencil.id, - function_definitions=node.fencil.function_definitions, - params=node.fencil.params[:-1], # remove `_gtmp_auto_domain` param again - closures=list(reversed(closures)), - location=node.fencil.location, - implicit_domain=node.fencil.implicit_domain, + transforms: list[Callable] = [ + # transform `if_` call into `IfStmt` + _transform_if, + # extract applied `as_fieldop` to top-level + functools.partial( + _transform_by_pattern, predicate=lambda expr, _: cpm.is_applied_as_fieldop(expr) ), - params=node.params, - tmps=node.tmps, - ) - - -def _tuple_constituents(node: ir.Expr) -> Iterable[ir.Expr]: - if cpm.is_call_to(node, "make_tuple"): - for arg in node.args: - yield from _tuple_constituents(arg) - else: - yield node - - -def collect_tmps_info( - node: FencilWithTemporaries, *, offset_provider: common.OffsetProvider -) -> FencilWithTemporaries: - """Perform type inference for finding the types of temporaries and sets the temporary size.""" - tmps = {tmp.id for tmp in node.tmps} - domains: dict[str, ir.Expr] = {} - for closure in node.fencil.closures: - for output_field in _tuple_constituents(closure.output): - assert isinstance(output_field, ir.SymRef) - if output_field.id not in tmps: - continue - - assert output_field.id not in domains or domains[output_field.id] == closure.domain - domains[output_field.id] = closure.domain + # extract if_ call to the top-level + functools.partial( + _transform_by_pattern, predicate=lambda expr, _: cpm.is_call_to(expr, "if_") + ), + ] - new_node = FencilWithTemporaries( - fencil=node.fencil, - params=node.params, - tmps=[ - ir.Temporary(id=tmp.id, domain=domains[tmp.id], dtype=tmp.dtype) for tmp in node.tmps - ], - ) - # TODO(tehrengruber): type inference is only really needed to infer the types of the temporaries - # and write them to the params of the inner fencil. This should be cleaned up after we - # refactored the IR. - return itir_type_inference.infer(new_node, offset_provider=offset_provider) + while unprocessed_stmts: + stmt = unprocessed_stmts.pop(0) + did_transform = False + for transform in transforms: + transformed_stmts = transform(stmt=stmt, declarations=declarations, uids=uids) + if transformed_stmts: + unprocessed_stmts = [*transformed_stmts, *unprocessed_stmts] + did_transform = True + break -def validate_no_dynamic_offsets(node: ir.Node) -> None: - """Vaidate we have no dynamic offsets, e.g. `shift(Ioff, deref(...))(...)`""" - for call_node in node.walk_values().if_isinstance(ir.FunCall): - assert isinstance(call_node, ir.FunCall) - if cpm.is_call_to(call_node, "shift"): - if any(not isinstance(arg, ir.OffsetLiteral) for arg in call_node.args): - raise NotImplementedError("Dynamic offsets not supported in temporary pass.") + # no transformation occurred + if not did_transform: + stmts.append(stmt) + return stmts -# TODO(tehrengruber): Add support for dynamic shifts (e.g. the distance is a symbol). This can be -# tricky: For every lift statement that is dynamically shifted we can not compute bounds anymore -# and hence also not extract as a temporary. -class CreateGlobalTmps(PreserveLocationVisitor, NodeTranslator): - """Main entry point for introducing global temporaries. - Transforms an existing iterator IR fencil into a fencil with global temporaries. +def create_global_tmps( + program: itir.Program, offset_provider: common.OffsetProvider +) -> itir.Program: """ + Given an `itir.Program` create temporaries for intermediate values. - def visit_FencilDefinition( - self, - node: ir.FencilDefinition, - *, - offset_provider: Mapping[str, Any], - extraction_heuristics: Optional[ - Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] - ] = None, - symbolic_sizes: Optional[dict[str, str]], - ) -> FencilWithTemporaries: - # Vaidate we have no dynamic offsets, e.g. `shift(Ioff, deref(...))(...)` - validate_no_dynamic_offsets(node) - # Split closures on lifted function calls and introduce temporaries - res = split_closures( - node, offset_provider=offset_provider, extraction_heuristics=extraction_heuristics - ) - # Prune unreferences closure inputs introduced in the previous step - res = PruneClosureInputs().visit(res) - # Prune unused temporaries possibly introduced in the previous step - res = prune_unused_temporaries(res) - # Perform an eta-reduction which should put all calls at the highest level of a closure - res = EtaReduction().visit(res) - # Perform a naive extent analysis to compute domain sizes of closures and temporaries - res = update_domains(res, offset_provider, symbolic_sizes) - # Use type inference to determine the data type of the temporaries - return collect_tmps_info(res, offset_provider=offset_provider) + This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its + arguments into temporaries. + """ + program = infer_domain.infer_program(program, offset_provider) + program = type_inference.infer(program, offset_provider=offset_provider) + + uids = eve_utils.UIDGenerator(prefix="__tmp") + declarations = program.declarations.copy() + new_body = [] + + for stmt in program.body: + assert isinstance(stmt, itir.SetAt) + new_body.extend(_transform_stmt(stmt, uids=uids, declarations=declarations)) + + return itir.Program( + id=program.id, + function_definitions=program.function_definitions, + params=program.params, + declarations=declarations, + body=new_body, + ) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index c1a743af1c..2a85e6f2cf 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -21,7 +21,7 @@ ir_makers as im, ) from gt4py.next.iterator.transforms import trace_shifts -from gt4py.next.utils import tree_map +from gt4py.next.utils import flatten_nested_tuple, tree_map DOMAIN: TypeAlias = domain_utils.SymbolicDomain | None | tuple["DOMAIN", ...] @@ -134,6 +134,9 @@ def infer_as_fieldop( assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") if target_domain is None: raise ValueError("'target_domain' cannot be 'None'.") + # FIXME[#1582](tehrengruber): Temporary solution for `tuple_get` on scan result. See `test_solve_triag`. + if isinstance(target_domain, tuple): + target_domain = _domain_union_with_none(*flatten_nested_tuple(target_domain)) if not isinstance(target_domain, domain_utils.SymbolicDomain): raise ValueError("'target_domain' needs to be a 'domain_utils.SymbolicDomain'.") @@ -157,23 +160,23 @@ def infer_as_fieldop( raise ValueError(f"Unsupported expression of type '{type(in_field)}'.") input_ids.append(id_) - accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( + inputs_accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( stencil, input_ids, target_domain, offset_provider ) # Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s + accessed_domains: ACCESSED_DOMAINS = {} transformed_inputs: list[itir.Expr] = [] for in_field_id, in_field in zip(input_ids, inputs): transformed_input, accessed_domains_tmp = infer_expr( - in_field, accessed_domains[in_field_id], offset_provider + in_field, inputs_accessed_domains[in_field_id], offset_provider ) transformed_inputs.append(transformed_input) accessed_domains = _merge_domains(accessed_domains, accessed_domains_tmp) - transformed_call = im.as_fieldop(stencil, domain_utils.SymbolicDomain.as_expr(target_domain))( - *transformed_inputs - ) + target_domain_expr = domain_utils.SymbolicDomain.as_expr(target_domain) + transformed_call = im.as_fieldop(stencil, target_domain_expr)(*transformed_inputs) accessed_domains_without_tmp = { k: v @@ -245,7 +248,8 @@ def infer_make_tuple( infered_arg_expr, actual_domains_arg = infer_expr(arg, domain[i], offset_provider) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) - return im.call(expr.fun)(*infered_args_expr), actual_domains + result_expr = im.call(expr.fun)(*infered_args_expr) + return result_expr, actual_domains def infer_tuple_get( @@ -255,12 +259,13 @@ def infer_tuple_get( ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "tuple_get") actual_domains: ACCESSED_DOMAINS = {} - idx, tuple_arg = expr.args - assert isinstance(idx, itir.Literal) - child_domain = tuple(None if i != int(idx.value) else domain for i in range(int(idx.value) + 1)) - infered_arg_expr, actual_domains_arg = infer_expr(tuple_arg, child_domain, offset_provider) + idx_expr, tuple_arg = expr.args + assert isinstance(idx_expr, itir.Literal) + idx = int(idx_expr.value) + tuple_domain = tuple(None if i != idx else domain for i in range(idx + 1)) + infered_arg_expr, actual_domains_arg = infer_expr(tuple_arg, tuple_domain, offset_provider) - infered_args_expr = im.tuple_get(idx.value, infered_arg_expr) + infered_args_expr = im.tuple_get(idx, infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) return infered_args_expr, actual_domains @@ -278,10 +283,11 @@ def infer_if( infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, offset_provider) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) - return im.call(expr.fun)(cond, *infered_args_expr), actual_domains + result_expr = im.call(expr.fun)(cond, *infered_args_expr) + return result_expr, actual_domains -def infer_expr( +def _infer_expr( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, @@ -310,6 +316,17 @@ def infer_expr( raise ValueError(f"Unsupported expression: {expr}") +def infer_expr( + expr: itir.Expr, + domain: DOMAIN, + offset_provider: common.OffsetProvider, +) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + # this is just a small wrapper that populates the `domain` annex + expr, accessed_domains = _infer_expr(expr, domain, offset_provider) + expr.annex.domain = domain + return expr, accessed_domains + + def infer_program( program: itir.Program, offset_provider: common.OffsetProvider, diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 8dd76b289b..b3bb7bc6e1 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -18,7 +18,6 @@ from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination from gt4py.next.iterator.transforms.eta_reduction import EtaReduction from gt4py.next.iterator.transforms.fuse_maps import FuseMaps -from gt4py.next.iterator.transforms.global_tmps import CreateGlobalTmps, FencilWithTemporaries from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas @@ -74,12 +73,14 @@ def apply_common_transforms( common_subexpression_elimination=True, force_inline_lambda_args=False, unconditionally_collapse_tuples=False, + # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place temporary_extraction_heuristics: Optional[ Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, + # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> itir.Program: - if isinstance(ir, (itir.FencilDefinition, FencilWithTemporaries)): + if isinstance(ir, itir.FencilDefinition): ir = fencil_to_program.FencilToProgram().apply( ir ) # FIXME[#1582](havogt): should be removed after refactoring to combined IR @@ -137,29 +138,30 @@ def apply_common_transforms( if lift_mode != LiftMode.FORCE_INLINE: # FIXME[#1582](tehrengruber): implement new temporary pass here raise NotImplementedError() - assert offset_provider is not None - ir = CreateGlobalTmps().visit( - ir, - offset_provider=offset_provider, - extraction_heuristics=temporary_extraction_heuristics, - symbolic_sizes=symbolic_domain_sizes, - ) - - for _ in range(10): - inlined = InlineLifts().visit(ir) - inlined = InlineLambdas.apply( - inlined, opcount_preserving=True, force_inline_lift_args=True - ) - if inlined == ir: - break - ir = inlined - else: - raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") - - # If after creating temporaries, the scan is not at the top, we inline. - # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. - # λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))` - ir = _inline_into_scan(ir) + # ruff: noqa: ERA001 + # assert offset_provider is not None + # ir = CreateGlobalTmps().visit( + # ir, + # offset_provider=offset_provider, + # extraction_heuristics=temporary_extraction_heuristics, + # symbolic_sizes=symbolic_domain_sizes, + # ) + # + # for _ in range(10): + # inlined = InlineLifts().visit(ir) + # inlined = InlineLambdas.apply( + # inlined, opcount_preserving=True, force_inline_lift_args=True + # ) + # if inlined == ir: + # break + # ir = inlined + # else: + # raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") + # + # # If after creating temporaries, the scan is not at the top, we inline. + # # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. + # # λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))` + # ir = _inline_into_scan(ir) # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index a13c7fb816..4640aa11d1 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -19,7 +19,6 @@ from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_call_to -from gt4py.next.iterator.transforms import global_tmps from gt4py.next.iterator.type_system import type_specifications as it_ts, type_synthesizer from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.type_system.type_info import primitive_constituents @@ -292,6 +291,8 @@ def type_synthesizer(*args, **kwargs): class SanitizeTypes(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): + PRESERVED_ANNEX_ATTRS = ("domain",) + def visit_Node(self, node: itir.Node, *, symtable: dict[str, itir.Node]) -> itir.Node: node = self.generic_visit(node) # We only want to sanitize types that have been inferred previously such that we don't run @@ -315,6 +316,8 @@ class ITIRTypeInference(eve.NodeTranslator): See :method:ITIRTypeInference.apply for more details. """ + PRESERVED_ANNEX_ATTRS = ("domain",) + offset_provider: common.OffsetProvider #: Mapping from a dimension name to the actual dimension instance. dimensions: dict[str, common.Dimension] @@ -466,28 +469,6 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition, *, ctx) -> it_ts.F closures = self.visit(node.closures, ctx=ctx | params | function_definitions) return it_ts.FencilType(params=params, closures=closures) - # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere - def visit_FencilWithTemporaries( - self, node: global_tmps.FencilWithTemporaries, *, ctx - ) -> it_ts.FencilType: - # TODO(tehrengruber): This implementation is not very appealing. Since we are about to - # refactor the IR anyway this is fine for now. - params: dict[str, ts.DataType] = {} - for param in node.params: - assert isinstance(param.type, ts.DataType) - params[param.id] = param.type - # infer types of temporary declarations - tmps: dict[str, ts.FieldType] = {} - for tmp_node in node.tmps: - tmps[tmp_node.id] = self.visit(tmp_node, ctx=ctx | params) - # and store them in the inner fencil - for fencil_param in node.fencil.params: - if fencil_param.id in tmps: - fencil_param.type = tmps[fencil_param.id] - self.visit(node.fencil, ctx=ctx) - assert isinstance(node.fencil.type, it_ts.FencilType) - return node.fencil.type - def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: params: dict[str, ts.DataType] = {} for param in node.params: diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 35db4cb7f2..2275576081 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -18,7 +18,6 @@ from gt4py.eve.utils import content_hash from gt4py.next import backend, common, config from gt4py.next.iterator import transforms -from gt4py.next.iterator.transforms import global_tmps from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler @@ -168,8 +167,9 @@ class Params: name_cached="_cached", ) use_temporaries = factory.Trait( + # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place otf_workflow__translation__lift_mode=transforms.LiftMode.USE_TEMPORARIES, - otf_workflow__translation__temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, + # otf_workflow__translation__temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, # noqa: ERA001 name_temps="_with_temporaries", ) device_type = core_defs.DeviceType.CPU diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 44fa929e56..7489908ba9 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -68,7 +68,12 @@ def flatten_nested_tuple( @overload -def tree_map(fun: Callable[_P, _R], /) -> Callable[..., _R | tuple[_R | tuple, ...]]: ... +def tree_map( + fun: Callable[_P, _R], + *, + collection_type: type | tuple[type, ...] = tuple, + result_collection_constructor: Optional[type | Callable] = None, +) -> Callable[..., _R | tuple[_R | tuple, ...]]: ... @overload @@ -82,7 +87,8 @@ def tree_map( def tree_map( - *args: Callable[_P, _R], + fun: Optional[Callable[_P, _R]] = None, + *, collection_type: type | tuple[type, ...] = tuple, result_collection_constructor: Optional[type | Callable] = None, ) -> Callable[..., _R | tuple[_R | tuple, ...]] | Callable[[Callable[_P, _R]], Callable[..., Any]]: @@ -108,6 +114,12 @@ def tree_map( ... [[1, 2], 3] ... ) ((2, 3), 4) + + >>> @tree_map + ... def impl(x): + ... return x + 1 + >>> impl(((1, 2), 3)) + ((2, 3), 4) """ if result_collection_constructor is None: @@ -117,8 +129,7 @@ def tree_map( ) result_collection_constructor = collection_type - if len(args) == 1: - fun = args[0] + if fun: @functools.wraps(fun) def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: @@ -129,17 +140,14 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: assert result_collection_constructor is not None return result_collection_constructor(impl(*arg) for arg in zip(*args)) - return fun( + return fun( # type: ignore[misc] # mypy not smart enough *cast(_P.args, args) ) # mypy doesn't understand that `args` at this point is of type `_P.args` return impl - if len(args) == 0: + else: return functools.partial( tree_map, collection_type=collection_type, result_collection_constructor=result_collection_constructor, ) - raise TypeError( - "tree_map() can be used as decorator with optional kwarg `collection_type` and `result_collection_constructor`." - ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index 78f95da8ca..3204b49371 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -74,16 +74,10 @@ def common_expr(): return im.plus("x", "x") # λ(x) → (λ(y) → y + (x + x + (x + x)))(z) - testee = im.lambda_("x")( - im.call(im.lambda_("y")(im.plus("y", im.plus(common_expr(), common_expr()))))("z") - ) - # λ(x) → (λ(_cs_1) → (λ(y) → y + (_cs_1 + _cs_1))(z))(x + x) + testee = im.lambda_("x")(im.let("y", "z")(im.plus("y", im.plus(common_expr(), common_expr())))) + # λ(x) → (λ(_cs_1) → z + (_cs_1 + _cs_1))(x + x) expected = im.lambda_("x")( - im.call( - im.lambda_("_cs_1")( - im.call(im.lambda_("y")(im.plus("y", im.plus("_cs_1", "_cs_1"))))("z") - ) - )(common_expr()) + im.let("_cs_1", common_expr())(im.plus("z", im.plus("_cs_1", "_cs_1"))) ) actual = CSE.apply(testee, is_local_view=True) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 79456e4d85..50756f40e7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -16,7 +16,7 @@ from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import infer_domain -from gt4py.next.iterator.transforms.global_tmps import SymbolicDomain +from gt4py.next.iterator.ir_utils import domain_utils from gt4py.next.common import Dimension from gt4py.next import common, NeighborTableOffsetProvider from gt4py.next.type_system import type_specifications as ts @@ -86,7 +86,7 @@ def run_test_expr( offset_provider: common.OffsetProvider, ): actual_call, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) folded_call = constant_fold_domain_exprs(actual_call) folded_domains = constant_fold_accessed_domains(actual_domains) if actual_domains else None @@ -122,7 +122,7 @@ def constant_fold_domain_exprs(arg: itir.Node) -> itir.Node: def constant_fold_accessed_domains( domains: infer_domain.ACCESSED_DOMAINS, ) -> infer_domain.ACCESSED_DOMAINS: - def fold_domain(domain: SymbolicDomain | None): + def fold_domain(domain: domain_utils.SymbolicDomain | None): if domain is None: return domain return constant_fold_domain_exprs(domain.as_expr()) @@ -134,7 +134,7 @@ def translate_domain( domain: itir.FunCall, shifts: dict[str, tuple[itir.Expr, itir.Expr]], offset_provider: common.OffsetProvider, -) -> SymbolicDomain: +) -> domain_utils.SymbolicDomain: shift_tuples = [ ( im.ensure_offset(d), @@ -145,7 +145,9 @@ def translate_domain( shift_list = [item for sublist in shift_tuples for item in sublist] - translated_domain_expr = SymbolicDomain.from_expr(domain).translate(shift_list, offset_provider) + translated_domain_expr = domain_utils.SymbolicDomain.from_expr(domain).translate( + shift_list, offset_provider + ) return constant_fold_domain_exprs(translated_domain_expr.as_expr()) @@ -330,7 +332,7 @@ def test_nested_stencils(offset_provider): "in_field2": translate_domain(domain, {"Ioff": 0, "Joff": -2}, offset_provider), } actual_call, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) folded_call = constant_fold_domain_exprs(actual_call) @@ -374,7 +376,7 @@ def test_nested_stencils_n_times(offset_provider, iterations): } actual_call, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -512,7 +514,7 @@ def test_cond(offset_provider): expected = im.if_(cond, expected_field_1, expected_field_2) actual_call, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -569,7 +571,7 @@ def test_let(offset_provider): expected_domains_sym = {"in_field": translate_domain(domain, {"Ioff": 2}, offset_provider)} actual_call2, actual_domains2 = infer_domain.infer_expr( - testee2, SymbolicDomain.from_expr(domain), offset_provider + testee2, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) folded_domains2 = constant_fold_accessed_domains(actual_domains2) folded_call2 = constant_fold_domain_exprs(actual_call2) @@ -789,7 +791,10 @@ def test_make_tuple(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), + ( + domain_utils.SymbolicDomain.from_expr(domain1), + domain_utils.SymbolicDomain.from_expr(domain2), + ), offset_provider, ) @@ -808,7 +813,7 @@ def test_tuple_get_1_make_tuple(offset_provider): } actual, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) assert expected == actual @@ -824,7 +829,10 @@ def test_tuple_get_1_nested_make_tuple(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), + ( + domain_utils.SymbolicDomain.from_expr(domain1), + domain_utils.SymbolicDomain.from_expr(domain2), + ), offset_provider, ) @@ -840,7 +848,9 @@ def test_tuple_get_let_arg_make_tuple(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - SymbolicDomain.from_expr(im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)})), + domain_utils.SymbolicDomain.from_expr( + im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + ), offset_provider, ) @@ -856,7 +866,7 @@ def test_tuple_get_let_make_tuple(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - SymbolicDomain.from_expr(domain), + domain_utils.SymbolicDomain.from_expr(domain), offset_provider, ) @@ -877,10 +887,13 @@ def test_nested_make_tuple(offset_provider): testee, ( ( - SymbolicDomain.from_expr(domain1), - (SymbolicDomain.from_expr(domain2_1), SymbolicDomain.from_expr(domain2_2)), + domain_utils.SymbolicDomain.from_expr(domain1), + ( + domain_utils.SymbolicDomain.from_expr(domain2_1), + domain_utils.SymbolicDomain.from_expr(domain2_2), + ), ), - SymbolicDomain.from_expr(domain3), + domain_utils.SymbolicDomain.from_expr(domain3), ), offset_provider, ) @@ -896,7 +909,7 @@ def test_tuple_get_1(offset_provider): expected_domains = {"a": (None, domain)} actual, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) assert expected == actual @@ -912,7 +925,10 @@ def test_domain_tuple(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), + ( + domain_utils.SymbolicDomain.from_expr(domain1), + domain_utils.SymbolicDomain.from_expr(domain2), + ), offset_provider, ) @@ -929,7 +945,7 @@ def test_as_fieldop_tuple_get(offset_provider): expected_domains = {"a": (domain, domain)} actual, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) assert expected == actual @@ -945,7 +961,10 @@ def test_make_tuple_2tuple_get(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), + ( + domain_utils.SymbolicDomain.from_expr(domain1), + domain_utils.SymbolicDomain.from_expr(domain2), + ), offset_provider, ) @@ -963,7 +982,7 @@ def test_make_tuple_non_tuple_domain(offset_provider): expected_domains = {"in_field1": domain, "in_field2": domain} actual, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) assert expected == actual @@ -971,15 +990,34 @@ def test_make_tuple_non_tuple_domain(offset_provider): def test_arithmetic_builtin(offset_provider): - testee = im.plus(im.ref("in_field1"), im.ref("in_field2")) + testee = im.plus(im.ref("alpha"), im.ref("beta")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - expected = im.plus(im.ref("in_field1"), im.ref("in_field2")) + expected = im.plus(im.ref("alpha"), im.ref("beta")) expected_domains = {} actual_call, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider ) folded_call = constant_fold_domain_exprs(actual_call) assert folded_call == expected assert actual_domains == expected_domains + + +def test_scan(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + testee = im.as_fieldop( + im.call("scan")(im.lambda_("init", "it")(im.deref(im.shift("Ioff", 1)("it"))), True, 0.0) + )("a") + expected = im.as_fieldop( + im.call("scan")(im.lambda_("init", "it")(im.deref(im.shift("Ioff", 1)("it"))), True, 0.0), + domain, + )("a") + + run_test_expr( + testee, + expected, + domain, + {"a": im.domain(common.GridType.CARTESIAN, {IDim: (1, 12)})}, + offset_provider, + ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index ffb5447684..23f62842c4 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -6,464 +6,219 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -# TODO(tehrengruber): add integration tests for temporaries starting from manually written -# itir. Currently we only test temporaries from frontend code which makes testing changes -# to anything related to temporaries tedious. -import copy +from typing import Optional -import gt4py.next as gtx -from gt4py.eve.utils import UIDs from gt4py.next import common -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms.global_tmps import ( - AUTO_DOMAIN, - FencilWithTemporaries, - SimpleTemporaryExtractionHeuristics, - collect_tmps_info, - split_closures, - update_domains, -) +from gt4py.next.iterator.transforms import global_tmps, infer_domain +from gt4py.next.iterator.type_system import inference as type_inference from gt4py.next.type_system import type_specifications as ts IDim = common.Dimension(value="IDim") JDim = common.Dimension(value="JDim") KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) -index_type = ts.ScalarType(kind=getattr(ts.ScalarKind, ir.INTEGER_INDEX_BUILTIN.upper())) +index_type = ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())) float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) index_field_type_factory = lambda dim: ts.FieldType(dims=[dim], dtype=index_type) -def test_split_closures(): - UIDs.reset_sequence() - testee = ir.FencilDefinition( - id="f", +def program_factory( + params: list[itir.Sym], + body: list[itir.SetAt], + declarations: Optional[list[itir.Temporary]] = None, +) -> itir.Program: + return itir.Program( + id="testee", function_definitions=[], - params=[ - im.sym("d", i_field_type), - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - ], - closures=[ - ir.StencilClosure( - domain=im.call("cartesian_domain")(), - stencil=im.lambda_("baz_inp")( - im.deref( - im.lift( - im.lambda_("bar_inp")( - im.deref( - im.lift(im.lambda_("foo_inp")(im.deref("foo_inp")))("bar_inp") - ) - ) - )("baz_inp") - ) - ), - output=im.ref("out"), - inputs=[im.ref("inp")], - ) - ], + params=params, + declarations=declarations or [], + body=body, ) - expected = ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - im.sym("d", i_field_type), - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - im.sym("_tmp_1", i_field_type), - im.sym("_tmp_2", i_field_type), - im.sym("_gtmp_auto_domain", ts.DeferredType(constraint=None)), - ], - closures=[ - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.lambda_("foo_inp")(im.deref("foo_inp")), - output=im.ref("_tmp_2"), - inputs=[im.ref("inp")], - ), - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.lambda_("bar_inp", "_tmp_2")(im.deref("_tmp_2")), - output=im.ref("_tmp_1"), - inputs=[im.ref("inp"), im.ref("_tmp_2")], - ), - ir.StencilClosure( - domain=im.call("cartesian_domain")(), - stencil=im.lambda_("baz_inp", "_tmp_1")(im.deref("_tmp_1")), - output=im.ref("out"), - inputs=[im.ref("inp"), im.ref("_tmp_1")], - ), - ], - ) - actual = split_closures(testee, offset_provider={}) - assert actual.tmps == [ - ir.Temporary(id="_tmp_1", dtype=float_type), - ir.Temporary(id="_tmp_2", dtype=float_type), - ] - assert actual.fencil == expected - -def test_split_closures_simple_heuristics(): - UIDs.reset_sequence() - testee = ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - im.sym("d", i_field_type), - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - ], - closures=[ - ir.StencilClosure( - domain=im.call("cartesian_domain")(), - stencil=im.lambda_("foo")( - im.let("lifted_it", im.lift(im.lambda_("bar")(im.deref("bar")))("foo"))( - im.plus(im.deref("lifted_it"), im.deref(im.shift("I", 1)("lifted_it"))) - ) - ), - output=im.ref("out"), - inputs=[im.ref("inp")], +def test_trivial(): + domain = im.domain("cartesian_domain", {IDim: (0, 1)}) + offset_provider = {} + testee = program_factory( + params=[im.sym("inp", i_field_type), im.sym("out", i_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", domain)(im.as_fieldop("deref", domain)("inp")), + domain=domain, ) ], ) + testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = infer_domain.infer_program(testee, offset_provider=offset_provider) - expected = ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - im.sym("d", i_field_type), - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - im.sym("_tmp_1", i_field_type), - im.sym("_gtmp_auto_domain", ts.DeferredType(constraint=None)), - ], - closures=[ - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.lambda_("bar")(im.deref("bar")), - output=im.ref("_tmp_1"), - inputs=[im.ref("inp")], + expected = program_factory( + params=[im.sym("inp", i_field_type), im.sym("out", i_field_type)], + declarations=[itir.Temporary(id="__tmp_1", domain=domain, dtype=float_type)], + body=[ + itir.SetAt( + target=im.ref("__tmp_1"), expr=im.as_fieldop("deref", domain)("inp"), domain=domain ), - ir.StencilClosure( - domain=im.call("cartesian_domain")(), - stencil=im.lambda_("foo", "_tmp_1")( - im.plus(im.deref("_tmp_1"), im.deref(im.shift("I", 1)("_tmp_1"))) - ), - output=im.ref("out"), - inputs=[im.ref("inp"), im.ref("_tmp_1")], + itir.SetAt( + target=im.ref("out"), expr=im.as_fieldop("deref", domain)("__tmp_1"), domain=domain ), ], ) - actual = split_closures( - testee, - extraction_heuristics=SimpleTemporaryExtractionHeuristics, - offset_provider={"I": IDim}, - ) - assert actual.tmps == [ir.Temporary(id="_tmp_1", dtype=float_type)] - assert actual.fencil == expected + actual = global_tmps.create_global_tmps(testee, offset_provider) + assert actual == expected -def test_split_closures_lifted_scan(): - UIDs.reset_sequence() - testee = ir.FencilDefinition( - id="f", - function_definitions=[], +def test_trivial_let(): + domain = im.domain("cartesian_domain", {IDim: (0, 1)}) + offset_provider = {} + testee = program_factory( params=[im.sym("inp", i_field_type), im.sym("out", i_field_type)], - closures=[ - ir.StencilClosure( - domain=im.call("cartesian_domain")(), - stencil=im.lambda_("a")( - im.call( - im.call("scan")( - im.lambda_("carry", "b")(im.plus("carry", im.deref("b"))), - True, - im.literal_from_value(0.0), - ) - )( - im.lift( - im.call("scan")( - im.lambda_("carry", "c")(im.plus("carry", im.deref("c"))), - False, - im.literal_from_value(0.0), - ) - )("a") - ) + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.let("tmp", im.as_fieldop("deref", domain)("inp"))( + im.as_fieldop("deref", domain)("tmp") ), - output=im.ref("out"), - inputs=[im.ref("inp")], + domain=domain, ) ], ) + testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = infer_domain.infer_program(testee, offset_provider=offset_provider) - expected = ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - im.sym("_tmp_1", i_field_type), - im.sym("_gtmp_auto_domain", ts.DeferredType(constraint=None)), - ], - closures=[ - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.call("scan")( - im.lambda_("carry", "c")(im.plus("carry", im.deref("c"))), - False, - im.literal_from_value(0.0), - ), - output=im.ref("_tmp_1"), - inputs=[im.ref("inp")], + expected = program_factory( + params=[im.sym("inp", i_field_type), im.sym("out", i_field_type)], + declarations=[itir.Temporary(id="__tmp_1", domain=domain, dtype=float_type)], + body=[ + itir.SetAt( + target=im.ref("__tmp_1"), expr=im.as_fieldop("deref", domain)("inp"), domain=domain ), - ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=im.lambda_("a", "_tmp_1")( - im.call( - im.call("scan")( - im.lambda_("carry", "b")(im.plus("carry", im.deref("b"))), - True, - im.literal_from_value(0.0), - ) - )("_tmp_1") - ), - output=im.ref("out"), - inputs=[im.ref("inp"), im.ref("_tmp_1")], + itir.SetAt( + target=im.ref("out"), expr=im.as_fieldop("deref", domain)("__tmp_1"), domain=domain ), ], ) - actual = split_closures(testee, offset_provider={}) - assert actual.tmps == [ir.Temporary(id="_tmp_1", dtype=float_type)] - assert actual.fencil == expected + actual = global_tmps.create_global_tmps(testee, offset_provider) + assert actual == expected -def test_update_cartesian_domains(): - testee = FencilWithTemporaries( - fencil=ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - im.sym("i", index_type), - im.sym("j", index_type), - im.sym("k", index_type), - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - im.sym("_gtmp_0", i_field_type), - im.sym("_gtmp_1", i_field_type), - im.sym("_gtmp_auto_domain", ts.DeferredType(constraint=None)), - ], - closures=[ - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.lambda_("foo_inp")(im.deref("foo_inp")), - output=im.ref("_gtmp_1"), - inputs=[im.ref("inp")], - ), - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.ref("deref"), - output=im.ref("_gtmp_0"), - inputs=[im.ref("_gtmp_1")], - ), - ir.StencilClosure( - domain=im.call("cartesian_domain")( - *( - im.call("named_range")( - ir.AxisLiteral(value=a), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(s), - ) - for a, s in (("IDim", "i"), ("JDim", "j"), ("KDim", "k")) - ) - ), - stencil=im.lambda_("baz_inp", "_lift_2")(im.deref(im.shift("I", 1)("_lift_2"))), - output=im.ref("out"), - inputs=[im.ref("inp"), im.ref("_gtmp_0")], +def test_top_level_if(): + domain = im.domain("cartesian_domain", {IDim: (0, 1)}) + offset_provider = {} + testee = program_factory( + params=[ + im.sym("inp1", i_field_type), + im.sym("inp2", i_field_type), + im.sym("out", i_field_type), + ], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.if_( + True, + im.as_fieldop("deref", domain)("inp1"), + im.as_fieldop("deref", domain)("inp2"), ), - ], - ), - params=[im.sym("i"), im.sym("j"), im.sym("k"), im.sym("inp"), im.sym("out")], - tmps=[ir.Temporary(id="_gtmp_0"), ir.Temporary(id="_gtmp_1")], - ) - expected = copy.deepcopy(testee) - assert expected.fencil.params.pop() == im.sym("_gtmp_auto_domain") - expected.fencil.closures[0].domain = ir.FunCall( - fun=im.ref("cartesian_domain"), - args=[ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value="IDim"), - im.plus( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.literal("1", ir.INTEGER_INDEX_BUILTIN), - ), - im.plus(im.ref("i"), im.literal("1", ir.INTEGER_INDEX_BUILTIN)), - ], - ) - ] - + [ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value=a), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(s), - ], + domain=domain, ) - for a, s in (("JDim", "j"), ("KDim", "k")) ], ) - expected.fencil.closures[1].domain = ir.FunCall( - fun=im.ref("cartesian_domain"), - args=[ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value="IDim"), - im.plus( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.literal("1", ir.INTEGER_INDEX_BUILTIN), - ), - im.plus(im.ref("i"), im.literal("1", ir.INTEGER_INDEX_BUILTIN)), + testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = infer_domain.infer_program(testee, offset_provider=offset_provider) + + expected = program_factory( + params=[ + im.sym("inp1", i_field_type), + im.sym("inp2", i_field_type), + im.sym("out", i_field_type), + ], + declarations=[], + body=[ + itir.IfStmt( + cond=im.literal_from_value(True), + true_branch=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", domain)("inp1"), + domain=domain, + ) ], - ) - ] - + [ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value=a), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(s), + false_branch=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", domain)("inp2"), + domain=domain, + ) ], ) - for a, s in (("JDim", "j"), ("KDim", "k")) ], ) - actual = update_domains(testee, {"I": gtx.Dimension("IDim")}, symbolic_sizes=None) + + actual = global_tmps.create_global_tmps(testee, offset_provider) assert actual == expected -def test_collect_tmps_info(): - tmp_domain = ir.FunCall( - fun=im.ref("cartesian_domain"), - args=[ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value="IDim"), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - ir.FunCall( - fun=im.ref("plus"), - args=[im.ref("i"), im.literal("1", ir.INTEGER_INDEX_BUILTIN)], - ), - ], - ) - ] - + [ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value=a), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(s), - ], +def test_nested_if(): + domain = im.domain("cartesian_domain", {IDim: (0, 1)}) + offset_provider = {} + testee = program_factory( + params=[ + im.sym("inp1", i_field_type), + im.sym("inp2", i_field_type), + im.sym("out", i_field_type), + ], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", domain)( + im.if_( + True, + im.as_fieldop("deref", domain)("inp1"), + im.as_fieldop("deref", domain)("inp2"), + ) + ), + domain=domain, ) - for a, s in (("JDim", "j"), ("KDim", "k")) ], ) + testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = infer_domain.infer_program(testee, offset_provider=offset_provider) - i = im.sym("i", index_type) - j = im.sym("j", index_type) - k = im.sym("k", index_type) - inp = im.sym("inp", i_field_type) - out = im.sym("out", i_field_type) - - testee = FencilWithTemporaries( - fencil=ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - i, - j, - k, - inp, - out, - im.sym("_gtmp_0", i_field_type), - im.sym("_gtmp_1", i_field_type), - ], - closures=[ - ir.StencilClosure( - domain=tmp_domain, - stencil=ir.Lambda( - params=[ir.Sym(id="foo_inp")], - expr=ir.FunCall(fun=im.ref("deref"), args=[im.ref("foo_inp")]), - ), - output=im.ref("_gtmp_1"), - inputs=[im.ref("inp")], - ), - ir.StencilClosure( - domain=tmp_domain, - stencil=im.ref("deref"), - output=im.ref("_gtmp_0"), - inputs=[im.ref("_gtmp_1")], - ), - ir.StencilClosure( - domain=ir.FunCall( - fun=im.ref("cartesian_domain"), - args=[ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value=a), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(s), - ], - ) - for a, s in (("IDim", "i"), ("JDim", "j"), ("KDim", "k")) - ], - ), - stencil=ir.Lambda( - params=[ir.Sym(id="baz_inp"), ir.Sym(id="_lift_2")], - expr=ir.FunCall( - fun=im.ref("deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=im.ref("shift"), - args=[ - ir.OffsetLiteral(value="I"), - ir.OffsetLiteral(value=1), - ], - ), - args=[im.ref("_lift_2")], - ) - ], - ), - ), - output=im.ref("out"), - inputs=[im.ref("inp"), im.ref("_gtmp_0")], - ), - ], - ), - params=[i, j, k, inp, out], - tmps=[ - ir.Temporary(id="_gtmp_0", dtype=float_type), - ir.Temporary(id="_gtmp_1", dtype=float_type), + expected = program_factory( + params=[ + im.sym("inp1", i_field_type), + im.sym("inp2", i_field_type), + im.sym("out", i_field_type), ], - ) - expected = FencilWithTemporaries( - fencil=testee.fencil, - params=testee.params, - tmps=[ - ir.Temporary(id="_gtmp_0", domain=tmp_domain, dtype=float_type), - ir.Temporary(id="_gtmp_1", domain=tmp_domain, dtype=float_type), + declarations=[itir.Temporary(id="__tmp_1", domain=domain, dtype=float_type)], + body=[ + itir.IfStmt( + cond=im.literal_from_value(True), + true_branch=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop("deref", domain)("inp1"), + domain=domain, + ) + ], + false_branch=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop("deref", domain)("inp2"), + domain=domain, + ) + ], + ), + itir.SetAt( + target=im.ref("out"), expr=im.as_fieldop("deref", domain)("__tmp_1"), domain=domain + ), ], ) - actual = collect_tmps_info(testee, offset_provider={"I": IDim, "J": JDim, "K": KDim}) + + actual = global_tmps.create_global_tmps(testee, offset_provider) assert actual == expected diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index ab86dda16b..3d82dd8ee5 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -50,20 +50,6 @@ def test_backend_factory_trait_cached(): assert cached_version.name == "run_gtfn_cpu_cached" -def test_backend_factory_trait_temporaries(): - inline_version = gtfn.GTFNBackendFactory(cached=False) - temps_version = gtfn.GTFNBackendFactory(cached=False, use_temporaries=True) - - assert inline_version.executor.translation.lift_mode is None - assert temps_version.executor.translation.lift_mode is transforms.LiftMode.USE_TEMPORARIES - - assert inline_version.executor.translation.temporary_extraction_heuristics is None - assert ( - temps_version.executor.translation.temporary_extraction_heuristics - is global_tmps.SimpleTemporaryExtractionHeuristics - ) - - def test_backend_factory_build_cache_config(monkeypatch): monkeypatch.setattr(config, "BUILD_CACHE_LIFETIME", config.BuildCacheLifetime.SESSION) session_version = gtfn.GTFNBackendFactory() From 7f7a866885756f3b5febfd1b11f726595aafcf41 Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 23 Oct 2024 11:26:58 +0200 Subject: [PATCH 010/178] feat[next][dace]: Cleanup scalar args (#1695) This PR simplifies the representation of program scalar arguments in the SDFG: instead of promoting them to symbols, we represent them as scalar data. Scalar to symbol promotion is a transformation pass available in DaCe that should be applied after lowering to SDFG, eventually. Additional changes: - Fixes a problem in naming of input connectors for scalar expression tasklets: the connector name cannot match the argument name, otherwise we cannot pass the same value to two connectors in expressions like `out = tmp * tmp`. - Only propagate to lambda scope the symbols that are referenced. --- .../gtir_builtin_translators.py | 11 +- .../runners/dace_fieldview/gtir_dataflow.py | 2 +- .../runners/dace_fieldview/gtir_sdfg.py | 163 +++++++++++++----- .../dace_tests/test_gtir_to_sdfg.py | 49 +++++- 4 files changed, 164 insertions(+), 61 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index a8ae1cc0e8..3cd2b17b88 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -481,13 +481,12 @@ def _get_data_nodes( sym_node = state.add_access(sym_name) return Field(sym_node, sym_type) elif isinstance(sym_type, ts.ScalarType): - if sym_name in sdfg.arrays: - # access the existing scalar container - sym_node = state.add_access(sym_name) - else: + if sym_name in sdfg.symbols: sym_node = _get_symbolic_value( sdfg, state, sdfg_builder, sym_name, sym_type, temp_name=f"__{sym_name}" ) + else: + sym_node = state.add_access(sym_name) return Field(sym_node, sym_type) elif isinstance(sym_type, ts.TupleType): tuple_fields = dace_gtir_utils.get_tuple_fields(sym_name, sym_type) @@ -612,7 +611,7 @@ def translate_scalar_expr( connectors = [] scalar_expr_args = [] - for arg_expr in node.args: + for i, arg_expr in enumerate(node.args): visit_expr = True if isinstance(arg_expr, gtir.SymRef): try: @@ -636,7 +635,7 @@ def translate_scalar_expr( ) if not (isinstance(arg, Field) and isinstance(arg.data_type, ts.ScalarType)): raise ValueError(f"Invalid argument to scalar expression {arg_expr}.") - param = f"__in_{arg.data_node.data}" + param = f"__arg{i}" args.append(arg.data_node) connectors.append(param) scalar_expr_args.append(gtir.SymRef(id=param)) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 0e571fc17d..fdd92b57c1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -939,7 +939,7 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> DataExpr: arg_expr = self.visit(arg) if isinstance(arg_expr, MemletExpr | DataExpr): # the argument value is the result of a tasklet node or direct field access - connector = f"__inp_{i}" + connector = f"__arg{i}" node_connections[connector] = arg_expr node_internals.append(connector) else: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 3697609d76..31e561a19c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -17,6 +17,7 @@ import abc import dataclasses import itertools +import operator from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union import dace @@ -26,7 +27,7 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.iterator.transforms import prune_casts as ir_prune_casts +from gt4py.next.iterator.transforms import prune_casts as ir_prune_casts, symbol_ref_utils from gt4py.next.iterator.type_system import inference as gtir_type_inference from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -106,6 +107,38 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: ... +def _collect_symbols_in_domain_expressions( + ir: gtir.Node, ir_params: Sequence[gtir.Sym] +) -> set[str]: + """ + Collect symbols accessed in domain expressions that also appear in the paremeter list. + + This function is used to identify all parameters that are accessed in domain + expressions. They have to be passed to the SDFG call as DaCe symbols (instead + of scalars) such that they can be used as bounds in map ranges. + + Args: + ir: GTIR node to be traversed and where to search for domain expressions. + ir_params: List of parameters to search for in domain expressions. + + Returns: + A set of names corresponding to the parameters found in domain expressions. + """ + params = {str(sym.id) for sym in ir_params} + return set( + eve.walk_values(ir) + .filter(lambda node: cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain"))) + .map( + lambda domain: eve.walk_values(domain) + .if_isinstance(gtir.SymRef) + .map(lambda symref: str(symref.id)) + .filter(lambda sym: sym in params) + .to_list() + ) + .reduce(operator.add, init=[]) + ) + + @dataclasses.dataclass(frozen=True) class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. @@ -172,60 +205,79 @@ def _make_array_shape_and_strides( def _add_storage( self, sdfg: dace.SDFG, + symbolic_arguments: set[str], name: str, - symbol_type: ts.DataType, + gt_type: ts.DataType, transient: bool = True, - is_tuple_member: bool = False, ) -> list[tuple[str, ts.DataType]]: """ - Add storage for data containers used in the SDFG. For fields, it allocates dace arrays, - while scalars are stored as SDFG symbols. + Add storage in the SDFG for a given GT4Py data symbol. - The fields used as temporary arrays, when `transient = True`, are allocated and exist - only within the SDFG; when `transient = False`, the fields have to be allocated outside - and have to be passed as array arguments to the SDFG. + GT4Py fields are allocated as DaCe arrays. GT4Py scalars are represented + as DaCe scalar objects in the SDFG; the exception are the symbols passed as + `symbolic_arguments`, e.g. symbols used in domain expressions, and those used + for symbolic array shape and strides. + + The fields used as temporary arrays, when `transient = True`, are allocated + and exist only within the SDFG; when `transient = False`, the fields have + to be allocated outside and have to be passed as arguments to the SDFG call. + + Args: + sdfg: The SDFG where storage needs to be allocated. + symbolic_arguments: Set of GT4Py scalars that must be represented as SDFG symbols. + name: Symbol Name to be allocated. + gt_type: GT4Py symbol type. + transient: True when the data symbol has to be allocated as internal storage. Returns: - List of data containers or symbols allocated as storage. This is a list, not a single value, - because in case of tuples we flat the tuple fields (eventually nested) and allocate storage - for each tuple element. + List of tuples '(data_name, gt_type)' where 'data_name' is the name of + the data container used as storage in the SDFG and 'gt_type' is the + corresponding GT4Py type. In case the storage has to be allocated for + a tuple symbol the list contains a flattened version of the tuple, + otherwise the list will contain a single entry. """ - if isinstance(symbol_type, ts.TupleType): + if isinstance(gt_type, ts.TupleType): tuple_fields = [] for tname, tsymbol_type in dace_gtir_utils.get_tuple_fields( - name, symbol_type, flatten=True + name, gt_type, flatten=True ): tuple_fields.extend( - self._add_storage(sdfg, tname, tsymbol_type, transient, is_tuple_member=True) + self._add_storage(sdfg, symbolic_arguments, tname, tsymbol_type, transient) ) return tuple_fields - elif isinstance(symbol_type, ts.FieldType): - dtype = dace_utils.as_dace_type(symbol_type.dtype) + elif isinstance(gt_type, ts.FieldType): + dtype = dace_utils.as_dace_type(gt_type.dtype) # use symbolic shape, which allows to invoke the program with fields of different size; # and symbolic strides, which enables decoupling the memory layout from generated code. - sym_shape, sym_strides = self._make_array_shape_and_strides(name, symbol_type.dims) + sym_shape, sym_strides = self._make_array_shape_and_strides(name, gt_type.dims) sdfg.add_array(name, sym_shape, dtype, strides=sym_strides, transient=transient) - return [(name, symbol_type)] + return [(name, gt_type)] - elif isinstance(symbol_type, ts.ScalarType): - dtype = dace_utils.as_dace_type(symbol_type) - # Scalar arguments passed to the program are represented as symbols in DaCe SDFG; - # the exception are members of tuple arguments, that are represented as scalar containers. - # The field size is sometimes passed as scalar argument to the program, so we have to - # check if the shape symbol was already allocated by `_make_array_shape_and_strides`. - # We assume that the scalar argument for field size always follows the field argument. - if is_tuple_member: - sdfg.add_scalar(name, dtype, transient=transient) - elif name in sdfg.symbols: - assert sdfg.symbols[name].dtype == dtype - else: + elif isinstance(gt_type, ts.ScalarType): + dtype = dace_utils.as_dace_type(gt_type) + if name in symbolic_arguments: sdfg.add_symbol(name, dtype) + elif dace_utils.is_field_symbol(name): + # Sometimes, when the field domain is implicitly derived from the + # field domain, the gt4py lowering adds the field size as a scalar + # argument to the program IR. Suppose a field '__sym', then gt4py + # will add '__sym_size_0'. + # Therefore, here we check whether the shape symbol was already + # created by `_make_array_shape_and_strides`, when allocating + # storage for field arguments. We assume that the scalar argument + # for field size, if present, always follows the field argument. + if name in sdfg.symbols: + assert sdfg.symbols[name].dtype == dtype + else: + sdfg.add_symbol(name, dtype) + else: + sdfg.add_scalar(name, dtype, transient=transient) - return [(name, symbol_type)] + return [(name, gt_type)] - raise RuntimeError(f"Data type '{type(symbol_type)}' not supported.") + raise RuntimeError(f"Data type '{type(gt_type)}' not supported.") def _add_storage_for_temporary(self, temp_decl: gtir.Temporary) -> dict[str, str]: """ @@ -246,10 +298,6 @@ def _visit_expression( Returns: A list of array nodes containing the result fields. - - TODO: Do we need to return the GT4Py `FieldType`/`ScalarType`? It is needed - in case the transient arrays containing the expression result are not guaranteed - to have the same memory layout as the target array. """ result = self.visit(node, sdfg=sdfg, head_state=head_state, reduce_identity=None) @@ -274,14 +322,28 @@ def make_temps(field: gtir_builtin_translators.Field) -> gtir_builtin_translator temp_result = gtx_utils.tree_map(make_temps)(result) return list(gtx_utils.flatten_nested_tuple((temp_result,))) - def _add_sdfg_params(self, sdfg: dace.SDFG, node_params: Sequence[gtir.Sym]) -> list[str]: - """Helper function to add storage for node parameters and connectivity tables.""" + def _add_sdfg_params( + self, + sdfg: dace.SDFG, + node_params: Sequence[gtir.Sym], + symbolic_arguments: set[str], + ) -> list[str]: + """ + Helper function to add storage for node parameters and connectivity tables. + + GT4Py field arguments will be translated to `dace.data.Array` objects. + GT4Py scalar arguments will be translated to `dace.data.Scalar` objects, + except when they are listed in 'symbolic_arguments', in which case they + will be represented in the SDFG as DaCe symbols. + """ # add non-transient arrays and/or SDFG symbols for the program arguments sdfg_args = [] for param in node_params: pname = str(param.id) assert isinstance(param.type, (ts.DataType)) - sdfg_args += self._add_storage(sdfg, pname, param.type, transient=False) + sdfg_args += self._add_storage( + sdfg, symbolic_arguments, pname, param.type, transient=False + ) self.global_symbols[pname] = param.type # add SDFG storage for connectivity tables @@ -290,7 +352,7 @@ def _add_sdfg_params(self, sdfg: dace.SDFG, node_params: Sequence[gtir.Sym]) -> ).items(): scalar_kind = tt.get_scalar_kind(offset_provider.index_type) local_dim = gtx_common.Dimension(offset, kind=gtx_common.DimensionKind.LOCAL) - type_ = ts.FieldType( + gt_type = ts.FieldType( [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) ) # We store all connectivity tables as transient arrays here; later, while building @@ -298,7 +360,9 @@ def _add_sdfg_params(self, sdfg: dace.SDFG, node_params: Sequence[gtir.Sym]) -> # the tables that are actually used. This way, we avoid adding SDFG arguments for # the connectivity tables that are not used. The remaining unused transient arrays # are removed by the dace simplify pass. - self._add_storage(sdfg, dace_utils.connectivity_identifier(offset), type_) + self._add_storage( + sdfg, symbolic_arguments, dace_utils.connectivity_identifier(offset), gt_type + ) # the list of all sdfg arguments (aka non-transient arrays) which include tuple-element fields return [arg_name for arg_name, _ in sdfg_args] @@ -330,7 +394,8 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: else: head_state = entry_state - sdfg_arg_names = self._add_sdfg_params(sdfg, node.params) + domain_symbols = _collect_symbols_in_domain_expressions(node, node.params) + sdfg_arg_names = self._add_sdfg_params(sdfg, node.params, symbolic_arguments=domain_symbols) # visit one statement at a time and expand the SDFG from the current head state for i, stmt in enumerate(node.body): @@ -490,7 +555,10 @@ def visit_Lambda( ] # inherit symbols from parent scope but eventually override with local symbols - lambda_symbols = self.global_symbols | { + lambda_symbols = { + sym: self.global_symbols[sym] + for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) + } | { pname: dace_gtir_utils.get_tuple_type(arg) if isinstance(arg, tuple) else arg.data_type for pname, arg in lambda_args_mapping } @@ -501,11 +569,12 @@ def visit_Lambda( nstate = nsdfg.add_state("lambda") # add sdfg storage for the symbols that need to be passed as input parameters + lambda_params = [ + gtir.Sym(id=p_name, type=p_type) for p_name, p_type in lambda_symbols.items() + ] + lambda_domain_symbols = _collect_symbols_in_domain_expressions(node.expr, lambda_params) lambda_translator._add_sdfg_params( - nsdfg, - node_params=[ - gtir.Sym(id=p_name, type=p_type) for p_name, p_type in lambda_symbols.items() - ], + nsdfg, node_params=lambda_params, symbolic_arguments=lambda_domain_symbols ) lambda_result = lambda_translator.visit( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 728b4b02b9..230ff695fa 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -819,9 +819,7 @@ def test_gtir_cartesian_shift_left(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) - FSYMBOLS_tmp = FSYMBOLS.copy() - FSYMBOLS_tmp["__x_offset_stride_0"] = 1 - sdfg(a, a_offset, b, **FSYMBOLS_tmp) + sdfg(a, a_offset, b, **FSYMBOLS, __x_offset_size_0=N, __x_offset_stride_0=1) assert np.allclose(a[OFFSET:] + DELTA, b[:-OFFSET]) @@ -914,7 +912,7 @@ def test_gtir_cartesian_shift_right(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) - sdfg(a, a_offset, b, **FSYMBOLS, __x_offset_stride_0=1) + sdfg(a, a_offset, b, **FSYMBOLS, __x_offset_size_0=N, __x_offset_stride_0=1) assert np.allclose(a[:-OFFSET] + DELTA, b[OFFSET:]) @@ -1072,7 +1070,9 @@ def test_gtir_connectivity_shift(): __ev_field_size_1=SIMPLE_MESH.num_vertices, __ev_field_stride_0=SIMPLE_MESH.num_vertices, __ev_field_stride_1=1, + __c2e_offset_size_0=SIMPLE_MESH.num_cells, __c2e_offset_stride_0=1, + __e2v_offset_size_0=SIMPLE_MESH.num_edges, __e2v_offset_stride_0=1, ) assert np.allclose(ce, ref) @@ -1592,6 +1592,41 @@ def test_gtir_let_lambda(): assert np.allclose(b, ref) +def test_gtir_let_lambda_scalar_expression(): + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + testee = gtir.Program( + id="let_lambda_scalar_expression", + function_definitions=[], + params=[ + gtir.Sym(id="a", type=IFTYPE.dtype), + gtir.Sym(id="b", type=IFTYPE.dtype), + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let("tmp", im.multiplies_("a", "b"))( + im.op_as_fieldop("multiplies", domain)("x", im.multiplies_("tmp", "tmp")) + ), + domain=domain, + target=gtir.SymRef(id="y"), + ) + ], + ) + + a = np.random.rand() + b = np.random.rand() + c = np.random.rand(N) + d = np.empty_like(c) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + + sdfg(a, b, c, d, **FSYMBOLS) + assert np.allclose(d, (a * a * b * b * c)) + + def test_gtir_let_lambda_with_connectivity(): C2E_neighbor_idx = 1 C2V_neighbor_idx = 2 @@ -1757,9 +1792,9 @@ def test_gtir_if_scalars(): body=[ gtir.SetAt( expr=im.let("f", im.tuple_get(0, "x"))( - im.let("y", im.tuple_get(1, "x"))( - im.let("y_0", im.tuple_get(0, "y"))( - im.let("y_1", im.tuple_get(1, "y"))( + im.let("g", im.tuple_get(1, "x"))( + im.let("y_0", im.tuple_get(0, "g"))( + im.let("y_1", im.tuple_get(1, "g"))( im.op_as_fieldop("plus", domain)( "f", im.if_( From 4eb4d4d0c768b39cac4d945c936b57f18b699eaa Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 23 Oct 2024 12:24:28 +0200 Subject: [PATCH 011/178] feat[next][dace]: GTIR-to-DaCe lowering of map-reduce with skip values (#1694) This PR extends the solution for map-reduce provided in #1683 with the support for connectivity tables with skip values. The field definition is extended with a `local_offset` attribute that stores the offset provider used to build the values in the local dimension. In case the local dimension is built by the `neighbors` expression, the `local_offset` corresponds to the offset provider used to access the neighbor dimension. Since this information is carried along the data itself, whenever the data is accessed it is also possible to access the corresponding offset provider and check whether the neighbor index is valid or if there is a skip value. For local dimensions already present in the program argument, this information is retrieved from the field domain (enabled in new test case). The data is accessed in the `map_` and `reduce` expressions. Here it is now possible to check for skip values. Therefore, the main objective of this PR is the lowering of map-reduce with skip values. A secondary objective is to pave the road to simplify the lowering logic, by getting rid of the `reduce_identity` value. The current approach is propagate the `reduce_identity` value while visiting the arguments to `reduce` expressions. By introducing `local_offset`, the argument visitor will return the information needed to implement `reduce` of local values in presence of skip values. --- .../gtir_builtin_translators.py | 110 +++--- .../runners/dace_fieldview/gtir_dataflow.py | 325 +++++++++++------- .../runners/dace_fieldview/gtir_sdfg.py | 110 +++--- .../runners/dace_fieldview/utility.py | 10 +- .../dace_tests/test_gtir_to_sdfg.py | 82 +++-- 5 files changed, 404 insertions(+), 233 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 3cd2b17b88..277d8a0cd8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -34,9 +34,25 @@ @dataclasses.dataclass(frozen=True) -class Field: - data_node: dace.nodes.AccessNode - data_type: ts.FieldType | ts.ScalarType +class FieldopData: + """ + Abstraction to represent data (scalars, arrays) during the lowering to SDFG. + + Attribute 'local_offset' must always be set for `FieldType` data with a local + dimension generated from neighbors access in unstructured domain, and indicates + the name of the offset provider used to generate the list of neighbor values. + + Args: + dc_node: DaCe access node to the data storage. + gt_dtype: GT4Py type definition, which includes the field domain information. + local_offset: Provides information about the local dimension in`FieldType` data. + Set to 'None' for scalar data. Can be 'None' for `FieldType` data with + only global (horizontal or vertical) dimensions. + """ + + dc_node: dace.nodes.AccessNode + gt_dtype: ts.FieldType | ts.ScalarType + local_offset: Optional[str] FieldopDomain: TypeAlias = list[ @@ -50,7 +66,7 @@ class Field: """ -FieldopResult: TypeAlias = Field | tuple[Field | tuple, ...] +FieldopResult: TypeAlias = FieldopData | tuple[FieldopData | tuple, ...] """Result of a field operator, can be either a field or a tuple fields.""" @@ -73,7 +89,7 @@ def __call__( This method is used by derived classes to build a specialized subgraph for a specific GTIR primitive function. - Arguments: + Args: node: The GTIR node describing the primitive to be lowered sdfg: The SDFG where the primitive subgraph should be instantiated state: The SDFG state where the result of the primitive function should be made available @@ -108,24 +124,24 @@ def _parse_fieldop_arg( ) # arguments passed to field operator should be plain fields, not tuples of fields - if not isinstance(arg, Field): + if not isinstance(arg, FieldopData): raise ValueError(f"Received {node} as argument to field operator, expected a field.") - if isinstance(arg.data_type, ts.ScalarType): - return gtir_dataflow.MemletExpr(arg.data_node, sbs.Indices([0])) - elif isinstance(arg.data_type, ts.FieldType): - indices: dict[gtx_common.Dimension, gtir_dataflow.ValueExpr] = { + if isinstance(arg.gt_dtype, ts.ScalarType): + return gtir_dataflow.MemletExpr(arg.dc_node, sbs.Indices([0])) + elif isinstance(arg.gt_dtype, ts.FieldType): + indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { dim: gtir_dataflow.SymbolExpr(dace_gtir_utils.get_map_variable(dim), INDEX_DTYPE) for dim, _, _ in domain } - dims = arg.data_type.dims + ( + dims = arg.gt_dtype.dims + ( # we add an extra anonymous dimension in the iterator definition to enable # dereferencing elements in `ListType` - [gtx_common.Dimension("")] if isinstance(arg.data_type.dtype, itir_ts.ListType) else [] + [gtx_common.Dimension("")] if isinstance(arg.gt_dtype.dtype, itir_ts.ListType) else [] ) - return gtir_dataflow.IteratorExpr(arg.data_node, dims, indices) + return gtir_dataflow.IteratorExpr(arg.dc_node, dims, indices, arg.local_offset) else: - raise NotImplementedError(f"Node type {type(arg.data_type)} not supported.") + raise NotImplementedError(f"Node type {type(arg.gt_dtype)} not supported.") def _create_temporary_field( @@ -134,7 +150,7 @@ def _create_temporary_field( domain: FieldopDomain, node_type: ts.FieldType, dataflow_output: gtir_dataflow.DataflowOutputEdge, -) -> Field: +) -> FieldopData: """Helper method to allocate a temporary field where to write the output of a field operator.""" domain_dims, _, domain_ubs = zip(*domain) field_dims = list(domain_dims) @@ -148,7 +164,7 @@ def _create_temporary_field( # eliminate most of transient arrays. field_shape = list(domain_ubs) - output_desc = dataflow_output.result.node.desc(sdfg) + output_desc = dataflow_output.result.dc_node.desc(sdfg) if isinstance(output_desc, dace.data.Array): assert isinstance(node_type.dtype, itir_ts.ListType) assert isinstance(node_type.dtype.element_type, ts.ScalarType) @@ -165,7 +181,7 @@ def _create_temporary_field( field_node = state.add_access(temp_name) field_type = ts.FieldType(field_dims, node_type.dtype) - return Field(field_node, field_type) + return FieldopData(field_node, field_type, local_offset=dataflow_output.result.local_offset) def extract_domain(node: gtir.Node) -> FieldopDomain: @@ -284,7 +300,7 @@ def translate_as_fieldop( # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder, reduce_identity) input_edges, output = taskgen.visit(stencil_expr, args=stencil_args) - output_desc = output.result.node.desc(sdfg) + output_desc = output.result.dc_node.desc(sdfg) if isinstance(node.type.dtype, itir_ts.ListType): assert isinstance(output_desc, dace.data.Array) @@ -314,7 +330,7 @@ def translate_as_fieldop( edge.connect(me) # and here the edge writing the result data through the map exit node - output.connect(mx, result_field.data_node, output_subset) + output.connect(mx, result_field.dc_node, output_subset) return result_field @@ -353,7 +369,7 @@ def translate_broadcast_scalar( assert isinstance(scalar_expr, gtir_dataflow.MemletExpr) assert scalar_expr.subset == sbs.Indices.from_string("0") result = gtir_dataflow.DataflowOutputEdge( - state, gtir_dataflow.DataExpr(scalar_expr.node, node.args[0].type) + state, gtir_dataflow.ValueExpr(scalar_expr.dc_node, node.args[0].type) ) result_field = _create_temporary_field(sdfg, state, domain, node.type, dataflow_output=result) @@ -364,11 +380,11 @@ def translate_broadcast_scalar( dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" for dim, lower_bound, upper_bound in domain }, - inputs={"__inp": dace.Memlet(data=scalar_expr.node.data, subset="0")}, + inputs={"__inp": dace.Memlet(data=scalar_expr.dc_node.data, subset="0")}, code="__val = __inp", - outputs={"__val": dace.Memlet(data=result_field.data_node.data, subset=domain_indices)}, - input_nodes={scalar_expr.node.data: scalar_expr.node}, - output_nodes={result_field.data_node.data: result_field.data_node}, + outputs={"__val": dace.Memlet(data=result_field.dc_node.data, subset=domain_indices)}, + input_nodes={scalar_expr.dc_node.data: scalar_expr.dc_node}, + output_nodes={result_field.dc_node.data: result_field.dc_node}, external_edges=True, ) @@ -431,16 +447,16 @@ def translate_if( reduce_identity=reduce_identity, ) - def make_temps(x: Field) -> Field: - desc = x.data_node.desc(sdfg) + def make_temps(output_data: FieldopData) -> FieldopData: + desc = output_data.dc_node.desc(sdfg) data_name, _ = sdfg.add_temp_transient_like(desc) data_node = state.add_access(data_name) - return Field(data_node, x.data_type) + return FieldopData(data_node, output_data.gt_dtype, output_data.local_offset) result_temps = gtx_utils.tree_map(make_temps)(true_br_args) - fields: Iterable[tuple[Field, Field, Field]] = zip( + fields: Iterable[tuple[FieldopData, FieldopData, FieldopData]] = zip( gtx_utils.flatten_nested_tuple((true_br_args,)), gtx_utils.flatten_nested_tuple((false_br_args,)), gtx_utils.flatten_nested_tuple((result_temps,)), @@ -448,11 +464,11 @@ def make_temps(x: Field) -> Field: ) for true_br, false_br, temp in fields: - assert true_br.data_type == false_br.data_type - true_br_node = true_br.data_node - false_br_node = false_br.data_node + assert true_br.gt_dtype == false_br.gt_dtype + true_br_node = true_br.dc_node + false_br_node = false_br.dc_node - temp_name = temp.data_node.data + temp_name = temp.dc_node.data true_br_output_node = true_state.add_access(temp_name) true_state.add_nedge( true_br_node, @@ -479,7 +495,19 @@ def _get_data_nodes( ) -> FieldopResult: if isinstance(sym_type, ts.FieldType): sym_node = state.add_access(sym_name) - return Field(sym_node, sym_type) + local_dims = [dim for dim in sym_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL] + if len(local_dims) > 1: + raise ValueError(f"Field {sym_name} has more than one local dimension.") + elif len(local_dims) == 1: + # we ensure that the name of the local dimension corresponds to a valid + # connectivity-based offset provider + local_offset = next(iter(local_dims)).value + assert isinstance( + sdfg_builder.get_offset_provider(local_offset), gtx_common.Connectivity + ) + else: + local_offset = None + return FieldopData(sym_node, sym_type, local_offset) elif isinstance(sym_type, ts.ScalarType): if sym_name in sdfg.symbols: sym_node = _get_symbolic_value( @@ -487,7 +515,7 @@ def _get_data_nodes( ) else: sym_node = state.add_access(sym_name) - return Field(sym_node, sym_type) + return FieldopData(sym_node, sym_type, local_offset=None) elif isinstance(sym_type, ts.TupleType): tuple_fields = dace_gtir_utils.get_tuple_fields(sym_name, sym_type) return tuple( @@ -543,7 +571,7 @@ def translate_literal( data_type = node.type data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type) - return Field(data_node, data_type) + return FieldopData(data_node, data_type, local_offset=None) def translate_make_tuple( @@ -586,13 +614,13 @@ def translate_tuple_get( head_state=state, reduce_identity=reduce_identity, ) - if isinstance(data_nodes, Field): + if isinstance(data_nodes, FieldopData): raise ValueError(f"Invalid tuple expression {node}") - unused_arg_nodes: Iterable[Field] = gtx_utils.flatten_nested_tuple( + unused_arg_nodes: Iterable[FieldopData] = gtx_utils.flatten_nested_tuple( tuple(arg for i, arg in enumerate(data_nodes) if i != index) ) state.remove_nodes_from( - [arg.data_node for arg in unused_arg_nodes if state.degree(arg.data_node) == 0] + [arg.dc_node for arg in unused_arg_nodes if state.degree(arg.dc_node) == 0] ) return data_nodes[index] @@ -633,10 +661,10 @@ def translate_scalar_expr( head_state=state, reduce_identity=reduce_identity, ) - if not (isinstance(arg, Field) and isinstance(arg.data_type, ts.ScalarType)): + if not (isinstance(arg, FieldopData) and isinstance(arg.gt_dtype, ts.ScalarType)): raise ValueError(f"Invalid argument to scalar expression {arg_expr}.") param = f"__arg{i}" - args.append(arg.data_node) + args.append(arg.dc_node) connectors.append(param) scalar_expr_args.append(gtir.SymRef(id=param)) else: @@ -678,7 +706,7 @@ def translate_scalar_expr( dace.Memlet(data=temp_name, subset="0"), ) - return Field(temp_node, node.type) + return FieldopData(temp_node, node.type, local_offset=None) def translate_symbol_ref( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index fdd92b57c1..4f6a1e04c6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -30,19 +30,43 @@ @dataclasses.dataclass(frozen=True) -class DataExpr: - """Local storage for the computation result returned by a tasklet node.""" +class ValueExpr: + """ + Local storage for the values returned by dataflow computation. + + This type is used in the context in a dataflow, that is a stencil expression. + Therefore, it contains either a scalar value (single elements in the fields) or + a list of values in a local dimension. + This is different from `gtir_builtin_translators.FieldopData` which represents + the result of a field operator, basically the data storage outside a global map. - node: dace.nodes.AccessNode - dtype: itir_ts.ListType | ts.ScalarType + Args: + dc_node: Access node to the data storage, can be either a scalar or a local list. + gt_dtype: GT4Py type definition, which includes the field domain information. + local_offset: Provides information about the local dimension in`FieldType` data. + For a more detailed explanation see `gtir_builtin_translators.FieldopData`. + """ + + dc_node: dace.nodes.AccessNode + gt_dtype: itir_ts.ListType | ts.ScalarType + local_offset: Optional[str] = None @dataclasses.dataclass(frozen=True) class MemletExpr: - """Scalar or array data access through a memlet.""" + """ + Scalar or array data access through a memlet. - node: dace.nodes.AccessNode + Args: + dc_node: Access node to the data storage, can be either a scalar or a local list. + subset: Represents the subset to use in memlet to access the above data. + local_offset: Provides information about the local dimension in`FieldType` data. + For a more detailed explanation see `gtir_builtin_translators.FieldopData`. + """ + + dc_node: dace.nodes.AccessNode subset: sbs.Indices | sbs.Range + local_offset: Optional[str] = None @dataclasses.dataclass(frozen=True) @@ -50,10 +74,10 @@ class SymbolExpr: """Any symbolic expression that is constant in the context of current SDFG.""" value: dace.symbolic.SymExpr - dtype: dace.typeclass + dc_dtype: dace.typeclass -ValueExpr: TypeAlias = DataExpr | MemletExpr | SymbolExpr +DataExpr: TypeAlias = ValueExpr | MemletExpr | SymbolExpr @dataclasses.dataclass(frozen=True) @@ -62,18 +86,20 @@ class IteratorExpr: Iterator for field access to be consumed by `deref` or `shift` builtin functions. Args: - field: The field this iterator operates on. - dimensions: Field domain represented as a sorted list of dimensions. - In order to dereference an element in the field, we need index values - for all the dimensions in the right order. + field: Access node to the field this iterator operates on. + dimensions: Field domain represented as a sorted list of dimensions, needed + to order the map index variables and dereference an element in the field. indices: Maps each dimension to an index value, which could be either a symbolic value - or the result of a tasklet computation like neighbors connectivity or dynamic offset. + or the result of a tasklet computation like neighbors connectivity or dynamic offset. + local_offset: Provides information about the local dimension in`FieldType` data. + For a more detailed explanation see `gtir_builtin_translators.FieldopData`. """ field: dace.nodes.AccessNode dimensions: list[gtx_common.Dimension] - indices: dict[gtx_common.Dimension, ValueExpr] + indices: dict[gtx_common.Dimension, DataExpr] + local_offset: Optional[str] = None class DataflowInputEdge(Protocol): @@ -119,7 +145,7 @@ def connect(self, me: dace.nodes.MapEntry) -> None: @dataclasses.dataclass(frozen=True) class EmptyInputEdge(DataflowInputEdge): """ - Allows to setup an edge from a map entry node to a tasklet with no arguements. + Allows to setup an edge from a map entry node to a tasklet with no arguments. The reason behind this kind of connection is that all nodes inside a map scope must have an in/out path that traverses the entry and exit nodes. @@ -145,30 +171,30 @@ class DataflowOutputEdge: """ state: dace.SDFGState - result: DataExpr + result: ValueExpr def connect( self, mx: dace.nodes.MapExit, - result_node: dace.nodes.AccessNode, + dest: dace.nodes.AccessNode, subset: sbs.Range, ) -> None: # retrieve the node which writes the result - last_node = self.state.in_edges(self.result.node)[0].src + last_node = self.state.in_edges(self.result.dc_node)[0].src if isinstance(last_node, dace.nodes.Tasklet): # the last transient node can be deleted - last_node_connector = self.state.in_edges(self.result.node)[0].src_conn - self.state.remove_node(self.result.node) + last_node_connector = self.state.in_edges(self.result.dc_node)[0].src_conn + self.state.remove_node(self.result.dc_node) else: - last_node = self.result.node + last_node = self.result.dc_node last_node_connector = None self.state.add_memlet_path( last_node, mx, - result_node, + dest, src_conn=last_node_connector, - memlet=dace.Memlet(data=result_node.data, subset=subset), + memlet=dace.Memlet(data=dest.data, subset=subset), ) @@ -187,7 +213,7 @@ def connect( def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: assert isinstance(node.type, ts.ScalarType) - dtype = dace_utils.as_dace_type(node.type) + dc_dtype = dace_utils.as_dace_type(node.type) assert isinstance(node.fun, gtir.FunCall) assert len(node.fun.args) == 2 @@ -195,12 +221,12 @@ def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: op_name = str(node.fun.args[0]) assert isinstance(node.fun.args[1], gtir.Literal) assert node.fun.args[1].type == node.type - reduce_init = SymbolExpr(node.fun.args[1].value, dtype) + reduce_init = SymbolExpr(node.fun.args[1].value, dc_dtype) if op_name not in DACE_REDUCTION_MAPPING: raise RuntimeError(f"Reduction operation '{op_name}' not supported.") - identity_value = dace.dtypes.reduction_identity(dtype, DACE_REDUCTION_MAPPING[op_name]) - reduce_identity = SymbolExpr(identity_value, dtype) + identity_value = dace.dtypes.reduction_identity(dc_dtype, DACE_REDUCTION_MAPPING[op_name]) + reduce_identity = SymbolExpr(identity_value, dc_dtype) return op_name, reduce_init, reduce_identity @@ -325,9 +351,9 @@ def _add_mapped_tasklet( name, self.state, map_ranges, inputs, code, outputs, **kwargs ) - def _construct_local_view(self, field: MemletExpr | DataExpr) -> DataExpr: + def _construct_local_view(self, field: MemletExpr | ValueExpr) -> ValueExpr: if isinstance(field, MemletExpr): - desc = field.node.desc(self.sdfg) + desc = field.dc_node.desc(self.sdfg) local_dim_indices = [i for i, size in enumerate(field.subset.size()) if size != 1] if len(local_dim_indices) == 0: # we are accessing a single-element array with shape (1,) @@ -337,36 +363,37 @@ def _construct_local_view(self, field: MemletExpr | DataExpr) -> DataExpr: view_shape = tuple(desc.shape[i] for i in local_dim_indices) view_strides = tuple(desc.strides[i] for i in local_dim_indices) view, _ = self.sdfg.add_view( - f"{field.node.data}_view", + f"{field.dc_node.data}_view", view_shape, desc.dtype, strides=view_strides, find_new_name=True, ) local_view_node = self.state.add_access(view) - self._add_input_data_edge(field.node, field.subset, local_view_node) + self._add_input_data_edge(field.dc_node, field.subset, local_view_node) - return DataExpr(local_view_node, desc.dtype) + return ValueExpr(local_view_node, desc.dtype) else: return field def _construct_tasklet_result( self, - dtype: dace.typeclass, + dc_dtype: dace.typeclass, src_node: dace.nodes.Tasklet, src_connector: str, + local_offset: Optional[str] = None, use_array: bool = False, - ) -> DataExpr: + ) -> ValueExpr: temp_name = self.sdfg.temp_data_name() if use_array: # In some cases, such as result data with list-type annotation, we want # that output data is represented as an array (single-element 1D array) # in order to allow for composition of array shape in external memlets. - self.sdfg.add_array(temp_name, (1,), dtype, transient=True) + self.sdfg.add_array(temp_name, (1,), dc_dtype, transient=True) else: - self.sdfg.add_scalar(temp_name, dtype, transient=True) - data_type = dace_utils.as_itir_type(dtype) + self.sdfg.add_scalar(temp_name, dc_dtype, transient=True) + data_type = dace_utils.as_itir_type(dc_dtype) temp_node = self.state.add_access(temp_name) self._add_edge( src_node, @@ -375,9 +402,9 @@ def _construct_tasklet_result( None, dace.Memlet(data=temp_name, subset="0"), ) - return DataExpr(temp_node, data_type) + return ValueExpr(temp_node, data_type, local_offset) - def _visit_deref(self, node: gtir.FunCall) -> ValueExpr: + def _visit_deref(self, node: gtir.FunCall) -> DataExpr: """ Visit a `deref` node, which represents dereferencing of an iterator. The iterator is the argument of this node. @@ -409,7 +436,7 @@ def _visit_deref(self, node: gtir.FunCall) -> ValueExpr: else (0, size - 1, 1) for dim, size in zip(arg_expr.dimensions, field_desc.shape) ) - return MemletExpr(arg_expr.field, field_subset) + return MemletExpr(arg_expr.field, field_subset, arg_expr.local_offset) else: # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, @@ -449,31 +476,33 @@ def _visit_deref(self, node: gtir.FunCall) -> ValueExpr: deref_connector = IndexConnectorFmt.format(dim=dim.value) if isinstance(index_expr, MemletExpr): self._add_input_data_edge( - index_expr.node, + index_expr.dc_node, index_expr.subset, deref_node, deref_connector, ) - elif isinstance(index_expr, DataExpr): + elif isinstance(index_expr, ValueExpr): self._add_edge( - index_expr.node, + index_expr.dc_node, None, deref_node, deref_connector, - dace.Memlet(data=index_expr.node.data, subset="0"), + dace.Memlet(data=index_expr.dc_node.data, subset="0"), ) else: assert isinstance(index_expr, SymbolExpr) - dtype = arg_expr.field.desc(self.sdfg).dtype - return self._construct_tasklet_result(dtype, deref_node, "val") + dc_dtype = arg_expr.field.desc(self.sdfg).dtype + return self._construct_tasklet_result( + dc_dtype, deref_node, "val", arg_expr.local_offset + ) else: # dereferencing a scalar or a literal node results in the node itself return arg_expr - def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr: + def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: assert len(node.args) == 2 assert isinstance(node.type, itir_ts.ListType) @@ -532,24 +561,26 @@ def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr: ) neighbors_node = self.state.add_access(neighbors_temp) - offset_dim = gtx_common.Dimension(offset, kind=gtx_common.DimensionKind.LOCAL) - neighbor_idx = dace_gtir_utils.get_map_variable(offset_dim) + neighbor_idx = dace_gtir_utils.get_map_variable(offset) index_connector = "__index" output_connector = "__val" tasklet_expression = f"{output_connector} = __field[{index_connector}]" input_memlets = { - "__field": self.sdfg.make_array_memlet(field_slice.node.data), - index_connector: dace.Memlet(data=connectivity_slice.node.data, subset=neighbor_idx), + "__field": self.sdfg.make_array_memlet(field_slice.dc_node.data), + index_connector: dace.Memlet(data=connectivity_slice.dc_node.data, subset=neighbor_idx), } input_nodes = { - field_slice.node.data: field_slice.node, - connectivity_slice.node.data: connectivity_slice.node, + field_slice.dc_node.data: field_slice.dc_node, + connectivity_slice.dc_node.data: connectivity_slice.dc_node, } if offset_provider.has_skip_values: - assert self.reduce_identity is not None - assert self.reduce_identity.dtype == field_desc.dtype + if self.reduce_identity is None: + raise ValueError( + f"Found local offset '{offset}' with skip values, but 'reduce_identity' is not set." + ) + assert self.reduce_identity.dc_dtype == field_desc.dtype tasklet_expression += f" if {index_connector} != {gtx_common._DEFAULT_SKIP_VALUE} else {field_desc.dtype}({self.reduce_identity.value})" self._add_mapped_tasklet( @@ -565,9 +596,9 @@ def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr: external_edges=True, ) - return DataExpr(neighbors_node, node.type) + return ValueExpr(neighbors_node, node.type, offset) - def _visit_map(self, node: gtir.FunCall) -> DataExpr: + def _visit_map(self, node: gtir.FunCall) -> ValueExpr: """ A map node defines an operation to be mapped on all elements of input arguments. @@ -588,7 +619,7 @@ def _visit_map(self, node: gtir.FunCall) -> DataExpr: assert len(node.fun.args) == 1 # the operation to be mapped on the arguments assert isinstance(node.type.element_type, ts.ScalarType) - dtype = dace_utils.as_dace_type(node.type.element_type) + dc_dtype = dace_utils.as_dace_type(node.type.element_type) input_args = [self.visit(arg) for arg in node.args] input_connectors = [f"__arg{i}" for i in range(len(input_args))] @@ -599,9 +630,27 @@ def _visit_map(self, node: gtir.FunCall) -> DataExpr: fun_python_code = gtir_python_codegen.get_source(fun_node) tasklet_expression = f"{output_connector} = {fun_python_code}" - # TODO(edopao): extract offset_dim from the input arguments - offset_dim = gtx_common.Dimension("", gtx_common.DimensionKind.LOCAL) - map_index = dace_gtir_utils.get_map_variable(offset_dim) + input_local_offsets = [ + input_arg.local_offset for input_arg in input_args if input_arg.local_offset is not None + ] + if len(input_local_offsets) == 0: + raise ValueError(f"Missing information on local dimension for map node {node}.") + + # GT4Py guarantees that all connectivities used to generate lists of neighbors + # have the same length, that is the same value of 'max_neighbors'. + local_connectivities = dace_utils.filter_connectivities( + { + offset: self.subgraph_builder.get_offset_provider(offset) + for offset in input_local_offsets + } + ) + if len(set(table.max_neighbors for table in local_connectivities.values())) != 1: + raise ValueError( + "Unexpected arguments to map expression with different local dimensions." + ) + local_offset, offset_provider = next(iter(local_connectivities.items())) + local_size = offset_provider.max_neighbors + map_index = dace_gtir_utils.get_map_variable(local_offset) # The dataflow we build in this class has some loose connections on input edges. # These edges are described as set of nodes, that will have to be connected to @@ -611,9 +660,9 @@ def _visit_map(self, node: gtir.FunCall) -> DataExpr: # than representing map-to-map edges (which require memlets with 2 pass-nodes). input_memlets = {} input_nodes = {} - local_size: Optional[int] = None + skip_value_connectivities: dict[str, gtx_common.Connectivity] = {} for conn, input_expr in zip(input_connectors, input_args): - input_node = self._construct_local_view(input_expr).node + input_node = self._construct_local_view(input_expr).dc_node input_desc = input_node.desc(self.sdfg) # we assume that there is a single local dimension if len(input_desc.shape) != 1: @@ -621,21 +670,59 @@ def _visit_map(self, node: gtir.FunCall) -> DataExpr: input_size = input_desc.shape[0] if input_size == 1: input_memlets[conn] = dace.Memlet(data=input_node.data, subset="0") - elif local_size is not None and input_size != local_size: - raise ValueError(f"Invalid node {node}") + elif input_size != local_size: + raise ValueError( + f"Argument to map node with local size {input_size}, expected {local_size}." + ) else: + assert input_expr.local_offset input_memlets[conn] = dace.Memlet(data=input_node.data, subset=map_index) - local_size = input_size input_nodes[input_node.data] = input_node - if local_size is None: - # corner case where map is applied to 1-element lists - assert len(input_nodes) >= 1 - local_size = 1 + result, _ = self.sdfg.add_temp_transient((local_size,), dc_dtype) + result_node = self.state.add_access(result) + + skip_value_connectivities = { + offset: offset_provider + for offset, offset_provider in local_connectivities.items() + if offset_provider.has_skip_values + } + + if len(skip_value_connectivities) == 0: + result_offset = local_offset + else: + # In case one or more of input expressions contain skip values, we use + # the connectivity-based offset provider as mask for map computation. + # Therefore, the result of map computation will also contain skip values. + # GT4Py guarantees that the skip values are placed in the same positions + # for all input expressions. + + result_offset, offset_provider = next(iter(skip_value_connectivities.items())) + + connectivity = dace_utils.connectivity_identifier(result_offset) + connectivity_desc = self.sdfg.arrays[connectivity] + connectivity_desc.transient = False - out, _ = self.sdfg.add_temp_transient((local_size,), dtype) - out_node = self.state.add_access(out) + origin_map_index = dace_gtir_utils.get_map_variable(offset_provider.origin_axis) + + connectivity_slice = self._construct_local_view( + MemletExpr( + self.state.add_access(connectivity), + sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider.max_neighbors}"), + ) + ) + + if self.reduce_identity is None: + raise ValueError( + f"Found local offset '{result_offset}' with skip values, but 'reduce_identity' is not set." + ) + assert self.reduce_identity.dc_dtype == dc_dtype + input_memlets["__neighbor_idx"] = dace.Memlet( + data=connectivity_slice.dc_node.data, subset=map_index + ) + input_nodes[connectivity_slice.dc_node.data] = connectivity_slice.dc_node + tasklet_expression += f" if __neighbor_idx != {gtx_common._DEFAULT_SKIP_VALUE} else {dc_dtype}({self.reduce_identity.value})" self._add_mapped_tasklet( name="map", @@ -644,15 +731,15 @@ def _visit_map(self, node: gtir.FunCall) -> DataExpr: inputs=input_memlets, input_nodes=input_nodes, outputs={ - output_connector: dace.Memlet(data=out, subset=map_index), + output_connector: dace.Memlet(data=result, subset=map_index), }, - output_nodes={out: out_node}, + output_nodes={result: result_node}, external_edges=True, ) - return DataExpr(out_node, dtype) + return ValueExpr(result_node, dc_dtype, result_offset) - def _visit_reduce(self, node: gtir.FunCall) -> DataExpr: + def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(node.type, ts.ScalarType) op_name, reduce_init, reduce_identity = get_reduce_params(node) @@ -672,8 +759,8 @@ def _visit_reduce(self, node: gtir.FunCall) -> DataExpr: # ensure that we leave the visitor in the same state as we entered self.reduce_identity = prev_reduce_identity - assert isinstance(input_expr, MemletExpr | DataExpr) - input_desc = input_expr.node.desc(self.sdfg) + assert isinstance(input_expr, MemletExpr | ValueExpr) + input_desc = input_expr.dc_node.desc(self.sdfg) assert isinstance(input_desc, dace.data.Array) if len(input_desc.shape) > 1: @@ -691,19 +778,19 @@ def _visit_reduce(self, node: gtir.FunCall) -> DataExpr: if isinstance(input_expr, MemletExpr): self._add_input_data_edge( - input_expr.node, + input_expr.dc_node, input_expr.subset, reduce_node, ) else: self.state.add_nedge( - input_expr.node, + input_expr.dc_node, reduce_node, - dace.Memlet.from_array(input_expr.node.data, input_desc), + dace.Memlet.from_array(input_expr.dc_node.data, input_desc), ) temp_name = self.sdfg.temp_data_name() - self.sdfg.add_scalar(temp_name, reduce_identity.dtype, transient=True) + self.sdfg.add_scalar(temp_name, reduce_identity.dc_dtype, transient=True) temp_node = self.state.add_access(temp_name) self.state.add_nedge( @@ -711,7 +798,7 @@ def _visit_reduce(self, node: gtir.FunCall) -> DataExpr: temp_node, dace.Memlet(data=temp_name, subset="0"), ) - return DataExpr(temp_node, node.type) + return ValueExpr(temp_node, node.type) def _split_shift_args( self, args: list[gtir.Expr] @@ -742,18 +829,18 @@ def _visit_shift_multidim( return offset_provider_arg, offset_value_arg, it def _make_cartesian_shift( - self, it: IteratorExpr, offset_dim: gtx_common.Dimension, offset_expr: ValueExpr + self, it: IteratorExpr, offset_dim: gtx_common.Dimension, offset_expr: DataExpr ) -> IteratorExpr: """Implements cartesian shift along one dimension.""" assert offset_dim in it.dimensions - new_index: SymbolExpr | DataExpr + new_index: SymbolExpr | ValueExpr assert offset_dim in it.indices index_expr = it.indices[offset_dim] if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr): # purely symbolic expression which can be interpreted at compile time new_index = SymbolExpr( dace.symbolic.pystr_to_symbolic(index_expr.value) + offset_expr.value, - index_expr.dtype, + index_expr.dc_dtype, ) else: # the offset needs to be calculated by means of a tasklet (i.e. dynamic offset) @@ -782,27 +869,27 @@ def _make_cartesian_shift( for input_expr, input_connector in [(index_expr, "index"), (offset_expr, "offset")]: if isinstance(input_expr, MemletExpr): self._add_input_data_edge( - input_expr.node, + input_expr.dc_node, input_expr.subset, dynamic_offset_tasklet, input_connector, ) - elif isinstance(input_expr, DataExpr): + elif isinstance(input_expr, ValueExpr): self._add_edge( - input_expr.node, + input_expr.dc_node, None, dynamic_offset_tasklet, input_connector, - dace.Memlet(data=input_expr.node.data, subset="0"), + dace.Memlet(data=input_expr.dc_node.data, subset="0"), ) if isinstance(index_expr, SymbolExpr): - dtype = index_expr.dtype + dc_dtype = index_expr.dc_dtype else: - dtype = index_expr.node.desc(self.sdfg).dtype + dc_dtype = index_expr.dc_node.desc(self.sdfg).dtype new_index = self._construct_tasklet_result( - dtype, dynamic_offset_tasklet, new_index_connector + dc_dtype, dynamic_offset_tasklet, new_index_connector ) # a new iterator with a shifted index along one dimension @@ -814,10 +901,10 @@ def _make_cartesian_shift( def _make_dynamic_neighbor_offset( self, - offset_expr: MemletExpr | DataExpr, + offset_expr: MemletExpr | ValueExpr, offset_table_node: dace.nodes.AccessNode, origin_index: SymbolExpr, - ) -> DataExpr: + ) -> ValueExpr: """ Implements access to neighbor connectivity table by means of a tasklet node. @@ -839,29 +926,29 @@ def _make_dynamic_neighbor_offset( ) if isinstance(offset_expr, MemletExpr): self._add_input_data_edge( - offset_expr.node, + offset_expr.dc_node, offset_expr.subset, tasklet_node, "offset", ) else: self._add_edge( - offset_expr.node, + offset_expr.dc_node, None, tasklet_node, "offset", - dace.Memlet(data=offset_expr.node.data, subset="0"), + dace.Memlet(data=offset_expr.dc_node.data, subset="0"), ) - dtype = offset_table_node.desc(self.sdfg).dtype - return self._construct_tasklet_result(dtype, tasklet_node, new_index_connector) + dc_dtype = offset_table_node.desc(self.sdfg).dtype + return self._construct_tasklet_result(dc_dtype, tasklet_node, new_index_connector) def _make_unstructured_shift( self, it: IteratorExpr, connectivity: gtx_common.Connectivity, offset_table_node: dace.nodes.AccessNode, - offset_expr: ValueExpr, + offset_expr: DataExpr, ) -> IteratorExpr: """Implements shift in unstructured domain by means of a neighbor table.""" assert connectivity.neighbor_axis in it.dimensions @@ -928,16 +1015,16 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: it, offset_provider, offset_table_node, offset_expr ) - def _visit_generic_builtin(self, node: gtir.FunCall) -> DataExpr: + def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: """ Generic handler called by `visit_FunCall()` when it encounters a builtin function that does not match any other specific handler. """ node_internals = [] - node_connections: dict[str, MemletExpr | DataExpr] = {} + node_connections: dict[str, MemletExpr | ValueExpr] = {} for i, arg in enumerate(node.args): arg_expr = self.visit(arg) - if isinstance(arg_expr, MemletExpr | DataExpr): + if isinstance(arg_expr, MemletExpr | ValueExpr): # the argument value is the result of a tasklet node or direct field access connector = f"__arg{i}" node_connections[connector] = arg_expr @@ -961,17 +1048,17 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> DataExpr: ) for connector, arg_expr in node_connections.items(): - if isinstance(arg_expr, DataExpr): + if isinstance(arg_expr, ValueExpr): self._add_edge( - arg_expr.node, + arg_expr.dc_node, None, tasklet_node, connector, - dace.Memlet(data=arg_expr.node.data, subset="0"), + dace.Memlet(data=arg_expr.dc_node.data, subset="0"), ) else: self._add_input_data_edge( - arg_expr.node, + arg_expr.dc_node, arg_expr.subset, tasklet_node, connector, @@ -988,18 +1075,18 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> DataExpr: # Therefore we handle `ListType` as a single-element array with shape (1,) # that will be accessed in a map expression on a local domain. assert isinstance(node.type.element_type, ts.ScalarType) - dtype = dace_utils.as_dace_type(node.type.element_type) + dc_dtype = dace_utils.as_dace_type(node.type.element_type) # In order to ease the lowring of the parent expression on local dimension, # we represent the scalar value as a single-element 1D array. use_array = True else: assert isinstance(node.type, ts.ScalarType) - dtype = dace_utils.as_dace_type(node.type) + dc_dtype = dace_utils.as_dace_type(node.type) use_array = False - return self._construct_tasklet_result(dtype, tasklet_node, "result", use_array=use_array) + return self._construct_tasklet_result(dc_dtype, tasklet_node, "result", use_array=use_array) - def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | ValueExpr: + def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) @@ -1026,16 +1113,16 @@ def visit_Lambda( ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: for p, arg in zip(node.params, args, strict=True): self.symbol_map[str(p.id)] = arg - output_expr: ValueExpr = self.visit(node.expr) - if isinstance(output_expr, DataExpr): + output_expr: DataExpr = self.visit(node.expr) + if isinstance(output_expr, ValueExpr): return self.input_edges, DataflowOutputEdge(self.state, output_expr) if isinstance(output_expr, MemletExpr): # special case where the field operator is simply copying data from source to destination node - output_dtype = output_expr.node.desc(self.sdfg).dtype + output_dtype = output_expr.dc_node.desc(self.sdfg).dtype tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") self._add_input_data_edge( - output_expr.node, + output_expr.dc_node, output_expr.subset, tasklet_node, "__inp", @@ -1043,15 +1130,15 @@ def visit_Lambda( else: assert isinstance(output_expr, SymbolExpr) # even simpler case, where a constant value is written to destination node - output_dtype = output_expr.dtype + output_dtype = output_expr.dc_dtype tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {output_expr.value}") output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") return self.input_edges, DataflowOutputEdge(self.state, output_expr) def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: - dtype = dace_utils.as_dace_type(node.type) - return SymbolExpr(node.value, dtype) + dc_dtype = dace_utils.as_dace_type(node.type) + return SymbolExpr(node.value, dc_dtype) def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolExpr: param = str(node.id) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 31e561a19c..4627293cfd 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -186,18 +186,18 @@ def _make_array_shape_and_strides( Returns: Two lists of symbols, one for the shape and the other for the strides of the array. """ - dtype = dace.int32 + dc_dtype = gtir_builtin_translators.INDEX_DTYPE neighbor_tables = dace_utils.filter_connectivities(self.offset_provider) shape = [ ( neighbor_tables[dim.value].max_neighbors if dim.kind == gtx_common.DimensionKind.LOCAL - else dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) + else dace.symbol(dace_utils.field_size_symbol_name(name, i), dc_dtype) ) for i, dim in enumerate(dims) ] strides = [ - dace.symbol(dace_utils.field_stride_symbol_name(name, i), dtype) + dace.symbol(dace_utils.field_stride_symbol_name(name, i), dc_dtype) for i in range(len(dims)) ] return shape, strides @@ -247,18 +247,18 @@ def _add_storage( return tuple_fields elif isinstance(gt_type, ts.FieldType): - dtype = dace_utils.as_dace_type(gt_type.dtype) + dc_dtype = dace_utils.as_dace_type(gt_type.dtype) # use symbolic shape, which allows to invoke the program with fields of different size; # and symbolic strides, which enables decoupling the memory layout from generated code. sym_shape, sym_strides = self._make_array_shape_and_strides(name, gt_type.dims) - sdfg.add_array(name, sym_shape, dtype, strides=sym_strides, transient=transient) + sdfg.add_array(name, sym_shape, dc_dtype, strides=sym_strides, transient=transient) return [(name, gt_type)] elif isinstance(gt_type, ts.ScalarType): - dtype = dace_utils.as_dace_type(gt_type) + dc_dtype = dace_utils.as_dace_type(gt_type) if name in symbolic_arguments: - sdfg.add_symbol(name, dtype) + sdfg.add_symbol(name, dc_dtype) elif dace_utils.is_field_symbol(name): # Sometimes, when the field domain is implicitly derived from the # field domain, the gt4py lowering adds the field size as a scalar @@ -269,11 +269,11 @@ def _add_storage( # storage for field arguments. We assume that the scalar argument # for field size, if present, always follows the field argument. if name in sdfg.symbols: - assert sdfg.symbols[name].dtype == dtype + assert sdfg.symbols[name].dc_dtype == dc_dtype else: - sdfg.add_symbol(name, dtype) + sdfg.add_symbol(name, dc_dtype) else: - sdfg.add_scalar(name, dtype, transient=transient) + sdfg.add_scalar(name, dc_dtype, transient=transient) return [(name, gt_type)] @@ -289,7 +289,7 @@ def _add_storage_for_temporary(self, temp_decl: gtir.Temporary) -> dict[str, str def _visit_expression( self, node: gtir.Expr, sdfg: dace.SDFG, head_state: dace.SDFGState, use_temp: bool = True - ) -> list[gtir_builtin_translators.Field]: + ) -> list[gtir_builtin_translators.FieldopData]: """ Specialized visit method for fieldview expressions. @@ -307,17 +307,21 @@ def _visit_expression( assert len(sink_states) == 1 assert sink_states[0] == head_state - def make_temps(field: gtir_builtin_translators.Field) -> gtir_builtin_translators.Field: - desc = sdfg.arrays[field.data_node.data] + def make_temps( + field: gtir_builtin_translators.FieldopData, + ) -> gtir_builtin_translators.FieldopData: + desc = sdfg.arrays[field.dc_node.data] if desc.transient or not use_temp: return field else: temp, _ = sdfg.add_temp_transient_like(desc) temp_node = head_state.add_access(temp) head_state.add_nedge( - field.data_node, temp_node, sdfg.make_array_memlet(field.data_node.data) + field.dc_node, temp_node, sdfg.make_array_memlet(field.dc_node.data) + ) + return gtir_builtin_translators.FieldopData( + temp_node, field.gt_dtype, field.local_offset ) - return gtir_builtin_translators.Field(temp_node, field.data_type) temp_result = gtx_utils.tree_map(make_temps)(result) return list(gtx_utils.flatten_nested_tuple((temp_result,))) @@ -449,35 +453,35 @@ def visit_SetAt( target_state: Optional[dace.SDFGState] = None for temp, target in zip(temp_fields, target_fields, strict=True): - target_desc = sdfg.arrays[target.data_node.data] + target_desc = sdfg.arrays[target.dc_node.data] assert not target_desc.transient - if isinstance(target.data_type, ts.FieldType): + if isinstance(target.gt_dtype, ts.FieldType): subset = ",".join( - f"{domain[dim][0]}:{domain[dim][1]}" for dim in target.data_type.dims + f"{domain[dim][0]}:{domain[dim][1]}" for dim in target.gt_dtype.dims ) else: assert len(domain) == 0 subset = "0" - if target.data_node.data in state_input_data: + if target.dc_node.data in state_input_data: # if inout argument, write the result in separate next state # this is needed to avoid undefined behavior for expressions like: X, Y = X + 1, X if not target_state: target_state = sdfg.add_state_after(state, f"post_{state.label}") # create new access nodes in the target state target_state.add_nedge( - target_state.add_access(temp.data_node.data), - target_state.add_access(target.data_node.data), - dace.Memlet(data=target.data_node.data, subset=subset, other_subset=subset), + target_state.add_access(temp.dc_node.data), + target_state.add_access(target.dc_node.data), + dace.Memlet(data=target.dc_node.data, subset=subset, other_subset=subset), ) # remove isolated access node - state.remove_node(target.data_node) + state.remove_node(target.dc_node) else: state.add_nedge( - temp.data_node, - target.data_node, - dace.Memlet(data=target.data_node.data, subset=subset, other_subset=subset), + temp.dc_node, + target.dc_node, + dace.Memlet(data=target.dc_node.data, subset=subset, other_subset=subset), ) return target_state or state @@ -559,7 +563,7 @@ def visit_Lambda( sym: self.global_symbols[sym] for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) } | { - pname: dace_gtir_utils.get_tuple_type(arg) if isinstance(arg, tuple) else arg.data_type + pname: dace_gtir_utils.get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_dtype for pname, arg in lambda_args_mapping } @@ -587,7 +591,7 @@ def visit_Lambda( def _flatten_tuples( name: str, arg: gtir_builtin_translators.FieldopResult, - ) -> list[tuple[str, gtir_builtin_translators.Field]]: + ) -> list[tuple[str, gtir_builtin_translators.FieldopData]]: if isinstance(arg, tuple): tuple_type = dace_gtir_utils.get_tuple_type(arg) tuple_field_names = [ @@ -617,7 +621,7 @@ def _flatten_tuples( continue datadesc: Optional[dace.dtypes.Data] = None if nsdfg_dataname in lambda_arg_nodes: - src_node = lambda_arg_nodes[nsdfg_dataname].data_node + src_node = lambda_arg_nodes[nsdfg_dataname].dc_node dataname = src_node.data datadesc = src_node.desc(sdfg) else: @@ -642,18 +646,20 @@ def _flatten_tuples( # Process lambda outputs # - lambda_output_nodes: Iterable[gtir_builtin_translators.Field] = ( + lambda_output_data: Iterable[gtir_builtin_translators.FieldopData] = ( gtx_utils.flatten_nested_tuple(lambda_result) ) # sanity check on isolated nodes assert all( - nstate.degree(x.data_node) == 0 - for x in lambda_output_nodes - if x.data_node.data in input_memlets + nstate.degree(output_data.dc_node) == 0 + for output_data in lambda_output_data + if output_data.dc_node.data in input_memlets ) # keep only non-isolated output nodes lambda_outputs = { - x.data_node.data for x in lambda_output_nodes if x.data_node.data not in input_memlets + output_data.dc_node.data + for output_data in lambda_output_data + if output_data.dc_node.data not in input_memlets } if lambda_outputs: @@ -668,36 +674,40 @@ def _flatten_tuples( for connector, memlet in input_memlets.items(): if connector in lambda_arg_nodes: - src_node = lambda_arg_nodes[connector].data_node + src_node = lambda_arg_nodes[connector].dc_node else: src_node = head_state.add_access(memlet.data) head_state.add_edge(src_node, None, nsdfg_node, connector, memlet) def make_temps( - x: gtir_builtin_translators.Field, - ) -> gtir_builtin_translators.Field: - if x.data_node.data in lambda_outputs: - connector = x.data_node.data - desc = x.data_node.desc(nsdfg) + output_data: gtir_builtin_translators.FieldopData, + ) -> gtir_builtin_translators.FieldopData: + if output_data.dc_node.data in lambda_outputs: + connector = output_data.dc_node.data + desc = output_data.dc_node.desc(nsdfg) # make lambda result non-transient and map it to external temporary desc.transient = False # isolated access node will make validation fail - if nstate.degree(x.data_node) == 0: - nstate.remove_node(x.data_node) + if nstate.degree(output_data.dc_node) == 0: + nstate.remove_node(output_data.dc_node) temp, _ = sdfg.add_temp_transient_like(desc) dst_node = head_state.add_access(temp) head_state.add_edge( nsdfg_node, connector, dst_node, None, sdfg.make_array_memlet(temp) ) - return gtir_builtin_translators.Field(dst_node, x.data_type) - elif x.data_node.data in lambda_arg_nodes: - nstate.remove_node(x.data_node) - return lambda_arg_nodes[x.data_node.data] + return gtir_builtin_translators.FieldopData( + dst_node, output_data.gt_dtype, output_data.local_offset + ) + elif output_data.dc_node.data in lambda_arg_nodes: + nstate.remove_node(output_data.dc_node) + return lambda_arg_nodes[output_data.dc_node.data] else: - nstate.remove_node(x.data_node) - data_node = head_state.add_access(x.data_node.data) - return gtir_builtin_translators.Field(data_node, x.data_type) + nstate.remove_node(output_data.dc_node) + data_node = head_state.add_access(output_data.dc_node.data) + return gtir_builtin_translators.FieldopData( + data_node, output_data.gt_dtype, output_data.local_offset + ) return gtx_utils.tree_map(make_temps)(lambda_result) @@ -734,7 +744,7 @@ def build_sdfg_from_gtir( The lowering to SDFG requires that the program node is type-annotated, therefore this function runs type ineference as first step. - Arguments: + Args: ir: The GTIR program node to be lowered to SDFG offset_provider: The definitions of offset providers used by the program node diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index b5c447a1be..355eaac903 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -15,12 +15,16 @@ from gt4py.next.type_system import type_specifications as ts -def get_map_variable(dim: gtx_common.Dimension) -> str: +def get_map_variable(dim: gtx_common.Dimension | str) -> str: """ Format map variable name based on the naming convention for application-specific SDFG transformations. """ + if not isinstance(dim, gtx_common.Dimension): + if len(dim) != 0: + dim = gtx_common.Dimension(dim, gtx_common.DimensionKind.LOCAL) + else: + raise ValueError("Dimension name cannot be empty.") suffix = "dim" if dim.kind == gtx_common.DimensionKind.LOCAL else "" - # TODO(edopao): raise exception if dim.value is empty return f"i_{dim.value}_gtx_{dim.kind}{suffix}" @@ -60,5 +64,5 @@ def get_tuple_type(data: tuple[Any, ...]) -> ts.TupleType: Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. """ return ts.TupleType( - types=[get_tuple_type(d) if isinstance(d, tuple) else d.data_type for d in data] + types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_dtype for d in data] ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 230ff695fa..2c1c04ce1f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1137,8 +1137,6 @@ def test_gtir_connectivity_shift_chain(): def test_gtir_neighbors_as_input(): - # FIXME[#1582](edopao): Enable testcase when type inference is working - pytest.skip("Field of lists not fully supported by GTIR type inference") init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) testee = gtir.Program( @@ -1146,24 +1144,28 @@ def test_gtir_neighbors_as_input(): function_definitions=[], params=[ gtir.Sym(id="v2e_field", type=V2E_FTYPE), - gtir.Sym(id="vertex", type=EFTYPE), + gtir.Sym(id="edges", type=EFTYPE), + gtir.Sym(id="vertices", type=VFTYPE), gtir.Sym(id="nvertices", type=SIZE_TYPE), ], declarations=[], body=[ gtir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - "it" - ) - ), - vertex_domain, + expr=im.as_fieldop( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.deref("it") + ) + ), + vertex_domain, + )( + im.op_as_fieldop(im.map_("plus"), vertex_domain)( + "v2e_field", + im.as_fieldop_neighbors("V2E", "edges", vertex_domain), ) - )("v2e_field"), + ), domain=vertex_domain, - target=gtir.SymRef(id="vertex"), + target=gtir.SymRef(id="vertices"), ) ], ) @@ -1174,16 +1176,19 @@ def test_gtir_neighbors_as_input(): assert isinstance(connectivity_V2E, gtx_common.NeighborTable) v2e_field = np.random.rand(SIMPLE_MESH.num_vertices, connectivity_V2E.max_neighbors) + e = np.random.rand(SIMPLE_MESH.num_edges) v = np.empty(SIMPLE_MESH.num_vertices, dtype=v2e_field.dtype) v_ref = [ - functools.reduce(lambda x, y: x + y, v2e_neighbors, init_value) - for v2e_neighbors in v2e_field + functools.reduce(lambda x, y: x + y, v2e_values + e[v2e_neighbors], init_value) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field, strict=True) ] sdfg( v2e_field, + e, v, + connectivity_V2E=connectivity_V2E.table, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __v2e_field_size_0=SIMPLE_MESH.num_vertices, @@ -1386,13 +1391,48 @@ def test_gtir_reduce_dot_product(): connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) + # create mesh with skip values + connectivity_V2E_skip = copy.deepcopy(connectivity_V2E) + connectivity_V2E_skip.has_skip_values = True + connectivity_V2E_skip.table = np.asarray( + [ + [x if i != skip_idx else gtx_common._DEFAULT_SKIP_VALUE for i, x in enumerate(row)] + for skip_idx, row in zip( + np.random.randint(0, connectivity_V2E.max_neighbors, size=SIMPLE_MESH.num_vertices), + connectivity_V2E.table, + strict=True, + ) + ], + dtype=connectivity_V2E.table.dtype, + ) + # safety check that the connectivity table actually contains skip values + assert len(np.where(connectivity_V2E.table == gtx_common._DEFAULT_SKIP_VALUE)) != 0 + + offset_provider = SIMPLE_MESH_OFFSET_PROVIDER | { + "V2E_skip": connectivity_V2E_skip, + } + + V2E_SKIP_SYMBOLS = dict( + __connectivity_V2E_skip_size_0=SIMPLE_MESH.num_vertices, + __connectivity_V2E_skip_size_1=connectivity_V2E_skip.max_neighbors, + __connectivity_V2E_skip_stride_0=connectivity_V2E_skip.max_neighbors, + __connectivity_V2E_skip_stride_1=1, + ) + e = np.random.rand(SIMPLE_MESH.num_edges) v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) v_ref = [ functools.reduce( - lambda x, y: x + y, (e[v2e_neighbors] * e[v2e_neighbors]) + 1.0, init_value + lambda x, y: x + y, + map( + lambda x: 0.0 if x[1] == gtx_common._DEFAULT_SKIP_VALUE else x[0], + zip((e[v2e_neighbors] * e[v2e_skip_neighbors]) + 1.0, v2e_skip_neighbors), + ), + init_value, + ) + for v2e_neighbors, v2e_skip_neighbors in zip( + connectivity_V2E.table, connectivity_V2E_skip.table ) - for v2e_neighbors in connectivity_V2E.table ] stencil_inlined = im.call( @@ -1402,7 +1442,7 @@ def test_gtir_reduce_dot_product(): im.map_("plus")( im.map_("multiplies")( im.neighbors("V2E", "it"), - im.neighbors("V2E", "it"), + im.neighbors("V2E_skip", "it"), ), im.call("make_const_list")(1.0), ) @@ -1425,7 +1465,7 @@ def test_gtir_reduce_dot_product(): im.op_as_fieldop(im.map_("plus"), vertex_domain)( im.op_as_fieldop(im.map_("multiplies"), vertex_domain)( im.as_fieldop_neighbors("V2E", "edges", vertex_domain), - im.as_fieldop_neighbors("V2E", "edges", vertex_domain), + im.as_fieldop_neighbors("V2E_skip", "edges", vertex_domain), ), im.op_as_fieldop("make_const_list", vertex_domain)(1.0), ) @@ -1450,14 +1490,16 @@ def test_gtir_reduce_dot_product(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, offset_provider) sdfg( e, v, connectivity_V2E=connectivity_V2E.table, + connectivity_V2E_skip=connectivity_V2E_skip.table, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), + **V2E_SKIP_SYMBOLS, ) assert np.allclose(v, v_ref) From 5f9891ed81a01d061c2a988abaca6b5b2ed58c17 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 23 Oct 2024 15:09:43 +0200 Subject: [PATCH 012/178] bug[next]: foast2gtir lowering of broadcasted field (#1701) Wrap every broadcast in an `as_fieldop` (not only scalars). The materialization of intermediate broadcasted fields need to be optimized by transformations. --- src/gt4py/next/ffront/foast_to_gtir.py | 4 +--- .../next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 9cb0ce05f5..0d0c3868f8 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -374,9 +374,7 @@ def create_if(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall: def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: expr = self.visit(node.args[0], **kwargs) - if isinstance(node.args[0].type, ts.ScalarType): - return im.as_fieldop(im.ref("deref"))(expr) - return expr + return im.as_fieldop(im.ref("deref"))(expr) def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self._map(self.visit(node.func, **kwargs), *node.args) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 09f18246dc..4a1a7cba8e 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -915,7 +915,7 @@ def foo(inp: gtx.Field[[TDim], float64]): lowered = FieldOperatorLowering.apply(parsed) assert lowered.id == "foo" - assert lowered.expr == im.ref("inp") + assert lowered.expr == im.as_fieldop("deref")(im.ref("inp")) def test_scalar_broadcast(): From b8020105c070d22c680126f401f4c169eca32e13 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 23 Oct 2024 12:25:52 -0400 Subject: [PATCH 013/178] fix[cartesian]: Verbose frontend error for bad call (#1700) When making a bad call to at a `gtscript.function` we add the function name in the error message for quick reference. --- src/gt4py/cartesian/frontend/gtscript_frontend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index e2aa98f3cf..ade05921ef 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -459,7 +459,8 @@ def visit_Call(self, node: ast.Call, *, target_node=None): # Cyclomatic complex call_args[name] = ast.Constant(value=arg_infos[name]) except Exception as ex: raise GTScriptSyntaxError( - message="Invalid call signature", loc=nodes.Location.from_ast_node(node) + message=f"Invalid call signature when calling {call_name}", + loc=nodes.Location.from_ast_node(node), ) from ex # Rename local names in subroutine to avoid conflicts with caller context names From c1106fca21ee90fe1d36978dfe4c827f2d6955dd Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 23 Oct 2024 20:28:29 +0200 Subject: [PATCH 014/178] fext[next]: GTIR embedded backend (not active in tests) (#1702) with features from #1648 --- .../next/advanced/ToolchainWalkthrough.md | 4 +- src/gt4py/next/ffront/decorator.py | 13 ++++-- src/gt4py/next/ffront/foast_to_past.py | 5 +++ src/gt4py/next/ffront/gtcallable.py | 10 +++++ src/gt4py/next/ffront/past_to_itir.py | 15 +++++-- .../next/iterator/transforms/__init__.py | 9 ++++- .../next/iterator/transforms/pass_manager.py | 37 +++++++++++++---- .../program_processors/formatters/lisp.py | 2 +- .../runners/dace_fieldview/workflow.py | 8 ++-- .../program_processors/runners/roundtrip.py | 40 +++++++++++-------- 10 files changed, 102 insertions(+), 41 deletions(-) diff --git a/docs/user/next/advanced/ToolchainWalkthrough.md b/docs/user/next/advanced/ToolchainWalkthrough.md index b82dea1a2f..a5a63cb56c 100644 --- a/docs/user/next/advanced/ToolchainWalkthrough.md +++ b/docs/user/next/advanced/ToolchainWalkthrough.md @@ -247,7 +247,7 @@ pprint.pprint(jit_args) ``` ```python -gtx.program_processors.runners.roundtrip.executor(pitir)(*jit_args.args, **jit_args.kwargs) +gtx.program_processors.runners.roundtrip.Roundtrip()(pitir)(*jit_args.args, **jit_args.kwargs) ``` ```python @@ -290,7 +290,7 @@ assert pitir2 == pitir #### Pass The result to the compile workflow and execute ```python -example_compiled = gtx.program_processors.runners.roundtrip.executor(pitir2) +example_compiled = gtx.program_processors.runners.roundtrip.Roundtrip()(pitir2) ``` ```python diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 52fe8d8116..dc2421e1d2 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -34,6 +34,8 @@ from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( field_operator_ast as foast, + foast_to_gtir, + foast_to_itir, past_process_args, signature, stages as ffront_stages, @@ -560,10 +562,15 @@ def with_grid_type(self, grid_type: GridType) -> FieldOperator: self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type) ) + # TODO(tehrengruber): We can not use transforms from `self.backend` since this can be + # a different backend than the one of the program that calls this field operator. Just use + # the hard-coded lowering until this is cleaned up. def __gt_itir__(self) -> itir.FunctionDefinition: - return self._frontend_transforms.foast_to_itir( - toolchain.CompilableProgram(self.foast_stage, arguments.CompileTimeArgs.empty()) - ) + return foast_to_itir.foast_to_itir(self.foast_stage) + + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR + def __gt_gtir__(self) -> itir.FunctionDefinition: + return foast_to_gtir.foast_to_gtir(self.foast_stage) def __gt_closure_vars__(self) -> dict[str, Any]: return self.foast_stage.closure_vars diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 0844f63286..312ac686a2 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -45,6 +45,11 @@ def __gt_type__(self) -> ts.CallableType: def __gt_itir__(self) -> itir.Expr: return self.foast_to_itir(self.definition) + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR + def __gt_gtir__(self) -> itir.Expr: + # backend should have self.foast_to_itir set to foast_to_gtir + return self.foast_to_itir(self.definition) + @dataclasses.dataclass(frozen=True) class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]): diff --git a/src/gt4py/next/ffront/gtcallable.py b/src/gt4py/next/ffront/gtcallable.py index beaebb3a5a..cdfb23910e 100644 --- a/src/gt4py/next/ffront/gtcallable.py +++ b/src/gt4py/next/ffront/gtcallable.py @@ -52,6 +52,16 @@ def __gt_itir__(self) -> itir.FunctionDefinition: """ ... + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR + @abc.abstractmethod + def __gt_gtir__(self) -> itir.FunctionDefinition: + """ + Return iterator IR function definition representing the callable. + Used internally by the Program decorator to populate the function + definitions of the iterator IR. + """ + ... + # TODO(tehrengruber): For embedded execution a `__call__` method and for # "truly" embedded execution arguably also a `from_function` method is # required. Since field operators currently have a `__gt_type__` with a diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index a20c517cce..14d705576e 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -80,11 +80,18 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra gt_callables = transform_utils._filter_closure_vars_by_type( all_closure_vars, gtcallable.GTCallable ).values() + + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR # TODO(ricoh): The following calls to .__gt_itir__, which will use whatever - # backend is set for each of these field operators (GTCallables). Instead - # we should use the current toolchain to lower these to ITIR. This will require - # making this step aware of the toolchain it is called by (it can be part of multiple). - lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] + # backend is set for each of these field operators (GTCallables). Instead + # we should use the current toolchain to lower these to ITIR. This will require + # making this step aware of the toolchain it is called by (it can be part of multiple). + lowered_funcs = [] + for gt_callable in gt_callables: + if to_gtir: + lowered_funcs.append(gt_callable.__gt_gtir__()) + else: + lowered_funcs.append(gt_callable.__gt_itir__()) itir_program = ProgramLowering.apply( inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type, to_gtir=to_gtir diff --git a/src/gt4py/next/iterator/transforms/__init__.py b/src/gt4py/next/iterator/transforms/__init__.py index 58678cfc9c..6f9651a397 100644 --- a/src/gt4py/next/iterator/transforms/__init__.py +++ b/src/gt4py/next/iterator/transforms/__init__.py @@ -6,7 +6,12 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.next.iterator.transforms.pass_manager import LiftMode, apply_common_transforms +from gt4py.next.iterator.transforms.pass_manager import ( + ITIRTransform, + LiftMode, + apply_common_transforms, + apply_fieldview_transforms, +) -__all__ = ["apply_common_transforms", "LiftMode"] +__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "LiftMode", "ITIRTransform"] diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index b3bb7bc6e1..7c35d552dc 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -7,11 +7,12 @@ # SPDX-License-Identifier: BSD-3-Clause import enum -from typing import Callable, Optional +from typing import Callable, Optional, Protocol from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import fencil_to_program, inline_fundefs +from gt4py.next.iterator.transforms import fencil_to_program, infer_domain, inline_fundefs from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.iterator.transforms.constant_folding import ConstantFolding @@ -29,6 +30,12 @@ from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce +class ITIRTransform(Protocol): + def __call__( + self, _: itir.Program | itir.FencilDefinition, *, offset_provider: common.OffsetProvider + ) -> itir.Program: ... + + @enum.unique class LiftMode(enum.Enum): FORCE_INLINE = enum.auto() @@ -65,7 +72,7 @@ def _inline_into_scan(ir, *, max_iter=10): # TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward # `lift_mode` and `temporary_extraction_heuristics` which is inconvenient. def apply_common_transforms( - ir: itir.Node, + ir: itir.Program | itir.FencilDefinition, *, lift_mode=None, offset_provider=None, @@ -115,10 +122,10 @@ def apply_common_transforms( # other cases we want it anyway. force_inline_trivial_lift_args=True, ) - inlined = ConstantFolding.apply(inlined) + inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # still a `itir.Program` # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply( + inlined = CollapseTuple.apply( # type: ignore[assignment] # still a `itir.Program` inlined, offset_provider=offset_provider, # TODO(tehrengruber): disabled since it increases compile-time too much right now @@ -167,7 +174,7 @@ def apply_common_transforms( # larger than the number of closure outputs as given by the unconditional collapse, we can # only run the unconditional version here instead of in the loop above. if unconditionally_collapse_tuples: - ir = CollapseTuple.apply( + ir = CollapseTuple.apply( # type: ignore[assignment] # still a `itir.Program` ir, ignore_tuple_size=True, offset_provider=offset_provider, @@ -188,7 +195,7 @@ def apply_common_transforms( unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) if unrolled == ir: break - ir = unrolled + ir = unrolled # type: ignore[assignment] # still a `itir.Program` ir = CollapseListGet().visit(ir) ir = NormalizeShifts().visit(ir) ir = _inline_lifts(ir, LiftMode.FORCE_INLINE) @@ -200,7 +207,7 @@ def apply_common_transforms( ir = ScanEtaReduction().visit(ir) if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) # type: ignore[type-var] # always an itir.Program + ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) ir = MergeLet().visit(ir) ir = InlineLambdas.apply( @@ -209,3 +216,17 @@ def apply_common_transforms( assert isinstance(ir, itir.Program) return ir + + +def apply_fieldview_transforms( + ir: itir.Program, *, offset_provider: common.OffsetProvider +) -> itir.Program: + ir = inline_fundefs.InlineFundefs().visit(ir) + ir = inline_fundefs.prune_unreferenced_fundefs(ir) + ir = InlineLambdas.apply(ir, opcount_preserving=True) + ir = infer_domain.infer_program( + ir, + offset_provider=offset_provider, + ) + ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # type is still `itir.Program` + return ir diff --git a/src/gt4py/next/program_processors/formatters/lisp.py b/src/gt4py/next/program_processors/formatters/lisp.py index c477795c34..7b722a7c1a 100644 --- a/src/gt4py/next/program_processors/formatters/lisp.py +++ b/src/gt4py/next/program_processors/formatters/lisp.py @@ -50,7 +50,7 @@ class ToLispLike(TemplatedGenerator): ) @classmethod - def apply(cls, root: itir.Node, **kwargs: Any) -> str: # type: ignore[override] + def apply(cls, root: itir.FencilDefinition, **kwargs: Any) -> str: # type: ignore[override] transformed = apply_common_transforms( root, lift_mode=kwargs.get("lift_mode"), offset_provider=kwargs["offset_provider"] ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index ffc33a9f25..f2953eb05f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -17,7 +17,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common, config -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import languages, recipes, stages, step_types, workflow from gt4py.next.otf.binding import interface from gt4py.next.otf.languages import LanguageSettings @@ -46,10 +46,8 @@ def generate_sdfg( offset_provider: common.OffsetProvider, column_axis: Optional[common.Dimension], ) -> dace.SDFG: - # TODO(edopao): Call IR transformations and domain inference, finally lower IR to SDFG - raise NotImplementedError - - return gtir_sdfg.build_sdfg_from_gtir(program=ir, offset_provider=offset_provider) + ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) + return gtir_sdfg.build_sdfg_from_gtir(ir=ir, offset_provider=offset_provider) def __call__( self, inp: stages.CompilableProgram diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 93e6d09c5b..57785ceb33 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -9,6 +9,7 @@ from __future__ import annotations import dataclasses +import functools import importlib.util import pathlib import tempfile @@ -20,7 +21,7 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import allocators as next_allocators, backend as next_backend, common, config -from gt4py.next.ffront import foast_to_gtir, past_to_itir +from gt4py.next.ffront import foast_to_gtir, foast_to_past, past_to_itir from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import stages, workflow from gt4py.next.type_system import type_specifications as ts @@ -90,11 +91,11 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: def fencil_generator( - ir: itir.Node, + ir: itir.Program | itir.FencilDefinition, debug: bool, - lift_mode: itir_transforms.LiftMode, use_embedded: bool, offset_provider: dict[str, common.Connectivity | common.Dimension], + transforms: itir_transforms.ITIRTransform, ) -> stages.CompiledProgram: """ Generate a directly executable fencil from an ITIR node. @@ -102,7 +103,6 @@ def fencil_generator( Arguments: ir: The iterator IR (ITIR) node. debug: Keep module source containing fencil implementation. - lift_mode: Change the way lifted function calls are evaluated. use_embedded: Directly use builtins from embedded backend instead of generic dispatcher. Gives faster performance and is easier to debug. @@ -110,15 +110,13 @@ def fencil_generator( """ # TODO(tehrengruber): just a temporary solution until we have a proper generic # caching mechanism - cache_key = hash((ir, lift_mode, debug, use_embedded, tuple(offset_provider.items()))) + cache_key = hash((ir, transforms, debug, use_embedded, tuple(offset_provider.items()))) if cache_key in _FENCIL_CACHE: if debug: print(f"Using cached fencil for key {cache_key}") return typing.cast(stages.CompiledProgram, _FENCIL_CACHE[cache_key]) - ir = itir_transforms.apply_common_transforms( - ir, lift_mode=lift_mode, offset_provider=offset_provider - ) + ir = transforms(ir, offset_provider=offset_provider) program = EmbeddedDSL.apply(ir) @@ -187,9 +185,9 @@ def fencil_generator( @dataclasses.dataclass(frozen=True) class Roundtrip(workflow.Workflow[stages.CompilableProgram, stages.CompiledProgram]): debug: Optional[bool] = None - lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE use_embedded: bool = True dispatch_backend: Optional[next_backend.Backend] = None + transforms: itir_transforms.ITIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` def __call__(self, inp: stages.CompilableProgram) -> stages.CompiledProgram: debug = config.DEBUG if self.debug is None else self.debug @@ -198,8 +196,8 @@ def __call__(self, inp: stages.CompilableProgram) -> stages.CompiledProgram: inp.data, offset_provider=inp.args.offset_provider, debug=debug, - lift_mode=self.lift_mode, use_embedded=self.use_embedded, + transforms=self.transforms, ) def decorated_fencil( @@ -224,28 +222,38 @@ def decorated_fencil( return decorated_fencil -executor = Roundtrip() -executor_with_temporaries = Roundtrip(lift_mode=itir_transforms.LiftMode.USE_TEMPORARIES) - default = next_backend.Backend( name="roundtrip", - executor=executor, + executor=Roundtrip( + transforms=functools.partial( + itir_transforms.apply_common_transforms, lift_mode=itir_transforms.LiftMode.FORCE_INLINE + ) + ), allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.DEFAULT_TRANSFORMS, ) with_temporaries = next_backend.Backend( name="roundtrip_with_temporaries", - executor=executor_with_temporaries, + executor=Roundtrip( + transforms=functools.partial( + itir_transforms.apply_common_transforms, + lift_mode=itir_transforms.LiftMode.USE_TEMPORARIES, + ) + ), allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.DEFAULT_TRANSFORMS, ) + gtir = next_backend.Backend( name="roundtrip_gtir", - executor=executor, + executor=Roundtrip(transforms=itir_transforms.apply_fieldview_transforms), # type: ignore[arg-type] # on purpose doesn't support `FencilDefintion` will resolve itself later... allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.Transforms( past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(cached=True), + field_view_op_to_prog=foast_to_past.operator_to_program_factory( + foast_to_itir_step=foast_to_gtir.adapted_foast_to_gtir_factory() + ), ), ) From 2927eba14ae68f40db927eeb4510fb97b6540141 Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 24 Oct 2024 16:00:22 +0200 Subject: [PATCH 015/178] fix[next][dace]: Avoid bool cast in branch condition expressions (#1707) Casting the condition expression on an inter-state edge to `bool` prevented dead-state elimination in the SDFG in case the expression was specialized to True/False. --- .../runners/dace_fieldview/gtir_builtin_translators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 277d8a0cd8..5e3a220caa 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -426,12 +426,12 @@ def translate_if( # expect true branch as second argument true_state = sdfg.add_state(state.label + "_true_branch") - sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(condition=f"bool({if_stmt})")) + sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(condition=f"{if_stmt}")) sdfg.add_edge(true_state, state, dace.InterstateEdge()) # and false branch as third argument false_state = sdfg.add_state(state.label + "_false_branch") - sdfg.add_edge(cond_state, false_state, dace.InterstateEdge(condition=(f"not bool({if_stmt})"))) + sdfg.add_edge(cond_state, false_state, dace.InterstateEdge(condition=(f"not ({if_stmt})"))) sdfg.add_edge(false_state, state, dace.InterstateEdge()) true_br_args = sdfg_builder.visit( From 77a8a6d8afbdda9045f3f5fa1035e4625d496ee5 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 25 Oct 2024 09:05:28 +0200 Subject: [PATCH 016/178] feat[next][dace]: Stop carrying reduce identity in args context (#1704) This PR implements an alternative design to local expressions on lists containing skip values. The previous design was based on the assumption that such operations only exist in the context of a reduction (which is the most common case, but necessarily the only one) and was carrying the reduce identity value from a reduce node to the visiting context of the arguments and all their child nodes. The reduce identity value was used till fill the skip values when building lists of neighbors or when applying map expressions. Although quite effective for reduce expressions, this design led to a very complicated code. Besides, this design was mixing lowering with optimization, by implicitly optimizing the SDFG. With this PR, we stop carrying the reduce identity to the arguments context. Instead, in presence of skip values we just write a dummy value. When it is time to reduce, we override the dummy value with the reduce identity. In this way, the reduce identity is only used in the context of the reduction expression. --- .../gtir_builtin_translators.py | 69 +----- .../runners/dace_fieldview/gtir_dataflow.py | 232 +++++++++++++----- .../runners/dace_fieldview/gtir_sdfg.py | 49 ++-- 3 files changed, 188 insertions(+), 162 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 5e3a220caa..b60c86b349 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -82,7 +82,6 @@ def __call__( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: """Creates the dataflow subgraph representing a GTIR primitive function. @@ -94,9 +93,6 @@ def __call__( sdfg: The SDFG where the primitive subgraph should be instantiated state: The SDFG state where the result of the primitive function should be made available sdfg_builder: The object responsible for visiting child nodes of the primitive node. - reduce_identity: The value of the reduction identity, in case the primitive node - is visited in the context of a reduction expression. This value is used - by the `neighbors` primitive to provide the default value of skip neighbors. Returns: A list of data access nodes and the associated GT4Py data type, which provide @@ -112,7 +108,6 @@ def _parse_fieldop_arg( state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, domain: FieldopDomain, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: """Helper method to visit an expression passed as argument to a field operator.""" @@ -120,7 +115,6 @@ def _parse_fieldop_arg( node, sdfg=sdfg, head_state=state, - reduce_identity=reduce_identity, ) # arguments passed to field operator should be plain fields, not tuples of fields @@ -213,7 +207,6 @@ def translate_as_fieldop( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: """ Generates the dataflow subgraph for the `as_fieldop` builtin function. @@ -244,7 +237,7 @@ def translate_as_fieldop( # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. - return translate_broadcast_scalar(node, sdfg, state, sdfg_builder, reduce_identity) + return translate_broadcast_scalar(node, sdfg, state, sdfg_builder) else: raise NotImplementedError( f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node." @@ -254,51 +247,11 @@ def translate_as_fieldop( domain = extract_domain(domain_expr) domain_indices = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain]) - # The reduction identity value is used in place of skip values when building - # a list of neighbor values in the unstructured domain. - # - # A reduction on neighbor values can be either expressed in local view (itir): - # vertices @ u⟨ Vertexₕ: [0, nvertices) ⟩ - # ← as_fieldop( - # λ(it) → reduce(plus, 0)(neighbors(V2Eₒ, it)), u⟨ Vertexₕ: [0, nvertices) ⟩ - # )(edges); - # - # or in field view (gtir): - # vertices @ u⟨ Vertexₕ: [0, nvertices) ⟩ - # ← as_fieldop(λ(it) → reduce(plus, 0)(·it), u⟨ Vertexₕ: [0, nvertices) ⟩)( - # as_fieldop(λ(it) → neighbors(V2Eₒ, it), u⟨ Vertexₕ: [0, nvertices) ⟩)(edges) - # ); - # - # In local view, the list of neighbors is (recursively) built while visiting - # the current expression. - # In field view, the list of neighbors is built as argument to the current - # expression. Therefore, the reduction identity value needs to be passed to - # the argument visitor (`reduce_identity_for_args = reduce_identity`). - if cpm.is_applied_reduce(stencil_expr.expr): - if reduce_identity is not None: - raise NotImplementedError("Nested reductions are not supported.") - _, _, reduce_identity_for_args = gtir_dataflow.get_reduce_params(stencil_expr.expr) - elif cpm.is_call_to(stencil_expr.expr, "neighbors"): - # When the visitor hits a neighbors expression, we stop carrying the reduce - # identity further (`reduce_identity_for_args = None`) because the reduce - # identity value is filled in place of skip values in the context of neighbors - # itself, not in the arguments context. - # Besides, setting `reduce_identity_for_args = None` enables a sanity check - # that the sequence 'reduce(V2E) -> neighbors(V2E) -> reduce(C2E) -> neighbors(C2E)' - # is accepted, while 'reduce(V2E) -> reduce(C2E) -> neighbors(V2E) -> neighbors(C2E)' - # is not. The latter sequence would raise the 'NotImplementedError' exception above. - reduce_identity_for_args = None - else: - reduce_identity_for_args = reduce_identity - # visit the list of arguments to be passed to the lambda expression - stencil_args = [ - _parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain, reduce_identity_for_args) - for arg in node.args - ] + stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder, reduce_identity) + taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) input_edges, output = taskgen.visit(stencil_expr, args=stencil_args) output_desc = output.result.dc_node.desc(sdfg) @@ -340,7 +293,6 @@ def translate_broadcast_scalar( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: """ Generates the dataflow subgraph for the 'as_fieldop' builtin function for the @@ -363,9 +315,7 @@ def translate_broadcast_scalar( assert len(node.args) == 1 assert isinstance(node.args[0].type, ts.ScalarType) - scalar_expr = _parse_fieldop_arg( - node.args[0], sdfg, state, sdfg_builder, domain, reduce_identity=None - ) + scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain) assert isinstance(scalar_expr, gtir_dataflow.MemletExpr) assert scalar_expr.subset == sbs.Indices.from_string("0") result = gtir_dataflow.DataflowOutputEdge( @@ -396,7 +346,6 @@ def translate_if( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: """Generates the dataflow subgraph for the `if_` builtin function.""" assert cpm.is_call_to(node, "if_") @@ -438,13 +387,11 @@ def translate_if( true_expr, sdfg=sdfg, head_state=true_state, - reduce_identity=reduce_identity, ) false_br_args = sdfg_builder.visit( false_expr, sdfg=sdfg, head_state=false_state, - reduce_identity=reduce_identity, ) def make_temps(output_data: FieldopData) -> FieldopData: @@ -563,7 +510,6 @@ def translate_literal( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: """Generates the dataflow subgraph for a `ir.Literal` node.""" assert isinstance(node, gtir.Literal) @@ -579,7 +525,6 @@ def translate_make_tuple( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: assert cpm.is_call_to(node, "make_tuple") return tuple( @@ -587,7 +532,6 @@ def translate_make_tuple( arg, sdfg=sdfg, head_state=state, - reduce_identity=reduce_identity, ) for arg in node.args ) @@ -598,7 +542,6 @@ def translate_tuple_get( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: assert cpm.is_call_to(node, "tuple_get") assert len(node.args) == 2 @@ -612,7 +555,6 @@ def translate_tuple_get( node.args[1], sdfg=sdfg, head_state=state, - reduce_identity=reduce_identity, ) if isinstance(data_nodes, FieldopData): raise ValueError(f"Invalid tuple expression {node}") @@ -630,7 +572,6 @@ def translate_scalar_expr( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: assert isinstance(node, gtir.FunCall) assert isinstance(node.type, ts.ScalarType) @@ -659,7 +600,6 @@ def translate_scalar_expr( arg_expr, sdfg=sdfg, head_state=state, - reduce_identity=reduce_identity, ) if not (isinstance(arg, FieldopData) and isinstance(arg.gt_dtype, ts.ScalarType)): raise ValueError(f"Invalid argument to scalar expression {arg_expr}.") @@ -714,7 +654,6 @@ def translate_symbol_ref( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: """Generates the dataflow subgraph for a `ir.SymRef` node.""" assert isinstance(node, gtir.SymRef) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 4f6a1e04c6..416321c038 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -26,7 +26,7 @@ gtir_sdfg, utility as dace_gtir_utils, ) -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_info as ti, type_specifications as ts @dataclasses.dataclass(frozen=True) @@ -250,7 +250,6 @@ class LambdaToDataflow(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState subgraph_builder: gtir_sdfg.DataflowBuilder - reduce_identity: Optional[SymbolExpr] input_edges: list[DataflowInputEdge] symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] @@ -259,12 +258,10 @@ def __init__( sdfg: dace.SDFG, state: dace.SDFGState, subgraph_builder: gtir_sdfg.DataflowBuilder, - reduce_identity: Optional[SymbolExpr], ): self.sdfg = sdfg self.state = state self.subgraph_builder = subgraph_builder - self.reduce_identity = reduce_identity self.input_edges = [] self.symbol_map = {} @@ -351,6 +348,10 @@ def _add_mapped_tasklet( name, self.state, map_ranges, inputs, code, outputs, **kwargs ) + def unique_nsdfg_name(self, prefix: str) -> str: + """Utility function to generate a unique name for a nested SDFG, starting with the given prefix.""" + return self.subgraph_builder.unique_nsdfg_name(self.sdfg, prefix) + def _construct_local_view(self, field: MemletExpr | ValueExpr) -> ValueExpr: if isinstance(field, MemletExpr): desc = field.dc_node.desc(self.sdfg) @@ -576,12 +577,15 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: } if offset_provider.has_skip_values: - if self.reduce_identity is None: - raise ValueError( - f"Found local offset '{offset}' with skip values, but 'reduce_identity' is not set." - ) - assert self.reduce_identity.dc_dtype == field_desc.dtype - tasklet_expression += f" if {index_connector} != {gtx_common._DEFAULT_SKIP_VALUE} else {field_desc.dtype}({self.reduce_identity.value})" + # in case of skip value we can write any dummy value + skip_value = ( + "math.nan" + if ti.is_floating_point(node.type.element_type) + else str(dace.dtypes.max_value(field_desc.dtype)) + ) + tasklet_expression += ( + f" if {index_connector} != {gtx_common._DEFAULT_SKIP_VALUE} else {skip_value}" + ) self._add_mapped_tasklet( name=f"{offset}_neighbors", @@ -713,16 +717,20 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: ) ) - if self.reduce_identity is None: - raise ValueError( - f"Found local offset '{result_offset}' with skip values, but 'reduce_identity' is not set." - ) - assert self.reduce_identity.dc_dtype == dc_dtype input_memlets["__neighbor_idx"] = dace.Memlet( data=connectivity_slice.dc_node.data, subset=map_index ) input_nodes[connectivity_slice.dc_node.data] = connectivity_slice.dc_node - tasklet_expression += f" if __neighbor_idx != {gtx_common._DEFAULT_SKIP_VALUE} else {dc_dtype}({self.reduce_identity.value})" + + # in case of skip value we can write any dummy value + skip_value = ( + "math.nan" + if ti.is_floating_point(node.type.element_type) + else str(dace.dtypes.max_value(dc_dtype)) + ) + tasklet_expression += ( + f" if __neighbor_idx != {gtx_common._DEFAULT_SKIP_VALUE} else {skip_value}" + ) self._add_mapped_tasklet( name="map", @@ -739,66 +747,158 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: return ValueExpr(result_node, dc_dtype, result_offset) - def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: - assert isinstance(node.type, ts.ScalarType) - op_name, reduce_init, reduce_identity = get_reduce_params(node) + def _make_reduce_with_skip_values( + self, + input_expr: ValueExpr | MemletExpr, + offset_provider: gtx_common.Connectivity, + reduce_init: SymbolExpr, + reduce_identity: SymbolExpr, + reduce_wcr: str, + result_node: dace.nodes.AccessNode, + ) -> None: + """ + Helper method to lower reduction on a local field containing skip values. + + The reduction is implemented as a nested SDFG containing 2 states. In first + state, the result (a scalar data node passed as argumet) is initialized. + In second state, a mapped tasklet uses a write-conflict resolution (wcr) + memlet to update the result. + We use the offset provider as a mask to identify skip values: the value + that is written to the result node is either the input value, when the + corresponding neighbor index in the connectivity table is valid, or the + identity value if the neighbor index is missing. + """ + origin_map_index = dace_gtir_utils.get_map_variable(offset_provider.origin_axis) + + assert input_expr.local_offset is not None + connectivity = dace_utils.connectivity_identifier(input_expr.local_offset) + connectivity_node = self.state.add_access(connectivity) + connectivity_desc = connectivity_node.desc(self.sdfg) + connectivity_desc.transient = False - # The input to reduction is a list of elements on a local dimension. - # This list is provided by an argument that typically calls the neighbors - # builtin function, to built a list of neighbor values for each element - # in the field target dimension. - # We store the value of reduce identity in the visitor context to have it - # available while visiting the input to reduction; this value might be used - # by the `neighbors` visitor to fill the skip values in the neighbors list. - prev_reduce_identity = self.reduce_identity - self.reduce_identity = reduce_identity - - try: - input_expr = self.visit(node.args[0]) - finally: - # ensure that we leave the visitor in the same state as we entered - self.reduce_identity = prev_reduce_identity - - assert isinstance(input_expr, MemletExpr | ValueExpr) - input_desc = input_expr.dc_node.desc(self.sdfg) - assert isinstance(input_desc, dace.data.Array) - - if len(input_desc.shape) > 1: - assert isinstance(input_expr, MemletExpr) - ndims = len(input_desc.shape) - 1 - # the axis to be reduced is always the last one, because `reduce` is supposed - # to operate on `ListType` - assert set(input_expr.subset.size()[0:ndims]) == {1} - reduce_axes = [ndims] + desc = input_expr.dc_node.desc(self.sdfg) + if isinstance(input_expr, MemletExpr): + local_dim_indices = [i for i, size in enumerate(input_expr.subset.size()) if size != 1] else: - reduce_axes = None + local_dim_indices = list(range(len(desc.shape))) - reduce_wcr = "lambda x, y: " + gtir_python_codegen.format_builtin(op_name, "x", "y") - reduce_node = self.state.add_reduce(reduce_wcr, reduce_axes, reduce_init.value) + if len(local_dim_indices) != 1: + raise NotImplementedError( + f"Found {len(local_dim_indices)} local dimensions in reduce expression, expected one." + ) + local_dim_index = local_dim_indices[0] + assert desc.shape[local_dim_index] == offset_provider.max_neighbors + + # we lower the reduction map with WCR out memlet in a nested SDFG + nsdfg = dace.SDFG(name=self.unique_nsdfg_name("reduce_with_skip_values")) + nsdfg.add_array( + "values", + (desc.shape[local_dim_index],), + desc.dtype, + strides=(desc.strides[local_dim_index],), + ) + nsdfg.add_array( + "neighbor_indices", + (connectivity_desc.shape[1],), + connectivity_desc.dtype, + strides=(connectivity_desc.strides[1],), + ) + nsdfg.add_scalar("acc", desc.dtype) + st_init = nsdfg.add_state(f"{nsdfg.label}_init") + st_init.add_edge( + st_init.add_tasklet( + "init_acc", + {}, + {"__val"}, + f"__val = {reduce_init.dc_dtype}({reduce_init.value})", + ), + "__val", + st_init.add_access("acc"), + None, + dace.Memlet(data="acc", subset="0"), + ) + st_reduce = nsdfg.add_state_after(st_init, f"{nsdfg.label}_reduce") + # Fill skip values in local dimension with the reduce identity value + skip_value = f"{reduce_identity.dc_dtype}({reduce_identity.value})" + # Since this map operates on a pure local dimension, we explicitly set sequential + # schedule and we set the flag 'wcr_nonatomic=True' on the write memlet. + # TODO(phimuell): decide if auto-optimizer should reset `wcr_nonatomic` properties, as DaCe does. + st_reduce.add_mapped_tasklet( + name="reduce_with_skip_values", + map_ranges={"i": f"0:{offset_provider.max_neighbors}"}, + inputs={ + "__val": dace.Memlet(data="values", subset="i"), + "__neighbor_idx": dace.Memlet(data="neighbor_indices", subset="i"), + }, + code=f"__out = __val if __neighbor_idx != {gtx_common._DEFAULT_SKIP_VALUE} else {skip_value}", + outputs={ + "__out": dace.Memlet(data="acc", subset="0", wcr=reduce_wcr, wcr_nonatomic=True), + }, + external_edges=True, + schedule=dace.dtypes.ScheduleType.Sequential, + ) + + nsdfg_node = self.state.add_nested_sdfg( + nsdfg, self.sdfg, inputs={"values", "neighbor_indices"}, outputs={"acc"} + ) if isinstance(input_expr, MemletExpr): - self._add_input_data_edge( - input_expr.dc_node, - input_expr.subset, - reduce_node, - ) + self._add_input_data_edge(input_expr.dc_node, input_expr.subset, nsdfg_node, "values") else: - self.state.add_nedge( + self.state.add_edge( input_expr.dc_node, - reduce_node, - dace.Memlet.from_array(input_expr.dc_node.data, input_desc), + None, + nsdfg_node, + "values", + self.sdfg.make_array_memlet(input_expr.dc_node.data), ) + self._add_input_data_edge( + connectivity_node, + sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider.max_neighbors}"), + nsdfg_node, + "neighbor_indices", + ) + self.state.add_edge( + nsdfg_node, + "acc", + result_node, + None, + dace.Memlet(data=result_node.data, subset="0"), + ) - temp_name = self.sdfg.temp_data_name() - self.sdfg.add_scalar(temp_name, reduce_identity.dc_dtype, transient=True) - temp_node = self.state.add_access(temp_name) + def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: + assert isinstance(node.type, ts.ScalarType) + op_name, reduce_init, reduce_identity = get_reduce_params(node) + reduce_wcr = "lambda x, y: " + gtir_python_codegen.format_builtin(op_name, "x", "y") - self.state.add_nedge( - reduce_node, - temp_node, - dace.Memlet(data=temp_name, subset="0"), - ) - return ValueExpr(temp_node, node.type) + result = self.sdfg.temp_data_name() + self.sdfg.add_scalar(result, reduce_identity.dc_dtype, transient=True) + result_node = self.state.add_access(result) + + input_expr = self.visit(node.args[0]) + assert isinstance(input_expr, (MemletExpr, ValueExpr)) + assert input_expr.local_offset is not None + offset_provider = self.subgraph_builder.get_offset_provider(input_expr.local_offset) + assert isinstance(offset_provider, gtx_common.Connectivity) + + if offset_provider.has_skip_values: + self._make_reduce_with_skip_values( + input_expr, offset_provider, reduce_init, reduce_identity, reduce_wcr, result_node + ) + + else: + reduce_node = self.state.add_reduce(reduce_wcr, axes=None, identity=reduce_init.value) + if isinstance(input_expr, MemletExpr): + self._add_input_data_edge(input_expr.dc_node, input_expr.subset, reduce_node) + else: + self.state.add_nedge( + input_expr.dc_node, + reduce_node, + self.sdfg.make_array_memlet(input_expr.dc_node.data), + ) + self.state.add_nedge(reduce_node, result_node, dace.Memlet(data=result, subset="0")) + + return ValueExpr(result_node, node.type) def _split_shift_args( self, args: list[gtir.Expr] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 4627293cfd..0adcc95cf1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -32,7 +32,6 @@ from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_builtin_translators, - gtir_dataflow, utility as dace_gtir_utils, ) from gt4py.next.type_system import type_specifications as ts, type_translation as tt @@ -44,6 +43,9 @@ class DataflowBuilder(Protocol): @abc.abstractmethod def get_offset_provider(self, offset: str) -> gtx_common.OffsetProviderElem: ... + @abc.abstractmethod + def unique_nsdfg_name(self, sdfg: dace.SDFG, prefix: str) -> str: ... + @abc.abstractmethod def unique_map_name(self, name: str) -> str: ... @@ -168,6 +170,12 @@ def get_offset_provider(self, offset: str) -> gtx_common.OffsetProviderElem: def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.global_symbols[symbol_name] + def unique_nsdfg_name(self, sdfg: dace.SDFG, prefix: str) -> str: + nsdfg_list = [ + nsdfg.label for nsdfg in sdfg.all_sdfgs_recursive() if nsdfg.label.startswith(prefix) + ] + return f"{prefix}_{len(nsdfg_list)}" + def unique_map_name(self, name: str) -> str: return f"{self.map_uids.sequential_id()}_{name}" @@ -299,7 +307,7 @@ def _visit_expression( Returns: A list of array nodes containing the result fields. """ - result = self.visit(node, sdfg=sdfg, head_state=head_state, reduce_identity=None) + result = self.visit(node, sdfg=sdfg, head_state=head_state) # sanity check: each statement should preserve the property of single exit state (aka head state), # i.e. eventually only introduce internal branches, and keep the same head state @@ -491,32 +499,22 @@ def visit_FunCall( node: gtir.FunCall, sdfg: dace.SDFG, head_state: dace.SDFGState, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> gtir_builtin_translators.FieldopResult: # use specialized dataflow builder classes for each builtin function if cpm.is_call_to(node, "if_"): - return gtir_builtin_translators.translate_if( - node, sdfg, head_state, self, reduce_identity - ) + return gtir_builtin_translators.translate_if(node, sdfg, head_state, self) elif cpm.is_call_to(node, "make_tuple"): - return gtir_builtin_translators.translate_make_tuple( - node, sdfg, head_state, self, reduce_identity - ) + return gtir_builtin_translators.translate_make_tuple(node, sdfg, head_state, self) elif cpm.is_call_to(node, "tuple_get"): - return gtir_builtin_translators.translate_tuple_get( - node, sdfg, head_state, self, reduce_identity - ) + return gtir_builtin_translators.translate_tuple_get(node, sdfg, head_state, self) elif cpm.is_applied_as_fieldop(node): - return gtir_builtin_translators.translate_as_fieldop( - node, sdfg, head_state, self, reduce_identity - ) + return gtir_builtin_translators.translate_as_fieldop(node, sdfg, head_state, self) elif isinstance(node.fun, gtir.Lambda): lambda_args = [ self.visit( arg, sdfg=sdfg, head_state=head_state, - reduce_identity=reduce_identity, ) for arg in node.args ] @@ -525,13 +523,10 @@ def visit_FunCall( node.fun, sdfg=sdfg, head_state=head_state, - reduce_identity=reduce_identity, args=lambda_args, ) elif isinstance(node.type, ts.ScalarType): - return gtir_builtin_translators.translate_scalar_expr( - node, sdfg, head_state, self, reduce_identity - ) + return gtir_builtin_translators.translate_scalar_expr(node, sdfg, head_state, self) else: raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") @@ -540,7 +535,6 @@ def visit_Lambda( node: gtir.Lambda, sdfg: dace.SDFG, head_state: dace.SDFGState, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], args: list[gtir_builtin_translators.FieldopResult], ) -> gtir_builtin_translators.FieldopResult: """ @@ -569,7 +563,7 @@ def visit_Lambda( # lower let-statement lambda node as a nested SDFG lambda_translator = GTIRToSDFG(self.offset_provider, lambda_symbols) - nsdfg = dace.SDFG(f"{sdfg.label}_lambda") + nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) nstate = nsdfg.add_state("lambda") # add sdfg storage for the symbols that need to be passed as input parameters @@ -585,7 +579,6 @@ def visit_Lambda( node.expr, sdfg=nsdfg, head_state=nstate, - reduce_identity=reduce_identity, ) def _flatten_tuples( @@ -716,22 +709,16 @@ def visit_Literal( node: gtir.Literal, sdfg: dace.SDFG, head_state: dace.SDFGState, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> gtir_builtin_translators.FieldopResult: - return gtir_builtin_translators.translate_literal( - node, sdfg, head_state, self, reduce_identity=None - ) + return gtir_builtin_translators.translate_literal(node, sdfg, head_state, self) def visit_SymRef( self, node: gtir.SymRef, sdfg: dace.SDFG, head_state: dace.SDFGState, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> gtir_builtin_translators.FieldopResult: - return gtir_builtin_translators.translate_symbol_ref( - node, sdfg, head_state, self, reduce_identity=None - ) + return gtir_builtin_translators.translate_symbol_ref(node, sdfg, head_state, self) def build_sdfg_from_gtir( From eb05a0a29374a4a0fd9c9d65b0c39cf5ff2555be Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 25 Oct 2024 09:51:39 +0200 Subject: [PATCH 017/178] feat[next][dace]: Fix lowering of nested let-statements (#1697) This PR fixes one corner case of nested let-statements, discovered in `test_tuple_unpacking_star_multi` during GTIR integration. Test case added. Additionally, fixed handling of symbol already defined in SDFG for #1695. --- .../runners/dace_fieldview/gtir_sdfg.py | 122 +++++++++++------- .../dace_tests/test_gtir_to_sdfg.py | 53 +++++++- 2 files changed, 125 insertions(+), 50 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 0adcc95cf1..28eef5c260 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -266,18 +266,20 @@ def _add_storage( elif isinstance(gt_type, ts.ScalarType): dc_dtype = dace_utils.as_dace_type(gt_type) if name in symbolic_arguments: - sdfg.add_symbol(name, dc_dtype) - elif dace_utils.is_field_symbol(name): - # Sometimes, when the field domain is implicitly derived from the - # field domain, the gt4py lowering adds the field size as a scalar - # argument to the program IR. Suppose a field '__sym', then gt4py - # will add '__sym_size_0'. - # Therefore, here we check whether the shape symbol was already - # created by `_make_array_shape_and_strides`, when allocating - # storage for field arguments. We assume that the scalar argument - # for field size, if present, always follows the field argument. if name in sdfg.symbols: - assert sdfg.symbols[name].dc_dtype == dc_dtype + # Sometimes, when the field domain is implicitly derived from the + # field domain, the gt4py lowering adds the field size as a scalar + # argument to the program IR. Suppose a field '__sym', then gt4py + # will add '__sym_size_0'. + # Therefore, here we check whether the shape symbol was already + # created by `_make_array_shape_and_strides()`, when allocating + # storage for field arguments. We assume that the scalar argument + # for field size, if present, always follows the field argument. + assert dace_utils.is_field_symbol(name) + if sdfg.symbols[name].dtype != dc_dtype: + raise ValueError( + f"Type mismatch on argument {name}: got {dc_dtype}, expected {sdfg.symbols[name].dtype}." + ) else: sdfg.add_symbol(name, dc_dtype) else: @@ -599,6 +601,9 @@ def _flatten_tuples( # Process lambda inputs # + # All input arguments are passed as parameters to the nested SDFG, therefore + # we they are stored as non-transient array and scalar objects. + # lambda_arg_nodes = dict( itertools.chain(*[_flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) ) @@ -639,68 +644,89 @@ def _flatten_tuples( # Process lambda outputs # + # The output arguments do not really exist, so they are not allocated before + # visiting the lambda expression. Therefore, the result appears inside the + # nested SDFG as transient array/scalar storage. The exception is given by + # input arguments that are just passed through and returned by the lambda, + # e.g. when the lambda is constructing a tuple: in this case, the result + # data is non-transient, because it corresponds to an input node. + # The transient storage of the lambda result in nested-SDFG is corrected + # below by the call to `make_temps()`: this function ensures that the result + # transient nodes are changed to non-transient and the corresponding output + # connecters on the nested SDFG are connected to new data nodes in parent SDFG. + # lambda_output_data: Iterable[gtir_builtin_translators.FieldopData] = ( gtx_utils.flatten_nested_tuple(lambda_result) ) - # sanity check on isolated nodes - assert all( - nstate.degree(output_data.dc_node) == 0 - for output_data in lambda_output_data - if output_data.dc_node.data in input_memlets - ) - # keep only non-isolated output nodes + # The output connectors only need to be setup for the actual result of the + # internal dataflow that writes to transient nodes. + # We filter out the non-transient nodes because they are already available + # in the current context. Later these nodes will eventually be removed + # from the nested SDFG because they are isolated (see `make_temps()`). lambda_outputs = { output_data.dc_node.data for output_data in lambda_output_data - if output_data.dc_node.data not in input_memlets + if output_data.dc_node.desc(nsdfg).transient } - if lambda_outputs: - nsdfg_node = head_state.add_nested_sdfg( - nsdfg, - parent=sdfg, - inputs=set(input_memlets.keys()), - outputs=lambda_outputs, - symbol_mapping=nsdfg_symbols_mapping, - debuginfo=dace_utils.debug_info(node, default=sdfg.debuginfo), - ) + nsdfg_node = head_state.add_nested_sdfg( + nsdfg, + parent=sdfg, + inputs=set(input_memlets.keys()), + outputs=lambda_outputs, + symbol_mapping=nsdfg_symbols_mapping, + debuginfo=dace_utils.debug_info(node, default=sdfg.debuginfo), + ) - for connector, memlet in input_memlets.items(): - if connector in lambda_arg_nodes: - src_node = lambda_arg_nodes[connector].dc_node - else: - src_node = head_state.add_access(memlet.data) + for connector, memlet in input_memlets.items(): + if connector in lambda_arg_nodes: + src_node = lambda_arg_nodes[connector].dc_node + else: + src_node = head_state.add_access(memlet.data) - head_state.add_edge(src_node, None, nsdfg_node, connector, memlet) + head_state.add_edge(src_node, None, nsdfg_node, connector, memlet) def make_temps( output_data: gtir_builtin_translators.FieldopData, ) -> gtir_builtin_translators.FieldopData: - if output_data.dc_node.data in lambda_outputs: - connector = output_data.dc_node.data - desc = output_data.dc_node.desc(nsdfg) - # make lambda result non-transient and map it to external temporary + """ + This function will be called while traversing the result of the lambda + dataflow to setup the intermediate data nodes in the parent SDFG and + the data edges from the nested-SDFG output connectors. + """ + desc = output_data.dc_node.desc(nsdfg) + if desc.transient: + # Transient nodes actually contain some result produced by the dataflow + # itself, therefore these nodes are changed to non-transient and an output + # edge will write the result from the nested-SDFG to a new intermediate + # data node in the parent context. desc.transient = False - # isolated access node will make validation fail - if nstate.degree(output_data.dc_node) == 0: - nstate.remove_node(output_data.dc_node) temp, _ = sdfg.add_temp_transient_like(desc) + connector = output_data.dc_node.data dst_node = head_state.add_access(temp) head_state.add_edge( nsdfg_node, connector, dst_node, None, sdfg.make_array_memlet(temp) ) - return gtir_builtin_translators.FieldopData( + temp_field = gtir_builtin_translators.FieldopData( dst_node, output_data.gt_dtype, output_data.local_offset ) elif output_data.dc_node.data in lambda_arg_nodes: - nstate.remove_node(output_data.dc_node) - return lambda_arg_nodes[output_data.dc_node.data] + # This if branch and the next one handle the non-transient result nodes. + # Non-transient nodes are just input nodes that are immediately returned + # by the lambda expression. Therefore, these nodes are already available + # in the parent context and can be directly accessed there. + temp_field = lambda_arg_nodes[output_data.dc_node.data] else: - nstate.remove_node(output_data.dc_node) - data_node = head_state.add_access(output_data.dc_node.data) - return gtir_builtin_translators.FieldopData( - data_node, output_data.gt_dtype, output_data.local_offset + dc_node = head_state.add_access(output_data.dc_node.data) + temp_field = gtir_builtin_translators.FieldopData( + dc_node, output_data.gt_dtype, output_data.local_offset ) + # Isolated access node will make validation fail. + # Isolated access nodes can be found in the join-state of an if-expression + # or in lambda expressions that just construct tuples from input arguments. + if nstate.degree(output_data.dc_node) == 0: + nstate.remove_node(output_data.dc_node) + return temp_field return gtx_utils.tree_map(make_temps)(lambda_result) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 2c1c04ce1f..5377654b55 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1772,10 +1772,10 @@ def test_gtir_let_lambda_with_cond(): assert np.allclose(b, a if s else a * 2) -def test_gtir_let_lambda_with_tuple(): +def test_gtir_let_lambda_with_tuple1(): domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) testee = gtir.Program( - id="let_lambda_with_tuple", + id="let_lambda_with_tuple1", function_definitions=[], params=[ gtir.Sym(id="x", type=IFTYPE), @@ -1816,6 +1816,55 @@ def test_gtir_let_lambda_with_tuple(): assert np.allclose(z_fields[1], b) +def test_gtir_let_lambda_with_tuple2(): + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + val = np.random.rand() + testee = gtir.Program( + id="let_lambda_with_tuple2", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="z", type=ts.TupleType(types=[IFTYPE, IFTYPE, IFTYPE])), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let("s", im.as_fieldop("deref", domain)(val))( + im.let("t", im.make_tuple("x", "y"))( + im.let("p", im.op_as_fieldop("plus", domain)("x", "y"))( + im.make_tuple("p", "s", im.tuple_get(1, "t")) + ) + ) + ), + domain=domain, + target=gtir.SymRef(id="z"), + ) + ], + ) + + a = np.random.rand(N) + b = np.random.rand(N) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + + z_fields = (np.empty_like(a), np.empty_like(a), np.empty_like(a)) + z_symbols = dict( + __z_0_size_0=FSYMBOLS["__x_size_0"], + __z_0_stride_0=FSYMBOLS["__x_stride_0"], + __z_1_size_0=FSYMBOLS["__x_size_0"], + __z_1_stride_0=FSYMBOLS["__x_stride_0"], + __z_2_size_0=FSYMBOLS["__x_size_0"], + __z_2_stride_0=FSYMBOLS["__x_stride_0"], + ) + + sdfg(a, b, *z_fields, **FSYMBOLS, **z_symbols) + assert np.allclose(z_fields[0], a + b) + assert np.allclose(z_fields[1], val) + assert np.allclose(z_fields[2], b) + + def test_gtir_if_scalars(): domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) testee = gtir.Program( From db249bdbc2d82dc504a708122eb65ea99c610c61 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 25 Oct 2024 12:21:21 +0200 Subject: [PATCH 018/178] bug[next][dace]: Fix lowering of broadcast (#1698) Fix lowering of `as_fieldop` with broadcast expression after changes in PR #1701. Additional change: - Add support for GT4Py zero-dimensional fields, equivalent of numpy zero-dimensional arrays. Test case added. --- .../runners/dace_common/dace_backend.py | 7 + .../gtir_builtin_translators.py | 93 ++++++++--- .../runners/dace_fieldview/gtir_dataflow.py | 150 +++++++++--------- .../runners/dace_fieldview/gtir_sdfg.py | 4 + .../ffront_tests/test_execution.py | 1 + .../dace_tests/test_gtir_to_sdfg.py | 29 ++++ 6 files changed, 187 insertions(+), 97 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index 6039c82fdb..5d3cc7a358 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -26,6 +26,13 @@ def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: bool) -> Any: if not isinstance(arg, gtx_common.Field): return arg + if len(arg.domain.dims) == 0: + # Pass zero-dimensional fields as scalars. + # We need to extract the scalar value from the 0d numpy array without changing its type. + # Note that 'ndarray.item()' always transforms the numpy scalar to a python scalar, + # which may change its precision. To avoid this, we use here the empty tuple as index + # for 'ndarray.__getitem__()'. + return arg.ndarray[()] # field domain offsets are not supported non_zero_offsets = [ (dim, dim_range) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index b60c86b349..bb37440fe2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -138,6 +138,32 @@ def _parse_fieldop_arg( raise NotImplementedError(f"Node type {type(arg.gt_dtype)} not supported.") +def _get_field_shape( + domain: FieldopDomain, +) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr]]: + """ + Parse the field operator domain and generates the shape of the result field. + + It should be enough to allocate an array with shape (upper_bound - lower_bound) + but this would require to use array offset for compensate for the start index. + Suppose that a field operator executes on domain [2,N-2], the dace array to store + the result only needs size (N-4), but this would require to compensate all array + accesses with offset -2 (which corresponds to -lower_bound). Instead, we choose + to allocate (N-2), leaving positions [0:2] unused. The reason is that array offset + is known to cause issues to SDFG inlining. Besides, map fusion will in any case + eliminate most of transient arrays. + + Args: + domain: The field operator domain. + + Returns: + A tuple of two lists: the list of field dimensions and the list of dace + array sizes in each dimension. + """ + domain_dims, _, domain_ubs = zip(*domain) + return list(domain_dims), list(domain_ubs) + + def _create_temporary_field( sdfg: dace.SDFG, state: dace.SDFGState, @@ -146,17 +172,7 @@ def _create_temporary_field( dataflow_output: gtir_dataflow.DataflowOutputEdge, ) -> FieldopData: """Helper method to allocate a temporary field where to write the output of a field operator.""" - domain_dims, _, domain_ubs = zip(*domain) - field_dims = list(domain_dims) - # It should be enough to allocate an array with shape (upper_bound - lower_bound) - # but this would require to use array offset for compensate for the start index. - # Suppose that a field operator executes on domain [2,N-2], the dace array to store - # the result only needs size (N-4), but this would require to compensate all array - # accesses with offset -2 (which corresponds to -lower_bound). Instead, we choose - # to allocate (N-2), leaving positions [0:2] unused. The reason is that array offset - # is known to cause issues to SDFG inlining. Besides, map fusion will in any case - # eliminate most of transient arrays. - field_shape = list(domain_ubs) + field_dims, field_shape = _get_field_shape(domain) output_desc = dataflow_output.result.dc_node.desc(sdfg) if isinstance(output_desc, dace.data.Array): @@ -311,17 +327,46 @@ def translate_broadcast_scalar( assert cpm.is_ref_to(stencil_expr, "deref") domain = extract_domain(domain_expr) - domain_indices = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain]) + field_dims, field_shape = _get_field_shape(domain) + field_subset = sbs.Range.from_string( + ",".join(dace_gtir_utils.get_map_variable(dim) for dim in field_dims) + ) assert len(node.args) == 1 - assert isinstance(node.args[0].type, ts.ScalarType) scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain) - assert isinstance(scalar_expr, gtir_dataflow.MemletExpr) - assert scalar_expr.subset == sbs.Indices.from_string("0") - result = gtir_dataflow.DataflowOutputEdge( - state, gtir_dataflow.ValueExpr(scalar_expr.dc_node, node.args[0].type) - ) - result_field = _create_temporary_field(sdfg, state, domain, node.type, dataflow_output=result) + + if isinstance(node.args[0].type, ts.ScalarType): + assert isinstance(scalar_expr, (gtir_dataflow.MemletExpr, gtir_dataflow.ValueExpr)) + input_subset = ( + str(scalar_expr.subset) if isinstance(scalar_expr, gtir_dataflow.MemletExpr) else "0" + ) + input_node = scalar_expr.dc_node + gt_dtype = node.args[0].type + elif isinstance(node.args[0].type, ts.FieldType): + assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr) + if len(node.args[0].type.dims) == 0: # zero-dimensional field + input_subset = "0" + elif all( + isinstance(scalar_expr.indices[dim], gtir_dataflow.SymbolExpr) + for dim in scalar_expr.dimensions + if dim not in field_dims + ): + input_subset = ",".join( + dace_gtir_utils.get_map_variable(dim) + if dim in field_dims + else scalar_expr.indices[dim].value # type: ignore[union-attr] # catched by exception above + for dim in scalar_expr.dimensions + ) + else: + raise ValueError(f"Cannot deref field {scalar_expr.field} in broadcast expression.") + + input_node = scalar_expr.field + gt_dtype = node.args[0].type.dtype + else: + raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.") + + output, _ = sdfg.add_temp_transient(field_shape, input_node.desc(sdfg).dtype) + output_node = state.add_access(output) sdfg_builder.add_mapped_tasklet( "broadcast", @@ -330,15 +375,15 @@ def translate_broadcast_scalar( dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" for dim, lower_bound, upper_bound in domain }, - inputs={"__inp": dace.Memlet(data=scalar_expr.dc_node.data, subset="0")}, + inputs={"__inp": dace.Memlet(data=input_node.data, subset=input_subset)}, code="__val = __inp", - outputs={"__val": dace.Memlet(data=result_field.dc_node.data, subset=domain_indices)}, - input_nodes={scalar_expr.dc_node.data: scalar_expr.dc_node}, - output_nodes={result_field.dc_node.data: result_field.dc_node}, + outputs={"__val": dace.Memlet(data=output_node.data, subset=field_subset)}, + input_nodes={input_node.data: input_node}, + output_nodes={output_node.data: output_node}, external_edges=True, ) - return result_field + return FieldopData(output_node, ts.FieldType(field_dims, gt_dtype), local_offset=None) def translate_if( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 416321c038..cf91d15aba 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -426,82 +426,86 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: assert len(node.args) == 1 arg_expr = self.visit(node.args[0]) - if isinstance(arg_expr, IteratorExpr): - field_desc = arg_expr.field.desc(self.sdfg) - assert len(field_desc.shape) == len(arg_expr.dimensions) - if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()): - # when all indices are symblic expressions, we can perform direct field access through a memlet - field_subset = sbs.Range( - (arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr] - if dim in arg_expr.indices - else (0, size - 1, 1) - for dim, size in zip(arg_expr.dimensions, field_desc.shape) - ) - return MemletExpr(arg_expr.field, field_subset, arg_expr.local_offset) - - else: - # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, - # either indirection through connectivity table or dynamic cartesian offset. - assert all(dim in arg_expr.indices for dim in arg_expr.dimensions) - field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions] - index_connectors = [ - IndexConnectorFmt.format(dim=dim.value) - for dim, index in field_indices - if not isinstance(index, SymbolExpr) - ] - # here `internals` refer to the names used as index in the tasklet code string: - # an index can be either a connector name (for dynamic/indirect indices) - # or a symbol value (for literal values and scalar arguments). - index_internals = ",".join( - str(index.value) - if isinstance(index, SymbolExpr) - else IndexConnectorFmt.format(dim=dim.value) - for dim, index in field_indices - ) - deref_node = self._add_tasklet( - "runtime_deref", - {"field"} | set(index_connectors), - {"val"}, - code=f"val = field[{index_internals}]", - ) - # add new termination point for the field parameter - self._add_input_data_edge( - arg_expr.field, - sbs.Range.from_array(field_desc), - deref_node, - "field", - ) + if not isinstance(arg_expr, IteratorExpr): + # dereferencing a scalar or a literal node results in the node itself + return arg_expr - for dim, index_expr in field_indices: - # add termination points for the dynamic iterator indices - deref_connector = IndexConnectorFmt.format(dim=dim.value) - if isinstance(index_expr, MemletExpr): - self._add_input_data_edge( - index_expr.dc_node, - index_expr.subset, - deref_node, - deref_connector, - ) - - elif isinstance(index_expr, ValueExpr): - self._add_edge( - index_expr.dc_node, - None, - deref_node, - deref_connector, - dace.Memlet(data=index_expr.dc_node.data, subset="0"), - ) - else: - assert isinstance(index_expr, SymbolExpr) - - dc_dtype = arg_expr.field.desc(self.sdfg).dtype - return self._construct_tasklet_result( - dc_dtype, deref_node, "val", arg_expr.local_offset - ) + field_desc = arg_expr.field.desc(self.sdfg) + if isinstance(field_desc, dace.data.Scalar): + # deref a zero-dimensional field + assert len(arg_expr.dimensions) == 0 + assert isinstance(node.type, ts.ScalarType) + return MemletExpr(arg_expr.field, subset="0") + # default case: deref a field with one or more dimensions + assert len(field_desc.shape) == len(arg_expr.dimensions) + if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()): + # when all indices are symblic expressions, we can perform direct field access through a memlet + field_subset = sbs.Range( + (arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr] + if dim in arg_expr.indices + else (0, size - 1, 1) + for dim, size in zip(arg_expr.dimensions, field_desc.shape) + ) + return MemletExpr(arg_expr.field, field_subset, arg_expr.local_offset) else: - # dereferencing a scalar or a literal node results in the node itself - return arg_expr + # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, + # either indirection through connectivity table or dynamic cartesian offset. + assert all(dim in arg_expr.indices for dim in arg_expr.dimensions) + field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions] + index_connectors = [ + IndexConnectorFmt.format(dim=dim.value) + for dim, index in field_indices + if not isinstance(index, SymbolExpr) + ] + # here `internals` refer to the names used as index in the tasklet code string: + # an index can be either a connector name (for dynamic/indirect indices) + # or a symbol value (for literal values and scalar arguments). + index_internals = ",".join( + str(index.value) + if isinstance(index, SymbolExpr) + else IndexConnectorFmt.format(dim=dim.value) + for dim, index in field_indices + ) + deref_node = self._add_tasklet( + "runtime_deref", + {"field"} | set(index_connectors), + {"val"}, + code=f"val = field[{index_internals}]", + ) + # add new termination point for the field parameter + self._add_input_data_edge( + arg_expr.field, + sbs.Range.from_array(field_desc), + deref_node, + "field", + ) + + for dim, index_expr in field_indices: + # add termination points for the dynamic iterator indices + deref_connector = IndexConnectorFmt.format(dim=dim.value) + if isinstance(index_expr, MemletExpr): + self._add_input_data_edge( + index_expr.dc_node, + index_expr.subset, + deref_node, + deref_connector, + ) + + elif isinstance(index_expr, ValueExpr): + self._add_edge( + index_expr.dc_node, + None, + deref_node, + deref_connector, + dace.Memlet(data=index_expr.dc_node.data, subset="0"), + ) + else: + assert isinstance(index_expr, SymbolExpr) + + return self._construct_tasklet_result( + field_desc.dtype, deref_node, "val", arg_expr.local_offset + ) def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: assert len(node.args) == 2 diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 28eef5c260..e489f130db 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -255,6 +255,10 @@ def _add_storage( return tuple_fields elif isinstance(gt_type, ts.FieldType): + if len(gt_type.dims) == 0: + # represent zero-dimensional fields as scalar arguments + return self._add_storage(sdfg, symbolic_arguments, name, gt_type.dtype, transient) + # handle default case: field with one or more dimensions dc_dtype = dace_utils.as_dace_type(gt_type.dtype) # use symbolic shape, which allows to invoke the program with fields of different size; # and symbolic strides, which enables decoupling the memory layout from generated code. diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 36d6debf9d..7540d52fb3 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -219,6 +219,7 @@ def testee(a: tuple[int32, tuple[int32, int32]]) -> cases.VField: @pytest.mark.uses_tuple_args +@pytest.mark.uses_zero_dimensional_fields def test_zero_dim_tuple_arg(unstructured_case): @gtx.field_operator def testee( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 5377654b55..41f540d3cf 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -372,6 +372,35 @@ def test_gtir_tuple_broadcast_scalar(): assert np.allclose(d, a + 2 * b + 3 * c) +def test_gtir_zero_dim_fields(): + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + testee = gtir.Program( + id="gtir_zero_dim_fields", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=ts.FieldType(dims=[], dtype=IFTYPE.dtype)), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.as_fieldop("deref", domain)("x"), + domain=domain, + target=gtir.SymRef(id="y"), + ) + ], + ) + + a = np.asarray(np.random.rand()) + b = np.empty(N) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + + sdfg(a.item(), b, **FSYMBOLS) + assert np.allclose(a, b) + + def test_gtir_tuple_return(): domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) testee = gtir.Program( From 692c14b7784258a28202842f8e9a7a7b307dc96b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 28 Oct 2024 10:31:24 +0100 Subject: [PATCH 019/178] refactor[cartesian]: Replace `is_start_state=` with `is_start_block=` when buiding the SDFG for DaCe (#1709) When building an SDFG for DaCe it is sometimes necessary to specify the starting point of the graph. In the past, `is_start_state=` was used for this. In never versions of Dace, `is_start_state` has been deprecated and replaced with `is_start_block`. This PR replaces the two occurrences in the `cartesian/` folder and removes warnings generated during test runs. It looks like the `next/` folder has already been cleaned. --- src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py | 2 +- src/gt4py/cartesian/gtc/dace/oir_to_dace.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py index 9d64464377..7b0f0ab7c4 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py @@ -323,7 +323,7 @@ def visit_NestedSDFG( ) -> dace.nodes.NestedSDFG: sdfg = dace.SDFG(node.label) inner_sdfg_ctx = StencilComputationSDFGBuilder.SDFGContext( - sdfg=sdfg, state=sdfg.add_state(is_start_state=True) + sdfg=sdfg, state=sdfg.add_state(is_start_block=True) ) self.visit( node.field_decls, diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py index 3555d555f9..f12c13cd0e 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py @@ -39,7 +39,7 @@ class SDFGContext: def __init__(self, stencil: oir.Stencil): self.sdfg = dace.SDFG(stencil.name) - self.last_state = self.sdfg.add_state(is_start_state=True) + self.last_state = self.sdfg.add_state(is_start_block=True) self.decls = {decl.name: decl for decl in stencil.params + stencil.declarations} self.block_extents = compute_horizontal_block_extents(stencil) From 96f67dcc3f008704d847f3be439f16c21204c435 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 28 Oct 2024 18:35:48 +0100 Subject: [PATCH 020/178] feat[next]: Allow partial type inference on ITIR (#1706) --- .../iterator/transforms/collapse_tuple.py | 9 ++++- .../next/iterator/type_system/inference.py | 19 ++++++---- .../iterator/type_system/type_synthesizer.py | 3 ++ .../iterator_tests/test_type_inference.py | 37 +++++++++++++++++++ .../transforms_tests/test_collapse_tuple.py | 9 +++++ 5 files changed, 67 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 40d98208dd..b61fb2ba87 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -192,8 +192,13 @@ def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ # tuple argument differs, just continue with the rest of the tree return None - assert self.ignore_tuple_size or isinstance(first_expr.type, ts.TupleType) - if self.ignore_tuple_size or len(first_expr.type.types) == len(node.args): # type: ignore[union-attr] # ensured by assert above + assert self.ignore_tuple_size or isinstance( + first_expr.type, (ts.TupleType, ts.DeferredType) + ) + if self.ignore_tuple_size or ( + isinstance(first_expr.type, ts.TupleType) + and len(first_expr.type.types) == len(node.args) + ): return first_expr return None diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 4640aa11d1..edcb9b540c 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -84,7 +84,7 @@ def _is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec): is_compatible &= _is_compatible_type(arg_a, arg_b) is_compatible &= _is_compatible_type(type_a.returns, type_b.returns) else: - is_compatible &= type_a == type_b + is_compatible &= type_info.is_concretizable(type_a, type_b) return is_compatible @@ -435,7 +435,7 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: result = super().visit(node, **kwargs) if isinstance(node, itir.Node): if isinstance(result, ts.TypeSpec): - if node.type: + if node.type and not isinstance(node.type, ts.DeferredType): assert _is_compatible_type(node.type, result) node.type = result elif isinstance(result, ObservableTypeSynthesizer) or result is None: @@ -511,17 +511,18 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: path, node.expr.type, ) - assert isinstance(target_type, ts.FieldType) - assert isinstance(expr_type, ts.FieldType) + assert isinstance(target_type, (ts.FieldType, ts.DeferredType)) + assert isinstance(expr_type, (ts.FieldType, ts.DeferredType)) # TODO(tehrengruber): The lowering emits domains that always have the horizontal domain # first. Since the expr inherits the ordering from the domain this can lead to a mismatch # between the target and expr (e.g. when the target has dimension K, Vertex). We should # probably just change the behaviour of the lowering. Until then we do this more # complicated comparison. - assert ( - set(expr_type.dims) == set(target_type.dims) - and target_type.dtype == expr_type.dtype - ) + if isinstance(target_type, ts.FieldType) and isinstance(expr_type, ts.FieldType): + assert ( + set(expr_type.dims).issubset(set(target_type.dims)) + and target_type.dtype == expr_type.dtype + ) # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.StencilClosureType: @@ -623,6 +624,8 @@ def visit_FunCall( self.visit(tuple_, ctx=ctx) # ensure tuple is typed assert isinstance(index_literal, itir.Literal) index = int(index_literal.value) + if isinstance(tuple_.type, ts.DeferredType): + return ts.DeferredType(constraint=None) assert isinstance(tuple_.type, ts.TupleType) return tuple_.type.types[index] diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index c836de1391..c55cfd8d51 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -291,6 +291,9 @@ def as_fieldop( @TypeSynthesizer def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: + if any(isinstance(f, ts.DeferredType) for f in fields): + return ts.DeferredType(constraint=None) + stencil_return = stencil( *(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields), offset_provider=offset_provider, diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 20a1d7e9b7..7b6214fb1b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -103,6 +103,10 @@ def expression_test_cases(): # tuple_get (im.tuple_get(0, im.make_tuple(im.ref("a", int_type), im.ref("b", bool_type))), int_type), (im.tuple_get(1, im.make_tuple(im.ref("a", int_type), im.ref("b", bool_type))), bool_type), + ( + im.tuple_get(0, im.ref("t", ts.DeferredType(constraint=None))), + ts.DeferredType(constraint=None), + ), # neighbors ( im.neighbors("E2V", im.ref("a", it_on_e_of_e_type)), @@ -171,6 +175,12 @@ def expression_test_cases(): )(im.ref("inp1", float_i_field), im.ref("inp2", float_i_field)), ts.TupleType(types=[float_i_field, float_i_field]), ), + ( + im.as_fieldop(im.lambda_("x")(im.deref("x")))( + im.ref("inp", ts.DeferredType(constraint=None)) + ), + ts.DeferredType(constraint=None), + ), # if in field-view scope ( im.if_( @@ -458,6 +468,33 @@ def test_program_tuple_setat_short_target(): ) +def test_program_setat_without_domain(): + cartesian_domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ) + + testee = itir.Program( + id="f", + function_definitions=[], + params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x")(im.deref("x")))("inp"), + domain=cartesian_domain, + target=im.ref("out", float_i_field), + ) + ], + ) + + result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + + assert ( + isinstance(result.body[0].expr.type, ts.DeferredType) + and result.body[0].expr.type.constraint == ts.FieldType + ) + + def test_if_stmt(): cartesian_domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index bcf8b726be..720076c8c2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -8,6 +8,7 @@ from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple +from gt4py.next.type_system import type_specifications as ts def test_simple_make_tuple_tuple_get(): @@ -213,3 +214,11 @@ def test_if_on_tuples_with_let(): testee, remove_letified_make_tuple_elements=False, allow_undeclared_symbols=True ) assert actual == expected + + +def test_tuple_get_on_untyped_ref(): + # test pass gracefully handles untyped nodes. + testee = im.tuple_get(0, im.ref("val", ts.DeferredType(constraint=None))) + + actual = CollapseTuple.apply(testee, allow_undeclared_symbols=True) + assert actual == testee From c78bafd39da35b26a06a74b17346161a4cadb3b9 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 29 Oct 2024 17:57:17 +0100 Subject: [PATCH 021/178] tests: fix name of license file configured in pyproject.toml (#1718) The license file configured in `pyproject.toml` was missing the `*.txt` extension, leading to warnings in tests. Parent: https://github.com/GEOS-ESM/SMT-Nebulae/issues/89 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1bb05c11c5..3c3efab625 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ keywords = [ 'portable', 'hpc' ] -license = {file = 'LICENSE'} +license = {file = 'LICENSE.txt'} name = 'gt4py' readme = 'README.md' requires-python = '>=3.8' From f1c9d83c09836aeb2eff67898eff3604124b63b7 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 31 Oct 2024 10:06:21 +0100 Subject: [PATCH 022/178] ci: use actions/checkout@v4 and actions/setup-python@v5 (#1717) This PR updates the GitHub Actions (GHA) workflows to use `actions/checkout@v4` instead of `v3` (or `v2` in the cartesian case) and `actions/setup-python@v5`. Development for `actions/checkout@v3` and `actions/setup-python@v4` stopped ~1 year ago and GH is currently enforcing newer node versions than the one that these actions were designed with, leading to the following warnings ![image](https://github.com/user-attachments/assets/52924274-a00c-451c-a5a3-810fba1e6b27) _warnings in next workflows_ ![image](https://github.com/user-attachments/assets/af7f0027-f6d1-4998-a28a-c76c13dcaebb) _warnings in cartesian workflows_ `deploy_release` action was following `actions/checkout@master`. Was this on purpose? Happy to revert if so. Unless there's a good reason, I suggest to keep all actions pinned at ideally the same major version. `pre-commit/action` was updated to keep its dependencies up to date and avoid transitive warnings similar to the ones above. No changes made to currently disabled workflows, i.e. the ones under `.github/workflows/_disabled/`. Parent: https://github.com/GEOS-ESM/SMT-Nebulae/issues/89 --- .github/workflows/code-quality.yml | 6 +++--- .github/workflows/daily-ci.yml | 4 ++-- .github/workflows/deploy-release.yml | 4 ++-- .github/workflows/test-cartesian.yml | 4 ++-- .github/workflows/test-eve.yml | 4 ++-- .github/workflows/test-next.yml | 4 ++-- .github/workflows/test-notebooks.yml | 4 ++-- .github/workflows/test-storage.yml | 4 ++-- 8 files changed, 17 insertions(+), 17 deletions(-) diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index 2137cd871a..ee5ccce53c 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -14,9 +14,9 @@ jobs: code-quality: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" cache: 'pip' @@ -24,4 +24,4 @@ jobs: **/pyproject.toml **/constraints.txt **/requirements-dev.txt - - uses: pre-commit/action@v3.0.0 + - uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/daily-ci.yml b/.github/workflows/daily-ci.yml index 42f96659e0..30ad0a6ff9 100644 --- a/.github/workflows/daily-ci.yml +++ b/.github/workflows/daily-ci.yml @@ -23,7 +23,7 @@ jobs: runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install C++ libraries if: ${{ matrix.os == 'macos-latest' }} shell: bash @@ -42,7 +42,7 @@ jobs: mv boost_1_76_0/boost boost/include/ echo "BOOST_ROOT=${PWD}/boost" >> $GITHUB_ENV - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: 'pip' diff --git a/.github/workflows/deploy-release.yml b/.github/workflows/deploy-release.yml index 048a6f73e1..9ce6983de1 100644 --- a/.github/workflows/deploy-release.yml +++ b/.github/workflows/deploy-release.yml @@ -14,9 +14,9 @@ jobs: name: Build Python distribution runs-on: ubuntu-latest steps: - - uses: actions/checkout@master + - uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install pypa/build diff --git a/.github/workflows/test-cartesian.yml b/.github/workflows/test-cartesian.yml index 5d23577bc9..aa59660a68 100644 --- a/.github/workflows/test-cartesian.yml +++ b/.github/workflows/test-cartesian.yml @@ -28,7 +28,7 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11"] tox-factor: [internal, dace] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Install boost shell: bash run: | @@ -40,7 +40,7 @@ jobs: mv boost_1_76_0/boost boost/include/ echo "BOOST_ROOT=${PWD}/boost" >> $GITHUB_ENV - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: 'pip' diff --git a/.github/workflows/test-eve.yml b/.github/workflows/test-eve.yml index 061f7cd484..bfd6d8e481 100644 --- a/.github/workflows/test-eve.yml +++ b/.github/workflows/test-eve.yml @@ -28,9 +28,9 @@ jobs: runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: 'pip' diff --git a/.github/workflows/test-next.yml b/.github/workflows/test-next.yml index 8e05bbc86a..1460a5bdf4 100644 --- a/.github/workflows/test-next.yml +++ b/.github/workflows/test-next.yml @@ -27,7 +27,7 @@ jobs: runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install C++ libraries if: ${{ matrix.os == 'macos-latest' }} shell: bash @@ -39,7 +39,7 @@ jobs: run: | sudo apt install libboost-dev - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: 'pip' diff --git a/.github/workflows/test-notebooks.yml b/.github/workflows/test-notebooks.yml index 39298b5427..4a65b7f30d 100644 --- a/.github/workflows/test-notebooks.yml +++ b/.github/workflows/test-notebooks.yml @@ -20,9 +20,9 @@ jobs: runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: 'pip' diff --git a/.github/workflows/test-storage.yml b/.github/workflows/test-storage.yml index e76526c296..2f85670eeb 100644 --- a/.github/workflows/test-storage.yml +++ b/.github/workflows/test-storage.yml @@ -30,9 +30,9 @@ jobs: runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: 'pip' From e8f11fe4a2356f043ab1998aad405ffe462fa267 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 1 Nov 2024 17:05:20 +0100 Subject: [PATCH 023/178] bug[next]: fix lowering of tuples of neighbors in conditionals (#1710) Use the `_map` function in all cases where mapping of with as_fieldop/lifted stencil *and* mapping of lists is required. --- src/gt4py/next/ffront/foast_to_gtir.py | 78 ++++++++++------ src/gt4py/next/ffront/foast_to_itir.py | 50 +++++++---- src/gt4py/next/ffront/lowering_utils.py | 24 ++--- tests/next_tests/integration_tests/cases.py | 1 + .../ffront_tests/test_gt4py_builtins.py | 88 +++++++++++++++++++ 5 files changed, 187 insertions(+), 54 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 0d0c3868f8..10583b90ff 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -53,8 +53,8 @@ def adapted_foast_to_gtir_factory(**kwargs: Any) -> workflow.Workflow[AOT_FOP, i return toolchain.StripArgsAdapter(foast_to_gtir_factory(**kwargs)) -def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]: - if not type_info.contains_local_field(node.type): +def promote_to_list(node_type: ts.TypeSpec) -> Callable[[itir.Expr], itir.Expr]: + if not type_info.contains_local_field(node_type): return lambda x: im.op_as_fieldop("make_const_list")(x) return lambda x: x @@ -215,16 +215,16 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: if dtype.kind != ts.ScalarKind.BOOL: raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") - return self._map("not_", node.operand) + return self._lower_and_map("not_", node.operand) - return self._map( + return self._lower_and_map( node.op.value, foast.Constant(value="0", type=dtype, location=node.location), node.operand, ) def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: - return self._map(node.op.value, node.left, node.right) + return self._lower_and_map(node.op.value, node.left, node.right) def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall: assert ( @@ -236,7 +236,7 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC ) def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: - return self._map(node.op.value, node.left, node.right) + return self._lower_and_map(node.op.value, node.left, node.right) def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: current_expr = self.visit(node.func, **kwargs) @@ -338,34 +338,43 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id - def create_cast(expr: itir.Expr, t: ts.TypeSpec) -> itir.FunCall: - if isinstance(t, ts.FieldType): + def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall: + if isinstance(t[0], ts.FieldType): return im.as_fieldop( im.lambda_("__val")(im.call("cast_")(im.deref("__val"), str(new_type))) )(expr) else: - assert isinstance(t, ts.ScalarType) + assert isinstance(t[0], ts.ScalarType) return im.call("cast_")(expr, str(new_type)) if not isinstance(node.type, ts.TupleType): # to keep the IR simpler - return create_cast(obj, node.type) + return create_cast(obj, (node.args[0].type,)) - return lowering_utils.process_elements(create_cast, obj, node.type, with_type=True) + return lowering_utils.process_elements( + create_cast, obj, node.type, arg_types=(node.args[0].type,) + ) def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: if not isinstance(node.type, ts.TupleType): # to keep the IR simpler - return im.op_as_fieldop("if_")(*self.visit(node.args)) + return self._lower_and_map("if_", *node.args) cond_ = self.visit(node.args[0]) cond_symref_name = f"__cond_{eve_utils.content_hash(cond_)}" - def create_if(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall: - return im.op_as_fieldop("if_")(im.ref(cond_symref_name), true_, false_) + def create_if( + true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec] + ) -> itir.FunCall: + return _map( + "if_", + (im.ref(cond_symref_name), true_, false_), + (node.args[0].type, *arg_types), + ) result = lowering_utils.process_elements( create_if, (self.visit(node.args[1]), self.visit(node.args[2])), node.type, + arg_types=(node.args[1].type, node.args[2].type), ) return im.let(cond_symref_name, cond_)(result) @@ -377,7 +386,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return im.as_fieldop(im.ref("deref"))(expr) def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self._map(self.visit(node.func, **kwargs), *node.args) + return self._lower_and_map(self.visit(node.func, **kwargs), *node.args) def _make_reduction_expr( self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any @@ -436,19 +445,34 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: return self._make_literal(node.value, node.type) - def _map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: - lowered_args = [self.visit(arg, **kwargs) for arg in args] - if all( - isinstance(t, ts.ScalarType) - for arg in args - for t in type_info.primitive_constituents(arg.type) - ): - return im.call(op)(*lowered_args) # scalar operation - if any(type_info.contains_local_field(arg.type) for arg in args): - lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] - op = im.call("map_")(op) + def _lower_and_map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: + return _map( + op, tuple(self.visit(arg, **kwargs) for arg in args), tuple(arg.type for arg in args) + ) + + +def _map( + op: itir.Expr | str, + lowered_args: tuple, + original_arg_types: tuple[ts.TypeSpec, ...], +) -> itir.FunCall: + """ + Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists. + """ + if all( + isinstance(t, ts.ScalarType) + for arg_type in original_arg_types + for t in type_info.primitive_constituents(arg_type) + ): + return im.call(op)(*lowered_args) # scalar operation + if any(type_info.contains_local_field(arg_type) for arg_type in original_arg_types): + lowered_args = tuple( + promote_to_list(arg_type)(larg) + for arg_type, larg in zip(original_arg_types, lowered_args) + ) + op = im.call("map_")(op) - return im.op_as_fieldop(im.call(op))(*lowered_args) + return im.op_as_fieldop(im.call(op))(*lowered_args) class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 7936eda1cf..538b0f3ddb 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -55,8 +55,8 @@ def adapted_foast_to_itir_factory(**kwargs: Any) -> workflow.Workflow[AOT_FOP, i return toolchain.StripArgsAdapter(foast_to_itir_factory(**kwargs)) -def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]: - if not type_info.contains_local_field(node.type): +def promote_to_list(node_type: ts.TypeSpec) -> Callable[[itir.Expr], itir.Expr]: + if not type_info.contains_local_field(node_type): return lambda x: im.promote_to_lifted_stencil("make_const_list")(x) return lambda x: x @@ -267,16 +267,16 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: if dtype.kind != ts.ScalarKind.BOOL: raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") - return self._map("not_", node.operand) + return self._lower_and_map("not_", node.operand) - return self._map( + return self._lower_and_map( node.op.value, foast.Constant(value="0", type=dtype, location=node.location), node.operand, ) def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: - return self._map(node.op.value, node.left, node.right) + return self._lower_and_map(node.op.value, node.left, node.right) def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall: op = "if_" @@ -286,7 +286,9 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC for arg in args ] if any(type_info.contains_local_field(arg.type) for arg in args): - lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] + lowered_args = [ + promote_to_list(arg.type)(larg) for arg, larg in zip(args, lowered_args) + ] op = im.call("map_")(op) return lowering_utils.to_tuples_of_iterator( @@ -294,7 +296,7 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC ) def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: - return self._map(node.op.value, node.left, node.right) + return self._lower_and_map(node.op.value, node.left, node.right) def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: current_expr = self.visit(node.func, **kwargs) @@ -408,9 +410,12 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.Expr: lowered_condition = self.visit(condition, **kwargs) return lowering_utils.process_elements( - lambda tv, fv: im.promote_to_lifted_stencil("if_")(lowered_condition, tv, fv), + lambda tv, fv, types: _map( + "if_", (lowered_condition, tv, fv), (condition.type, *types) + ), [self.visit(true_value, **kwargs), self.visit(false_value, **kwargs)], node.type, + (node.args[1].type, node.args[2].type), ) _visit_concat_where = _visit_where @@ -419,7 +424,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self.visit(node.args[0], **kwargs) def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self._map(self.visit(node.func, **kwargs), *node.args) + return self._lower_and_map(self.visit(node.func, **kwargs), *node.args) def _make_reduction_expr( self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any @@ -480,13 +485,28 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: return self._make_literal(node.value, node.type) - def _map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: - lowered_args = [self.visit(arg, **kwargs) for arg in args] - if any(type_info.contains_local_field(arg.type) for arg in args): - lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] - op = im.call("map_")(op) + def _lower_and_map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: + return _map( + op, tuple(self.visit(arg, **kwargs) for arg in args), tuple(arg.type for arg in args) + ) + + +def _map( + op: itir.Expr | str, + lowered_args: tuple, + original_arg_types: tuple[ts.TypeSpec, ...], +) -> itir.FunCall: + """ + Mapping includes making the operation an lifted stencil (first kind of mapping), but also `itir.map_`ing lists. + """ + if any(type_info.contains_local_field(arg_type) for arg_type in original_arg_types): + lowered_args = tuple( + promote_to_list(arg_type)(larg) + for arg_type, larg in zip(original_arg_types, lowered_args) + ) + op = im.call("map_")(op) - return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) + return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index a52581edb0..7049f70021 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from collections.abc import Iterable -from typing import Any, Callable, TypeVar +from typing import Any, Callable, Optional, TypeVar from gt4py.eve import utils as eve_utils from gt4py.next.ffront import type_info as ti_ffront @@ -102,7 +102,7 @@ def process_elements( process_func: Callable[..., itir.Expr], objs: itir.Expr | Iterable[itir.Expr], current_el_type: ts.TypeSpec, - with_type: bool = False, + arg_types: Optional[Iterable[ts.TypeSpec]] = None, ) -> itir.FunCall: """ Recursively applies a processing function to all primitive constituents of a tuple. @@ -113,9 +113,9 @@ def process_elements( objs: The object whose elements are to be transformed. current_el_type: A type with the same structure as the elements of `objs`. The leaf-types are not used and thus not relevant. - current_el_type: A type with the same structure as the elements of `objs`. Unless `with_type=True` - the leaf-types are not used and thus not relevant. - with_type: If True, the last argument passed to `process_func` will be its type. + arg_types: If provided, a tuple of the type of each argument is passed to `process_func` as last argument. + Note, that `arg_types` might coincide with `(current_el_type,)*len(objs)`, but not necessarily, + in case of implicit broadcasts. """ if isinstance(objs, itir.Expr): objs = (objs,) @@ -125,7 +125,7 @@ def process_elements( process_func, tuple(im.ref(let_id) for let_id in let_ids), current_el_type, - with_type=with_type, + arg_types=arg_types, ) return im.let(*(zip(let_ids, objs, strict=True)))(body) @@ -138,7 +138,7 @@ def _process_elements_impl( process_func: Callable[..., itir.Expr], _current_el_exprs: Iterable[T], current_el_type: ts.TypeSpec, - with_type: bool, + arg_types: Optional[Iterable[ts.TypeSpec]], ) -> itir.Expr: if isinstance(current_el_type, ts.TupleType): result = im.make_tuple( @@ -149,16 +149,16 @@ def _process_elements_impl( im.tuple_get(i, current_el_expr) for current_el_expr in _current_el_exprs ), current_el_type.types[i], - with_type=with_type, + arg_types=tuple(arg_t.types[i] for arg_t in arg_types) # type: ignore[attr-defined] # guaranteed by the requirement that `current_el_type` and each element of `arg_types` have the same tuple structure + if arg_types is not None + else None, ) for i in range(len(current_el_type.types)) ) ) - elif type_info.contains_local_field(current_el_type): - raise NotImplementedError("Processing fields with local dimension is not implemented.") else: - if with_type: - result = process_func(*_current_el_exprs, current_el_type) + if arg_types is not None: + result = process_func(*_current_el_exprs, arg_types) else: result = process_func(*_current_el_exprs) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index d85cd5b3df..9fb7850666 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -69,6 +69,7 @@ IJKField: TypeAlias = gtx.Field[[IDim, JDim, KDim], np.int32] # type: ignore [valid-type] IJKFloatField: TypeAlias = gtx.Field[[IDim, JDim, KDim], np.float64] # type: ignore [valid-type] VField: TypeAlias = gtx.Field[[Vertex], np.int32] # type: ignore [valid-type] +VBoolField: TypeAlias = gtx.Field[[Vertex], bool] # type: ignore [valid-type] EField: TypeAlias = gtx.Field[[Edge], np.int32] # type: ignore [valid-type] CField: TypeAlias = gtx.Field[[Cell], np.int32] # type: ignore [valid-type] EmptyField: TypeAlias = gtx.Field[[], np.int32] # type: ignore [valid-type] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 3777de7843..29966c30ad 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -220,6 +220,94 @@ def testee(flux: cases.EField) -> cases.VField: ) +@pytest.mark.uses_unstructured_shift +def test_reduction_expression_with_where(unstructured_case): + @gtx.field_operator + def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: + return neighbor_sum(where(mask, inp(V2E), inp(V2E)), axis=V2EDim) + + v2e_table = unstructured_case.offset_provider["V2E"].table + + mask = unstructured_case.as_field( + [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) + ) + inp = cases.allocate(unstructured_case, testee, "inp")() + out = cases.allocate(unstructured_case, testee, cases.RETURN)() + + cases.verify( + unstructured_case, + testee, + mask, + inp, + out=out, + ref=np.sum( + inp.asnumpy()[v2e_table], + axis=1, + initial=0, + where=v2e_table != common._DEFAULT_SKIP_VALUE, + ), + ) + + +@pytest.mark.uses_unstructured_shift +def test_reduction_expression_with_where_and_tuples(unstructured_case): + @gtx.field_operator + def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: + return neighbor_sum(where(mask, (inp(V2E), inp(V2E)), (inp(V2E), inp(V2E)))[1], axis=V2EDim) + + v2e_table = unstructured_case.offset_provider["V2E"].table + + mask = unstructured_case.as_field( + [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) + ) + inp = cases.allocate(unstructured_case, testee, "inp")() + out = cases.allocate(unstructured_case, testee, cases.RETURN)() + + cases.verify( + unstructured_case, + testee, + mask, + inp, + out=out, + ref=np.sum( + inp.asnumpy()[v2e_table], + axis=1, + initial=0, + where=v2e_table != common._DEFAULT_SKIP_VALUE, + ), + ) + + +@pytest.mark.uses_unstructured_shift +def test_reduction_expression_with_where_and_scalar(unstructured_case): + @gtx.field_operator + def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: + return neighbor_sum(inp(V2E) + where(mask, inp(V2E), 1), axis=V2EDim) + + v2e_table = unstructured_case.offset_provider["V2E"].table + + mask = unstructured_case.as_field( + [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) + ) + inp = cases.allocate(unstructured_case, testee, "inp")() + out = cases.allocate(unstructured_case, testee, cases.RETURN)() + + cases.verify( + unstructured_case, + testee, + mask, + inp, + out=out, + ref=np.sum( + inp.asnumpy()[v2e_table] + + np.where(np.expand_dims(mask.asnumpy(), 1), inp.asnumpy()[v2e_table], 1), + axis=1, + initial=0, + where=v2e_table != common._DEFAULT_SKIP_VALUE, + ), + ) + + @pytest.mark.uses_tuple_returns def test_conditional_nested_tuple(cartesian_case): @gtx.field_operator From 162a512b2d6882ff16e1a245e2953d7bfe8aa3d3 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 4 Nov 2024 08:53:34 +0100 Subject: [PATCH 024/178] build: update dependencies (#1720) updates minimum gridtools_cpp version for #1699 --- .pre-commit-config.yaml | 24 +++--- constraints.txt | 97 ++++++++++++------------- min-extra-requirements-test.txt | 2 +- min-requirements-test.txt | 2 +- pyproject.toml | 2 +- requirements-dev.txt | 97 ++++++++++++------------- src/gt4py/cartesian/cli.py | 2 + src/gt4py/next/ffront/dialect_parser.py | 4 +- src/gt4py/next/ffront/fbuiltins.py | 2 +- src/gt4py/next/utils.py | 2 +- 10 files changed, 117 insertions(+), 117 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e0314bca3..880a422160 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,7 @@ repos: ## version = re.search('ruff==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: v{version}") ##]]] - rev: v0.6.4 + rev: v0.7.2 ##[[[end]]] hooks: # Run the linter. @@ -73,9 +73,9 @@ repos: ## version = re.search('mypy==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"#========= FROM constraints.txt: v{version} =========") ##]]] - #========= FROM constraints.txt: v1.11.2 ========= + #========= FROM constraints.txt: v1.13.0 ========= ##[[[end]]] - rev: v1.11.2 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) + rev: v1.13.0 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) hooks: - id: mypy additional_dependencies: # versions from constraints.txt @@ -95,25 +95,25 @@ repos: - attrs==24.2.0 - black==24.8.0 - boltons==24.0.0 - - cached-property==1.5.2 + - cached-property==2.0.1 - click==8.1.7 - - cmake==3.30.3 - - cytoolz==0.12.3 + - cmake==3.30.5 + - cytoolz==1.0.0 - deepdiff==8.0.1 - devtools==0.12.2 - factory-boy==3.3.1 - - frozendict==2.4.4 - - gridtools-cpp==2.3.4 + - frozendict==2.4.6 + - gridtools-cpp==2.3.6 - importlib-resources==6.4.5 - jinja2==3.1.4 - lark==1.2.2 - - mako==1.3.5 - - nanobind==2.1.0 + - mako==1.3.6 + - nanobind==2.2.0 - ninja==1.11.1.1 - numpy==1.24.4 - packaging==24.1 - - pybind11==2.13.5 - - setuptools==74.1.2 + - pybind11==2.13.6 + - setuptools==75.3.0 - tabulate==0.9.0 - typing-extensions==4.12.2 - xxhash==3.0.0 diff --git a/constraints.txt b/constraints.txt index 5df3f58c60..e846d4126c 100644 --- a/constraints.txt +++ b/constraints.txt @@ -14,51 +14,50 @@ babel==2.16.0 # via sphinx backcall==0.2.0 # via ipython black==24.8.0 # via gt4py (pyproject.toml) boltons==24.0.0 # via gt4py (pyproject.toml) -bracex==2.5 # via wcmatch -build==1.2.2 # via pip-tools -bump-my-version==0.26.0 # via -r requirements-dev.in -cached-property==1.5.2 # via gt4py (pyproject.toml) +bracex==2.5.post1 # via wcmatch +build==1.2.2.post1 # via pip-tools +bump-my-version==0.28.0 # via -r requirements-dev.in +cached-property==2.0.1 # via gt4py (pyproject.toml) cachetools==5.5.0 # via tox certifi==2024.8.30 # via requests cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox -charset-normalizer==3.3.2 # via requests -clang-format==18.1.8 # via -r requirements-dev.in, gt4py (pyproject.toml) +charset-normalizer==3.4.0 # via requests +clang-format==19.1.3 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.30.3 # via gt4py (pyproject.toml) +cmake==3.30.5 # via gt4py (pyproject.toml) cogapp==3.4.1 # via -r requirements-dev.in colorama==0.4.6 # via tox comm==0.2.2 # via ipykernel contourpy==1.1.1 # via matplotlib coverage==7.6.1 # via -r requirements-dev.in, pytest-cov cycler==0.12.1 # via matplotlib -cytoolz==0.12.3 # via gt4py (pyproject.toml) +cytoolz==1.0.0 # via gt4py (pyproject.toml) dace==0.16.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -debugpy==1.8.5 # via ipykernel +debugpy==1.8.7 # via ipykernel decorator==5.1.1 # via ipython deepdiff==8.0.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) -dill==0.3.8 # via dace -distlib==0.3.8 # via virtualenv +dill==0.3.9 # via dace +distlib==0.3.9 # via virtualenv docutils==0.20.1 # via sphinx, sphinx-rtd-theme -eval-type-backport==0.2.0 # via tach exceptiongroup==1.2.2 # via hypothesis, pytest execnet==2.1.1 # via pytest-cache, pytest-xdist executing==2.1.0 # via devtools, stack-data factory-boy==3.3.1 # via gt4py (pyproject.toml), pytest-factoryboy -faker==28.4.1 # via factory-boy +faker==30.8.2 # via factory-boy fastjsonschema==2.20.0 # via nbformat -filelock==3.16.0 # via tox, virtualenv -fonttools==4.53.1 # via matplotlib +filelock==3.16.1 # via tox, virtualenv +fonttools==4.54.1 # via matplotlib fparser==0.1.4 # via dace -frozendict==2.4.4 # via gt4py (pyproject.toml) +frozendict==2.4.6 # via gt4py (pyproject.toml) gitdb==4.0.11 # via gitpython gitpython==3.1.43 # via tach -gridtools-cpp==2.3.4 # via gt4py (pyproject.toml) -hypothesis==6.112.0 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.6.0 # via pre-commit -idna==3.8 # via requests +gridtools-cpp==2.3.6 # via gt4py (pyproject.toml) +hypothesis==6.113.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.6.1 # via pre-commit +idna==3.10 # via requests imagesize==1.4.1 # via sphinx importlib-metadata==8.5.0 # via build, jupyter-client, sphinx importlib-resources==6.4.5 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications, matplotlib @@ -70,12 +69,12 @@ jedi==0.19.1 # via ipython jinja2==3.1.4 # via dace, gt4py (pyproject.toml), sphinx jsonschema==4.23.0 # via nbformat jsonschema-specifications==2023.12.1 # via jsonschema -jupyter-client==8.6.2 # via ipykernel, nbclient +jupyter-client==8.6.3 # via ipykernel, nbclient jupyter-core==5.7.2 # via ipykernel, jupyter-client, nbformat jupytext==1.16.4 # via -r requirements-dev.in kiwisolver==1.4.7 # via matplotlib lark==1.2.2 # via gt4py (pyproject.toml) -mako==1.3.5 # via gt4py (pyproject.toml) +mako==1.3.6 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins, rich markupsafe==2.1.5 # via jinja2, mako matplotlib==3.7.5 # via -r requirements-dev.in @@ -83,9 +82,9 @@ matplotlib-inline==0.1.7 # via ipykernel, ipython mdit-py-plugins==0.4.2 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.11.2 # via -r requirements-dev.in +mypy==1.13.0 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==2.1.0 # via gt4py (pyproject.toml) +nanobind==2.2.0 # via gt4py (pyproject.toml) nbclient==0.6.8 # via nbmake nbformat==5.10.4 # via jupytext, nbclient, nbmake nbmake==1.5.4 # via -r requirements-dev.in @@ -102,25 +101,25 @@ pexpect==4.9.0 # via ipython pickleshare==0.7.5 # via ipython pillow==10.4.0 # via matplotlib pip-tools==7.4.1 # via -r requirements-dev.in -pipdeptree==2.23.3 # via -r requirements-dev.in +pipdeptree==2.23.4 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==4.3.2 # via black, jupyter-core, tox, virtualenv +platformdirs==4.3.6 # via black, jupyter-core, tox, virtualenv pluggy==1.5.0 # via pytest, tox ply==3.11 # via dace pre-commit==3.5.0 # via -r requirements-dev.in prompt-toolkit==3.0.36 # via ipython, questionary, tach -psutil==6.0.0 # via -r requirements-dev.in, ipykernel, pytest-xdist +psutil==6.1.0 # via -r requirements-dev.in, ipykernel, pytest-xdist ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data -pybind11==2.13.5 # via gt4py (pyproject.toml) -pydantic==2.9.1 # via bump-my-version, pydantic-settings, tach -pydantic-core==2.23.3 # via pydantic -pydantic-settings==2.5.2 # via bump-my-version -pydot==2.0.0 # via tach +pybind11==2.13.6 # via gt4py (pyproject.toml) +pydantic==2.9.2 # via bump-my-version, pydantic-settings +pydantic-core==2.23.4 # via pydantic +pydantic-settings==2.6.1 # via bump-my-version +pydot==3.0.2 # via tach pygments==2.18.0 # via -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx pyparsing==3.1.4 # via matplotlib, pydot -pyproject-api==1.7.1 # via tox -pyproject-hooks==1.1.0 # via build, pip-tools +pyproject-api==1.8.0 # via tox +pyproject-hooks==1.2.0 # via build, pip-tools pytest==8.3.3 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==5.0.0 # via -r requirements-dev.in @@ -136,10 +135,10 @@ pyzmq==26.2.0 # via ipykernel, jupyter-client questionary==2.0.1 # via bump-my-version referencing==0.35.1 # via jsonschema, jsonschema-specifications requests==2.32.3 # via sphinx -rich==13.8.1 # via bump-my-version, rich-click, tach +rich==13.9.3 # via bump-my-version, rich-click, tach rich-click==1.8.3 # via bump-my-version -rpds-py==0.20.0 # via jsonschema, referencing -ruff==0.6.4 # via -r requirements-dev.in +rpds-py==0.20.1 # via jsonschema, referencing +ruff==0.7.2 # via -r requirements-dev.in scipy==1.10.1 # via gt4py (pyproject.toml) setuptools-scm==8.1.0 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil @@ -147,7 +146,7 @@ smmap==5.0.1 # via gitdb snowballstemmer==2.2.0 # via sphinx sortedcontainers==2.4.0 # via hypothesis sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in +sphinx-rtd-theme==3.0.1 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -159,25 +158,25 @@ stack-data==0.6.3 # via ipython stdlib-list==0.10.0 # via tach sympy==1.12.1 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.10.7 # via -r requirements-dev.in -tomli==2.0.1 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tox +tach==0.14.1 # via -r requirements-dev.in +tomli==2.0.2 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via tach tomlkit==0.13.2 # via bump-my-version -toolz==0.12.1 # via cytoolz +toolz==1.0.0 # via cytoolz tornado==6.4.1 # via ipykernel, jupyter-client -tox==4.18.1 # via -r requirements-dev.in +tox==4.23.2 # via -r requirements-dev.in traitlets==5.14.3 # via comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat types-tabulate==0.9.0.20240106 # via -r requirements-dev.in -typing-extensions==4.12.2 # via annotated-types, black, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, setuptools-scm +typing-extensions==4.12.2 # via annotated-types, black, faker, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, setuptools-scm, tox urllib3==2.2.3 # via requests -virtualenv==20.26.4 # via pre-commit, tox -wcmatch==9.0 # via bump-my-version +virtualenv==20.27.1 # via pre-commit, tox +wcmatch==10.0 # via bump-my-version wcwidth==0.2.13 # via prompt-toolkit -websockets==13.0.1 # via dace +websockets==13.1 # via dace wheel==0.44.0 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) -zipp==3.20.1 # via importlib-metadata, importlib-resources +zipp==3.20.2 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: -pip==24.2 # via pip-tools, pipdeptree -setuptools==74.1.2 # via gt4py (pyproject.toml), pip-tools, setuptools-scm +pip==24.3.1 # via pip-tools, pipdeptree +setuptools==75.3.0 # via gt4py (pyproject.toml), pip-tools, setuptools-scm diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 10d70397c6..7fea11bc3d 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -67,7 +67,7 @@ deepdiff==5.6.0 devtools==0.6 factory-boy==3.3.0 frozendict==2.3 -gridtools-cpp==2.3.4 +gridtools-cpp==2.3.6 hypothesis==6.0.0 importlib-resources==5.0; python_version < "3.9" jax[cpu]==0.4.18; python_version >= "3.10" diff --git a/min-requirements-test.txt b/min-requirements-test.txt index 01b21dc1f2..c20883e25e 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -63,7 +63,7 @@ deepdiff==5.6.0 devtools==0.6 factory-boy==3.3.0 frozendict==2.3 -gridtools-cpp==2.3.4 +gridtools-cpp==2.3.6 hypothesis==6.0.0 importlib-resources==5.0; python_version < "3.9" jinja2==3.0.0 diff --git a/pyproject.toml b/pyproject.toml index 3c3efab625..64f08e671e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ 'devtools>=0.6', 'factory-boy>=3.3.0', 'frozendict>=2.3', - 'gridtools-cpp>=2.3.4,==2.*', + 'gridtools-cpp>=2.3.6,==2.*', "importlib-resources>=5.0;python_version<'3.9'", 'jinja2>=3.0.0', 'lark>=1.1.2', diff --git a/requirements-dev.txt b/requirements-dev.txt index 0b7baec1bc..eb757e0afd 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,51 +14,50 @@ babel==2.16.0 # via -c constraints.txt, sphinx backcall==0.2.0 # via -c constraints.txt, ipython black==24.8.0 # via -c constraints.txt, gt4py (pyproject.toml) boltons==24.0.0 # via -c constraints.txt, gt4py (pyproject.toml) -bracex==2.5 # via -c constraints.txt, wcmatch -build==1.2.2 # via -c constraints.txt, pip-tools -bump-my-version==0.26.0 # via -c constraints.txt, -r requirements-dev.in -cached-property==1.5.2 # via -c constraints.txt, gt4py (pyproject.toml) +bracex==2.5.post1 # via -c constraints.txt, wcmatch +build==1.2.2.post1 # via -c constraints.txt, pip-tools +bump-my-version==0.28.0 # via -c constraints.txt, -r requirements-dev.in +cached-property==2.0.1 # via -c constraints.txt, gt4py (pyproject.toml) cachetools==5.5.0 # via -c constraints.txt, tox certifi==2024.8.30 # via -c constraints.txt, requests cfgv==3.4.0 # via -c constraints.txt, pre-commit chardet==5.2.0 # via -c constraints.txt, tox -charset-normalizer==3.3.2 # via -c constraints.txt, requests -clang-format==18.1.8 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) +charset-normalizer==3.4.0 # via -c constraints.txt, requests +clang-format==19.1.3 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via -c constraints.txt, black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.30.3 # via -c constraints.txt, gt4py (pyproject.toml) +cmake==3.30.5 # via -c constraints.txt, gt4py (pyproject.toml) cogapp==3.4.1 # via -c constraints.txt, -r requirements-dev.in colorama==0.4.6 # via -c constraints.txt, tox comm==0.2.2 # via -c constraints.txt, ipykernel contourpy==1.1.1 # via -c constraints.txt, matplotlib coverage[toml]==7.6.1 # via -c constraints.txt, -r requirements-dev.in, pytest-cov cycler==0.12.1 # via -c constraints.txt, matplotlib -cytoolz==0.12.3 # via -c constraints.txt, gt4py (pyproject.toml) +cytoolz==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) dace==0.16.1 # via -c constraints.txt, gt4py (pyproject.toml) darglint==1.8.1 # via -c constraints.txt, -r requirements-dev.in -debugpy==1.8.5 # via -c constraints.txt, ipykernel +debugpy==1.8.7 # via -c constraints.txt, ipykernel decorator==5.1.1 # via -c constraints.txt, ipython deepdiff==8.0.1 # via -c constraints.txt, gt4py (pyproject.toml) devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml) -dill==0.3.8 # via -c constraints.txt, dace -distlib==0.3.8 # via -c constraints.txt, virtualenv +dill==0.3.9 # via -c constraints.txt, dace +distlib==0.3.9 # via -c constraints.txt, virtualenv docutils==0.20.1 # via -c constraints.txt, sphinx, sphinx-rtd-theme -eval-type-backport==0.2.0 # via -c constraints.txt, tach exceptiongroup==1.2.2 # via -c constraints.txt, hypothesis, pytest execnet==2.1.1 # via -c constraints.txt, pytest-cache, pytest-xdist executing==2.1.0 # via -c constraints.txt, devtools, stack-data factory-boy==3.3.1 # via -c constraints.txt, gt4py (pyproject.toml), pytest-factoryboy -faker==28.4.1 # via -c constraints.txt, factory-boy +faker==30.8.2 # via -c constraints.txt, factory-boy fastjsonschema==2.20.0 # via -c constraints.txt, nbformat -filelock==3.16.0 # via -c constraints.txt, tox, virtualenv -fonttools==4.53.1 # via -c constraints.txt, matplotlib +filelock==3.16.1 # via -c constraints.txt, tox, virtualenv +fonttools==4.54.1 # via -c constraints.txt, matplotlib fparser==0.1.4 # via -c constraints.txt, dace -frozendict==2.4.4 # via -c constraints.txt, gt4py (pyproject.toml) +frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython gitpython==3.1.43 # via -c constraints.txt, tach -gridtools-cpp==2.3.4 # via -c constraints.txt, gt4py (pyproject.toml) -hypothesis==6.112.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.6.0 # via -c constraints.txt, pre-commit -idna==3.8 # via -c constraints.txt, requests +gridtools-cpp==2.3.6 # via -c constraints.txt, gt4py (pyproject.toml) +hypothesis==6.113.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.6.1 # via -c constraints.txt, pre-commit +idna==3.10 # via -c constraints.txt, requests imagesize==1.4.1 # via -c constraints.txt, sphinx importlib-metadata==8.5.0 # via -c constraints.txt, build, jupyter-client, sphinx importlib-resources==6.4.5 ; python_version < "3.9" # via -c constraints.txt, gt4py (pyproject.toml), jsonschema, jsonschema-specifications, matplotlib @@ -70,12 +69,12 @@ jedi==0.19.1 # via -c constraints.txt, ipython jinja2==3.1.4 # via -c constraints.txt, dace, gt4py (pyproject.toml), sphinx jsonschema==4.23.0 # via -c constraints.txt, nbformat jsonschema-specifications==2023.12.1 # via -c constraints.txt, jsonschema -jupyter-client==8.6.2 # via -c constraints.txt, ipykernel, nbclient +jupyter-client==8.6.3 # via -c constraints.txt, ipykernel, nbclient jupyter-core==5.7.2 # via -c constraints.txt, ipykernel, jupyter-client, nbformat jupytext==1.16.4 # via -c constraints.txt, -r requirements-dev.in kiwisolver==1.4.7 # via -c constraints.txt, matplotlib lark==1.2.2 # via -c constraints.txt, gt4py (pyproject.toml) -mako==1.3.5 # via -c constraints.txt, gt4py (pyproject.toml) +mako==1.3.6 # via -c constraints.txt, gt4py (pyproject.toml) markdown-it-py==3.0.0 # via -c constraints.txt, jupytext, mdit-py-plugins, rich markupsafe==2.1.5 # via -c constraints.txt, jinja2, mako matplotlib==3.7.5 # via -c constraints.txt, -r requirements-dev.in @@ -83,9 +82,9 @@ matplotlib-inline==0.1.7 # via -c constraints.txt, ipykernel, ipython mdit-py-plugins==0.4.2 # via -c constraints.txt, jupytext mdurl==0.1.2 # via -c constraints.txt, markdown-it-py mpmath==1.3.0 # via -c constraints.txt, sympy -mypy==1.11.2 # via -c constraints.txt, -r requirements-dev.in +mypy==1.13.0 # via -c constraints.txt, -r requirements-dev.in mypy-extensions==1.0.0 # via -c constraints.txt, black, mypy -nanobind==2.1.0 # via -c constraints.txt, gt4py (pyproject.toml) +nanobind==2.2.0 # via -c constraints.txt, gt4py (pyproject.toml) nbclient==0.6.8 # via -c constraints.txt, nbmake nbformat==5.10.4 # via -c constraints.txt, jupytext, nbclient, nbmake nbmake==1.5.4 # via -c constraints.txt, -r requirements-dev.in @@ -102,25 +101,25 @@ pexpect==4.9.0 # via -c constraints.txt, ipython pickleshare==0.7.5 # via -c constraints.txt, ipython pillow==10.4.0 # via -c constraints.txt, matplotlib pip-tools==7.4.1 # via -c constraints.txt, -r requirements-dev.in -pipdeptree==2.23.3 # via -c constraints.txt, -r requirements-dev.in +pipdeptree==2.23.4 # via -c constraints.txt, -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via -c constraints.txt, jsonschema -platformdirs==4.3.2 # via -c constraints.txt, black, jupyter-core, tox, virtualenv +platformdirs==4.3.6 # via -c constraints.txt, black, jupyter-core, tox, virtualenv pluggy==1.5.0 # via -c constraints.txt, pytest, tox ply==3.11 # via -c constraints.txt, dace pre-commit==3.5.0 # via -c constraints.txt, -r requirements-dev.in prompt-toolkit==3.0.36 # via -c constraints.txt, ipython, questionary, tach -psutil==6.0.0 # via -c constraints.txt, -r requirements-dev.in, ipykernel, pytest-xdist +psutil==6.1.0 # via -c constraints.txt, -r requirements-dev.in, ipykernel, pytest-xdist ptyprocess==0.7.0 # via -c constraints.txt, pexpect pure-eval==0.2.3 # via -c constraints.txt, stack-data -pybind11==2.13.5 # via -c constraints.txt, gt4py (pyproject.toml) -pydantic==2.9.1 # via -c constraints.txt, bump-my-version, pydantic-settings, tach -pydantic-core==2.23.3 # via -c constraints.txt, pydantic -pydantic-settings==2.5.2 # via -c constraints.txt, bump-my-version -pydot==2.0.0 # via -c constraints.txt, tach +pybind11==2.13.6 # via -c constraints.txt, gt4py (pyproject.toml) +pydantic==2.9.2 # via -c constraints.txt, bump-my-version, pydantic-settings +pydantic-core==2.23.4 # via -c constraints.txt, pydantic +pydantic-settings==2.6.1 # via -c constraints.txt, bump-my-version +pydot==3.0.2 # via -c constraints.txt, tach pygments==2.18.0 # via -c constraints.txt, -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx pyparsing==3.1.4 # via -c constraints.txt, matplotlib, pydot -pyproject-api==1.7.1 # via -c constraints.txt, tox -pyproject-hooks==1.1.0 # via -c constraints.txt, build, pip-tools +pyproject-api==1.8.0 # via -c constraints.txt, tox +pyproject-hooks==1.2.0 # via -c constraints.txt, build, pip-tools pytest==8.3.3 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist pytest-cache==1.0 # via -c constraints.txt, -r requirements-dev.in pytest-cov==5.0.0 # via -c constraints.txt, -r requirements-dev.in @@ -136,17 +135,17 @@ pyzmq==26.2.0 # via -c constraints.txt, ipykernel, jupyter-client questionary==2.0.1 # via -c constraints.txt, bump-my-version referencing==0.35.1 # via -c constraints.txt, jsonschema, jsonschema-specifications requests==2.32.3 # via -c constraints.txt, sphinx -rich==13.8.1 # via -c constraints.txt, bump-my-version, rich-click, tach +rich==13.9.3 # via -c constraints.txt, bump-my-version, rich-click, tach rich-click==1.8.3 # via -c constraints.txt, bump-my-version -rpds-py==0.20.0 # via -c constraints.txt, jsonschema, referencing -ruff==0.6.4 # via -c constraints.txt, -r requirements-dev.in +rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing +ruff==0.7.2 # via -c constraints.txt, -r requirements-dev.in setuptools-scm==8.1.0 # via -c constraints.txt, fparser six==1.16.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil smmap==5.0.1 # via -c constraints.txt, gitdb snowballstemmer==2.2.0 # via -c constraints.txt, sphinx sortedcontainers==2.4.0 # via -c constraints.txt, hypothesis sphinx==7.1.2 # via -c constraints.txt, -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==2.0.0 # via -c constraints.txt, -r requirements-dev.in +sphinx-rtd-theme==3.0.1 # via -c constraints.txt, -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via -c constraints.txt, sphinx sphinxcontrib-devhelp==1.0.2 # via -c constraints.txt, sphinx sphinxcontrib-htmlhelp==2.0.1 # via -c constraints.txt, sphinx @@ -158,25 +157,25 @@ stack-data==0.6.3 # via -c constraints.txt, ipython stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.12.1 # via -c constraints.txt, dace, gt4py (pyproject.toml) tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.10.7 # via -c constraints.txt, -r requirements-dev.in -tomli==2.0.1 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tox +tach==0.14.1 # via -c constraints.txt, -r requirements-dev.in +tomli==2.0.2 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version -toolz==0.12.1 # via -c constraints.txt, cytoolz +toolz==1.0.0 # via -c constraints.txt, cytoolz tornado==6.4.1 # via -c constraints.txt, ipykernel, jupyter-client -tox==4.18.1 # via -c constraints.txt, -r requirements-dev.in +tox==4.23.2 # via -c constraints.txt, -r requirements-dev.in traitlets==5.14.3 # via -c constraints.txt, comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat types-tabulate==0.9.0.20240106 # via -c constraints.txt, -r requirements-dev.in -typing-extensions==4.12.2 # via -c constraints.txt, annotated-types, black, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, setuptools-scm +typing-extensions==4.12.2 # via -c constraints.txt, annotated-types, black, faker, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, setuptools-scm, tox urllib3==2.2.3 # via -c constraints.txt, requests -virtualenv==20.26.4 # via -c constraints.txt, pre-commit, tox -wcmatch==9.0 # via -c constraints.txt, bump-my-version +virtualenv==20.27.1 # via -c constraints.txt, pre-commit, tox +wcmatch==10.0 # via -c constraints.txt, bump-my-version wcwidth==0.2.13 # via -c constraints.txt, prompt-toolkit -websockets==13.0.1 # via -c constraints.txt, dace +websockets==13.1 # via -c constraints.txt, dace wheel==0.44.0 # via -c constraints.txt, astunparse, pip-tools xxhash==3.0.0 # via -c constraints.txt, gt4py (pyproject.toml) -zipp==3.20.1 # via -c constraints.txt, importlib-metadata, importlib-resources +zipp==3.20.2 # via -c constraints.txt, importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: -pip==24.2 # via -c constraints.txt, pip-tools, pipdeptree -setuptools==74.1.2 # via -c constraints.txt, gt4py (pyproject.toml), pip-tools, setuptools-scm +pip==24.3.1 # via -c constraints.txt, pip-tools, pipdeptree +setuptools==75.3.0 # via -c constraints.txt, gt4py (pyproject.toml), pip-tools, setuptools-scm diff --git a/src/gt4py/cartesian/cli.py b/src/gt4py/cartesian/cli.py index 23f8791ca7..91daed9e98 100644 --- a/src/gt4py/cartesian/cli.py +++ b/src/gt4py/cartesian/cli.py @@ -138,6 +138,8 @@ def convert( self, value: str, param: Optional[click.Parameter], ctx: Optional[click.Context] ) -> Tuple[str, Any]: backend = ctx.params["backend"] if ctx else gt4pyc.backend.from_name("numpy") + assert isinstance(backend, type) + assert issubclass(backend, gt4pyc.backend.Backend) name, value = self._try_split(value) if name.strip() not in backend.options: self.fail(f"Backend {backend.name} received unknown option: {name}!") diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index 23b719abb7..79d188cdf2 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -106,11 +106,11 @@ def get_location(self, node: ast.AST) -> SourceLocation: # `FixMissingLocations` ensures that all nodes have the location attributes assert hasattr(node, "lineno") - line = node.lineno + line_offset if node.lineno is not None else None + line = node.lineno + line_offset assert hasattr(node, "end_lineno") end_line = node.end_lineno + line_offset if node.end_lineno is not None else None assert hasattr(node, "col_offset") - column = 1 + node.col_offset + col_offset if node.col_offset is not None else None + column = 1 + node.col_offset + col_offset assert hasattr(node, "end_col_offset") end_column = ( 1 + node.end_col_offset + col_offset if node.end_col_offset is not None else None diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 3b711212a3..d932431b51 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -245,7 +245,7 @@ def impl(value: common.Field | core_defs.ScalarT, /) -> common.Field | core_defs value ) # default implementation for scalars, Fields are handled via dispatch - return _math_builtin(value) + return cast(common.Field | core_defs.ScalarT, _math_builtin(value)) # type: ignore[operator] # calling a function of unknown type impl.__name__ = name globals()[name] = BuiltInFunction(impl) diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 7489908ba9..f1a82c6bd9 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -140,7 +140,7 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: assert result_collection_constructor is not None return result_collection_constructor(impl(*arg) for arg in zip(*args)) - return fun( # type: ignore[misc] # mypy not smart enough + return fun( # type: ignore[call-arg, misc] # mypy not smart enough *cast(_P.args, args) ) # mypy doesn't understand that `args` at this point is of type `_P.args` From 78d8e922053e4f6259ef6af9d18703f284522c96 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 4 Nov 2024 10:58:28 +0100 Subject: [PATCH 025/178] feat[next]: enable gtir.embedded (and add support for Lists in output) (#1703) Lists returned from an `as_fieldop`ed stencil will be turned into a local dimension of the the resulting field. In case of a `make_const_list`, a magic local dimension `_CONST_DIM` is used. This is a hack, but localized to `itir.embedded`. A clean implementation will probably involve to tag the `make_const_list` with the neighborhood it is meant to be used with. --- src/gt4py/next/ffront/foast_to_gtir.py | 30 +++- src/gt4py/next/iterator/embedded.py | 168 ++++++++++++++---- .../next/iterator/transforms/pass_manager.py | 7 +- .../type_system/type_specifications.py | 5 +- tests/next_tests/definitions.py | 4 + .../ffront_tests/ffront_test_utils.py | 2 +- .../ffront_tests/test_decorator.py | 6 +- .../test_embedded_field_with_list.py | 124 +++++++++++++ 8 files changed, 294 insertions(+), 52 deletions(-) create mode 100644 tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 10583b90ff..6cf4cc67fd 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -116,7 +116,31 @@ def visit_FieldOperator( def visit_ScanOperator( self, node: foast.ScanOperator, **kwargs: Any ) -> itir.FunctionDefinition: - raise NotImplementedError("TODO") + # note: we don't need the axis here as this is handled by the program + # decorator + assert isinstance(node.type, ts_ffront.ScanOperatorType) + + # We are lowering node.forward and node.init to iterators, but here we expect values -> `deref`. + # In iterator IR we didn't properly specify if this is legal, + # however after lift-inlining the expressions are transformed back to literals. + forward = self.visit(node.forward, **kwargs) + init = self.visit(node.init, **kwargs) + + # lower definition function + func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) + new_body = func_definition.expr + + stencil_args: list[itir.Expr] = [] + assert not node.type.definition.pos_only_args and not node.type.definition.kw_only_args + for param in func_definition.params[1:]: + new_body = im.let(param.id, im.deref(param.id))(new_body) + stencil_args.append(im.ref(param.id)) + + definition = itir.Lambda(params=func_definition.params, expr=new_body) + + body = im.as_fieldop(im.call("scan")(definition, forward, init))(*stencil_args) + + return itir.FunctionDefinition(id=node.id, params=definition.params[1:], expr=body) def visit_Stmt(self, node: foast.Stmt, **kwargs: Any) -> Never: raise AssertionError("Statements must always be visited in the context of a function.") @@ -324,10 +348,6 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: *lowered_args, *lowered_kwargs.values() ) - # scan operators return an iterator of tuples, transform into tuples of iterator again - if isinstance(node.func.type, ts_ffront.ScanOperatorType): - raise NotImplementedError("TODO") - return result raise AssertionError( diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index afe0cec402..84dd9e3f72 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -54,6 +54,7 @@ ) from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, runtime +from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.otf import arguments from gt4py.next.type_system import type_specifications as ts, type_translation @@ -186,6 +187,12 @@ def mapped_index( NamedFieldIndices: TypeAlias = Mapping[Tag, FieldIndex | SparsePositionEntry] +# Magic local dimension for the result of a `make_const_list`. +# A clean implementation will probably involve to tag the `make_const_list` +# with the neighborhood it is meant to be used with. +_CONST_DIM = common.Dimension(value="_CONST_DIM", kind=common.DimensionKind.LOCAL) + + @runtime_checkable class ItIterator(Protocol): """ @@ -227,6 +234,12 @@ class MutableLocatedField(LocatedField, Protocol): def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: ... +def _numpy_structured_value_to_tuples(value: Any) -> Any: + if _elem_dtype(value).names is not None: + return tuple(_numpy_structured_value_to_tuples(v) for v in value) + return value + + class Column(np.lib.mixins.NDArrayOperatorsMixin): """Represents a column when executed in column mode (`column_axis != None`). @@ -247,6 +260,10 @@ def dtype(self) -> np.dtype: # not directly dtype of `self.data` as that might be a structured type containing `None` return _elem_dtype(self.data[self.kstart]) + def __gt_type__(self) -> ts.TypeSpec: + elem = self.data[self.kstart] + return type_translation.from_value(_numpy_structured_value_to_tuples(elem)) + def __getitem__(self, i: int) -> Any: result = self.data[i - self.kstart] # numpy type @@ -576,17 +593,20 @@ def execute_shift( for i, p in reversed(list(enumerate(new_entry))): # first shift applies to the last sparse dimensions of that axis type if p is None: - offset_implementation = offset_provider[tag] - assert isinstance(offset_implementation, common.Connectivity) - cur_index = pos[offset_implementation.origin_axis.value] - assert common.is_int_index(cur_index) - if offset_implementation.mapped_index(cur_index, index) in [ - None, - common._DEFAULT_SKIP_VALUE, - ]: - return None - - new_entry[i] = index + if tag == _CONST_DIM.value: + new_entry[i] = 0 + else: + offset_implementation = offset_provider[tag] + assert isinstance(offset_implementation, common.Connectivity) + cur_index = pos[offset_implementation.origin_axis.value] + assert common.is_int_index(cur_index) + if offset_implementation.mapped_index(cur_index, index) in [ + None, + common._DEFAULT_SKIP_VALUE, + ]: + return None + + new_entry[i] = index break # the assertions above confirm pos is incomplete casting here to avoid duplicating work in a type guard return cast(IncompletePosition, pos) | {tag: new_entry} @@ -920,9 +940,9 @@ def deref(self) -> Any: return _make_tuple(self.field, position, column_axis=self.column_axis) -def _get_sparse_dimensions(axes: Sequence[common.Dimension]) -> list[Tag]: +def _get_sparse_dimensions(axes: Sequence[common.Dimension]) -> list[common.Dimension]: return [ - axis.value + axis for axis in axes if isinstance(axis, common.Dimension) and axis.kind == common.DimensionKind.LOCAL ] @@ -945,7 +965,7 @@ def make_in_iterator( new_pos: Position = pos.copy() for sparse_dim in set(sparse_dimensions): init = [None] * sparse_dimensions.count(sparse_dim) - new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused + new_pos[sparse_dim.value] = init # type: ignore[assignment] # looks like mypy is confused if column_dimension is not None: column_range = embedded_context.closure_column_range.get().unit_range # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted @@ -956,7 +976,7 @@ def make_in_iterator( ) if len(sparse_dimensions) >= 1: if len(sparse_dimensions) == 1: - return SparseListIterator(it, sparse_dimensions[0]) + return SparseListIterator(it, sparse_dimensions[0].value) else: raise NotImplementedError( f"More than one local dimension is currently not supported, got {sparse_dimensions}." @@ -1004,7 +1024,17 @@ def field_getitem(self, named_indices: NamedFieldIndices) -> Any: def field_setitem(self, named_indices: NamedFieldIndices, value: Any): if isinstance(self._ndarrayfield, common.MutableField): - self._ndarrayfield[self._translate_named_indices(named_indices)] = value + if isinstance(value, _List): + for i, v in enumerate(value): # type:ignore[var-annotated, arg-type] + self._ndarrayfield[ + self._translate_named_indices({**named_indices, value.offset.value: i}) # type: ignore[dict-item] + ] = v + elif isinstance(value, _ConstList): + self._ndarrayfield[ + self._translate_named_indices({**named_indices, _CONST_DIM.value: 0}) + ] = value.value + else: + self._ndarrayfield[self._translate_named_indices(named_indices)] = value else: raise RuntimeError("Assigment into a non-mutable Field is not allowed.") @@ -1383,7 +1413,23 @@ def impl(it: ItIterator) -> ItIterator: DT = TypeVar("DT") -class _List(tuple, Generic[DT]): ... +@dataclasses.dataclass(frozen=True) +class _List(Generic[DT]): + values: tuple[DT, ...] + offset: runtime.Offset + + def __getitem__(self, i: int): + return self.values[i] + + def __gt_type__(self) -> itir_ts.ListType: + offset_tag = self.offset.value + assert isinstance(offset_tag, str) + element_type = type_translation.from_value(self.values[0]) + assert isinstance(element_type, ts.DataType) + return itir_ts.ListType( + element_type=element_type, + offset_type=common.Dimension(value=offset_tag, kind=common.DimensionKind.LOCAL), + ) @dataclasses.dataclass(frozen=True) @@ -1393,6 +1439,14 @@ class _ConstList(Generic[DT]): def __getitem__(self, _): return self.value + def __gt_type__(self) -> itir_ts.ListType: + element_type = type_translation.from_value(self.value) + assert isinstance(element_type, ts.DataType) + return itir_ts.ListType( + element_type=element_type, + offset_type=_CONST_DIM, + ) + @builtins.neighbors.register(EMBEDDED) def neighbors(offset: runtime.Offset, it: ItIterator) -> _List: @@ -1403,9 +1457,12 @@ def neighbors(offset: runtime.Offset, it: ItIterator) -> _List: connectivity = offset_provider[offset_str] assert isinstance(connectivity, common.Connectivity) return _List( - shifted.deref() - for i in range(connectivity.max_neighbors) - if (shifted := it.shift(offset_str, i)).can_deref() + values=tuple( + shifted.deref() + for i in range(connectivity.max_neighbors) + if (shifted := it.shift(offset_str, i)).can_deref() + ), + offset=offset, ) @@ -1414,10 +1471,23 @@ def list_get(i, lst: _List[Optional[DT]]) -> Optional[DT]: return lst[i] +def _get_offset(*lists: _List | _ConstList) -> Optional[runtime.Offset]: + offsets = set((lst.offset for lst in lists if hasattr(lst, "offset"))) + if len(offsets) == 0: + return None + if len(offsets) == 1: + return offsets.pop() + raise AssertionError("All lists must have the same offset.") + + @builtins.map_.register(EMBEDDED) def map_(op): def impl_(*lists): - return _List(map(lambda x: op(*x), zip(*lists))) + offset = _get_offset(*lists) + if offset is None: + return _ConstList(value=op(*[lst.value for lst in lists])) + else: + return _List(values=tuple(map(lambda x: op(*x), zip(*lists))), offset=offset) return impl_ @@ -1438,7 +1508,7 @@ def sten(*lists): break # we can check a single argument for length, # because all arguments share the same pattern - n = len(lst) + n = len(lst.values) res = init for i in range(n): res = fun(res, *(lst[i] for lst in lists)) @@ -1454,14 +1524,23 @@ class SparseListIterator: offsets: Sequence[OffsetPart] = dataclasses.field(default_factory=list, kw_only=True) def deref(self) -> Any: + if self.list_offset == _CONST_DIM.value: + return _ConstList( + value=self.it.shift(*self.offsets, SparseTag(self.list_offset), 0).deref() + ) offset_provider = embedded_context.offset_provider.get() assert offset_provider is not None connectivity = offset_provider[self.list_offset] assert isinstance(connectivity, common.Connectivity) return _List( - shifted.deref() - for i in range(connectivity.max_neighbors) - if (shifted := self.it.shift(*self.offsets, SparseTag(self.list_offset), i)).can_deref() + values=tuple( + shifted.deref() + for i in range(connectivity.max_neighbors) + if ( + shifted := self.it.shift(*self.offsets, SparseTag(self.list_offset), i) + ).can_deref() + ), + offset=runtime.Offset(value=self.list_offset), ) def can_deref(self) -> bool: @@ -1654,16 +1733,6 @@ def _extract_column_range(domain) -> common.NamedRange | eve.NothingType: return eve.NOTHING -def _structured_dtype_to_typespec(structured_dtype: np.dtype) -> ts.ScalarType | ts.TupleType: - if structured_dtype.names is None: - return type_translation.from_dtype(core_defs.dtype(structured_dtype)) - return ts.TupleType( - types=[ - _structured_dtype_to_typespec(structured_dtype[name]) for name in structured_dtype.names - ] - ) - - def _get_output_type( fun: Callable, domain_: runtime.CartesianDomain | runtime.UnstructuredDomain, @@ -1682,8 +1751,29 @@ def _get_output_type( with embedded_context.new_context(closure_column_range=col_range) as ctx: single_pos_result = ctx.run(_compute_at_position, fun, args, pos_in_domain, col_dim) assert single_pos_result is not _UNDEFINED, "Stencil contains an Out-Of-Bound access." - dtype = _elem_dtype(single_pos_result) - return _structured_dtype_to_typespec(dtype) + return type_translation.from_value(single_pos_result) + + +def _fieldspec_list_to_value( + domain: common.Domain, type_: ts.TypeSpec +) -> tuple[common.Domain, ts.TypeSpec]: + """Translate the list element type into the domain.""" + if isinstance(type_, itir_ts.ListType): + if type_.offset_type == _CONST_DIM: + return domain.insert( + len(domain), common.named_range((_CONST_DIM, 1)) + ), type_.element_type + else: + offset_provider = embedded_context.offset_provider.get() + offset_type = type_.offset_type + assert isinstance(offset_type, common.Dimension) + connectivity = offset_provider[offset_type.value] + assert isinstance(connectivity, common.Connectivity) + return domain.insert( + len(domain), + common.named_range((offset_type, connectivity.max_neighbors)), + ), type_.element_type + return domain, type_ @builtins.as_fieldop.register(EMBEDDED) @@ -1691,7 +1781,9 @@ def as_fieldop(fun: Callable, domain: runtime.CartesianDomain | runtime.Unstruct def impl(*args): xp = field_utils.get_array_ns(*args) type_ = _get_output_type(fun, domain, [promote_scalars(arg) for arg in args]) - out = field_utils.field_from_typespec(type_, common.domain(domain), xp) + + new_domain, type_ = _fieldspec_list_to_value(common.domain(domain), type_) + out = field_utils.field_from_typespec(type_, new_domain, xp) # TODO(havogt): after updating all tests to use the new program, # we should get rid of closure and move the implementation to this function diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 7c35d552dc..0c08bf2b9d 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -223,10 +223,7 @@ def apply_fieldview_transforms( ) -> itir.Program: ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = InlineLambdas.apply(ir, opcount_preserving=True) - ir = infer_domain.infer_program( - ir, - offset_provider=offset_provider, - ) + ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # type is still `itir.Program` + ir = infer_domain.infer_program(ir, offset_provider=offset_provider) return ir diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index 94a174dca4..edb56f5659 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -from typing import Literal +from typing import Literal, Optional from gt4py.next import common from gt4py.next.type_system import type_specifications as ts @@ -31,6 +31,9 @@ class OffsetLiteralType(ts.TypeSpec): @dataclasses.dataclass(frozen=True) class ListType(ts.DataType): element_type: ts.DataType + # TODO(havogt): the `offset_type` is not yet used in type_inference, + # it is meant to describe the neighborhood (via the local dimension) + offset_type: Optional[common.Dimension] = None @dataclasses.dataclass(frozen=True) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 123384a098..2c4102d5af 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -187,6 +187,10 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE) ], ProgramBackendId.ROUNDTRIP: [(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE)], + ProgramBackendId.GTIR_EMBEDDED: [ + (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + ], ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES: [ (ALL, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index a0e72ede8d..333a2dae28 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -46,7 +46,7 @@ def __gt_allocator__( @pytest.fixture( params=[ next_tests.definitions.ProgramBackendId.ROUNDTRIP, - # next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, # FIXME[#1582](havogt): enable once all ingredients for GTIR are available + next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, next_tests.definitions.ProgramBackendId.GTFN_CPU, next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, next_tests.definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index e3e919e52e..f26424bf0e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -30,8 +30,10 @@ def testee_op(a: cases.IField) -> cases.IField: def testee(a: cases.IField, out: cases.IField): testee_op(a, out=out) - assert isinstance(testee.itir, itir.FencilDefinition) - assert isinstance(testee.with_backend(cartesian_case.backend).itir, itir.FencilDefinition) + assert isinstance(testee.itir, (itir.FencilDefinition, itir.Program)) + assert isinstance( + testee.with_backend(cartesian_case.backend).itir, (itir.FencilDefinition, itir.Program) + ) def test_frozen(cartesian_case): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py new file mode 100644 index 0000000000..56d52c75ae --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py @@ -0,0 +1,124 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import numpy as np +import pytest + +import gt4py.next as gtx +from gt4py.next.embedded import context as embedded_context +from gt4py.next.iterator import embedded, runtime +from gt4py.next.iterator.builtins import ( + as_fieldop, + deref, + if_, + make_const_list, + map_, + neighbors, + plus, +) + + +E = gtx.Dimension("E") +V = gtx.Dimension("V") +E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) +E2V = gtx.FieldOffset("E2V", source=V, target=(E, E2VDim)) + + +# 0 --0-- 1 --1-- 2 +e2v_arr = np.array([[0, 1], [1, 2]]) +e2v_conn = gtx.NeighborTableOffsetProvider( + table=e2v_arr, + origin_axis=E, + neighbor_axis=V, + max_neighbors=2, + has_skip_values=False, +) + + +def test_write_neighbors(): + def testee(inp): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda it: neighbors(E2V, it), domain)(inp) + + inp = gtx.as_field([V], np.arange(3)) + with embedded_context.new_context(offset_provider={"E2V": e2v_conn}) as ctx: + result = ctx.run(testee, inp) + + ref = e2v_arr + np.testing.assert_array_equal(result.asnumpy(), ref) + + +def test_write_const_list(): + def testee(): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda: make_const_list(42.0), domain)() + + with embedded_context.new_context(offset_provider={}) as ctx: + result = ctx.run(testee) + + ref = np.asarray([[42.0], [42.0]]) + + assert result.domain.dims[0] == E + assert result.domain.dims[1] == embedded._CONST_DIM # this is implementation detail + assert result.shape[1] == 1 # this is implementation detail + np.testing.assert_array_equal(result.asnumpy(), ref) + + +def test_write_map_neighbors_and_const_list(): + def testee(inp): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda x, y: map_(plus)(deref(x), deref(y)), domain)( + as_fieldop(lambda it: neighbors(E2V, it), domain)(inp), + as_fieldop(lambda: make_const_list(42.0), domain)(), + ) + + inp = gtx.as_field([V], np.arange(3)) + with embedded_context.new_context(offset_provider={"E2V": e2v_conn}) as ctx: + result = ctx.run(testee, inp) + + ref = e2v_arr + 42.0 + np.testing.assert_array_equal(result.asnumpy(), ref) + + +def test_write_map_conditional_neighbors_and_const_list(): + def testee(inp, mask): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda m, x, y: map_(if_)(deref(m), deref(x), deref(y)), domain)( + as_fieldop(lambda it: make_const_list(deref(it)), domain)(mask), + as_fieldop(lambda it: neighbors(E2V, it), domain)(inp), + as_fieldop(lambda it: make_const_list(deref(it)), domain)(42.0), + ) + + inp = gtx.as_field([V], np.arange(3)) + mask_field = gtx.as_field([E], np.array([True, False])) + with embedded_context.new_context(offset_provider={"E2V": e2v_conn}) as ctx: + result = ctx.run(testee, inp, mask_field) + + ref = np.empty_like(e2v_arr, dtype=float) + ref[0, :] = e2v_arr[0, :] + ref[1, :] = 42.0 + np.testing.assert_array_equal(result.asnumpy(), ref) + + +def test_write_map_const_list_and_const_list(): + def testee(): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda x, y: map_(plus)(deref(x), deref(y)), domain)( + as_fieldop(lambda: make_const_list(1.0), domain)(), + as_fieldop(lambda: make_const_list(42.0), domain)(), + ) + + with embedded_context.new_context(offset_provider={}) as ctx: + result = ctx.run(testee) + + ref = np.asarray([[43.0], [43.0]]) + + assert result.domain.dims[0] == E + assert result.domain.dims[1] == embedded._CONST_DIM # this is implementation detail + assert result.shape[1] == 1 # this is implementation detail + np.testing.assert_array_equal(result.asnumpy(), ref) From 827d40cf43b0d00cdd61c3bac15f26b063ae45bd Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 4 Nov 2024 11:12:57 +0100 Subject: [PATCH 026/178] test[cartesian]: Unskip blocked DaCe tests after DaCe upgrade (#1714) A couple of tests were skipped in the DaCe backend with issue https://github.com/GridTools/gt4py/issues/1084 because they were throwing compiler errors. Current versions of DaCe can handle these cases again. Re-enabling the test cases. --- .../test_code_generation.py | 6 ----- .../multi_feature_tests/test_suites.py | 23 ++----------------- 2 files changed, 2 insertions(+), 27 deletions(-) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 398e312af3..c4d07d7337 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -366,9 +366,6 @@ def stencil( @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_variable_offsets(backend): - if backend == "dace:cpu": - pytest.skip("Internal compiler error in GitHub action container") - @gtscript.stencil(backend=backend) def stencil_ij( in_field: gtscript.Field[np.float_], @@ -391,9 +388,6 @@ def stencil_ijk( @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_variable_offsets_and_while_loop(backend): - if backend == "dace:cpu": - pytest.skip("Internal compiler error in GitHub action container") - @gtscript.stencil(backend=backend) def stencil( pe1: gtscript.Field[np.float_], diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py index 44112f3899..d3a5744389 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py @@ -738,29 +738,10 @@ def validation(field_in, field_out, *, domain, origin): field_out[:, :, 0] = field_in[:, :, 0] -def _skip_dace_cpu_gcc_error(backends): - paramtype = type(pytest.param()) - res = [] - for b in backends: - if isinstance(b, paramtype) and b.values[0] == "dace:cpu": - res.append( - pytest.param( - *b.values, - marks=[ - *b.marks, - pytest.mark.skip("Internal compiler error in GitHub action container"), - ], - ) - ) - else: - res.append(b) - return res - - class TestVariableKRead(gt_testing.StencilTestSuite): dtypes = {"field_in": np.float32, "field_out": np.float32, "index": np.int32} domain_range = [(2, 2), (2, 2), (2, 8)] - backends = _skip_dace_cpu_gcc_error(ALL_BACKENDS) + backends = ALL_BACKENDS symbols = { "field_in": gt_testing.field( in_range=(-10, 10), axes="IJK", boundary=[(0, 0), (0, 0), (0, 0)] @@ -782,7 +763,7 @@ def validation(field_in, field_out, index, *, domain, origin): class TestVariableKAndReadOutside(gt_testing.StencilTestSuite): dtypes = {"field_in": np.float64, "field_out": np.float64, "index": np.int32} domain_range = [(2, 2), (2, 2), (2, 8)] - backends = _skip_dace_cpu_gcc_error(ALL_BACKENDS) + backends = ALL_BACKENDS symbols = { "field_in": gt_testing.field( in_range=(0.1, 10), axes="IJK", boundary=[(0, 0), (0, 0), (1, 0)] From 44d6224e2c9f7d0b7a5dadf12d15957c2bb9560d Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 4 Nov 2024 12:02:29 +0100 Subject: [PATCH 027/178] feat[next][dace]: Add helper method to replace unicode symbols (#1696) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DaCe requires C-compatible strings for the names of data containers, such as arrays and scalars. GT4Py uses a unicode symbols (`ᐞ`) as name separator in the SSA pass, which generates invalid symbols for DaCe. This PR introduces a helper method to find new names for invalid symbols present in the IR. --- .../runners/dace_fieldview/gtir_sdfg.py | 8 ++++ .../runners/dace_fieldview/utility.py | 43 ++++++++++++++++++- .../dace_tests/test_gtir_to_sdfg.py | 20 ++++----- 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index e489f130db..48c666a363 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -399,6 +399,14 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: sdfg = dace.SDFG(node.id) sdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) + + # DaCe requires C-compatible strings for the names of data containers, + # such as arrays and scalars. GT4Py uses a unicode symbols ('ᐞ') as name + # separator in the SSA pass, which generates invalid symbols for DaCe. + # Here we find new names for invalid symbols present in the IR. + node = dace_gtir_utils.replace_invalid_symbols(sdfg, node) + + # start block of the stateful graph entry_state = sdfg.add_state("program_entry", is_start_block=True) # declarations of temporaries result in transient array definitions in the SDFG diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 355eaac903..baae8a6ccd 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -9,9 +9,13 @@ from __future__ import annotations import itertools -from typing import Any +from typing import Any, Dict, TypeVar +import dace + +from gt4py import eve from gt4py.next import common as gtx_common +from gt4py.next.iterator import ir as gtir from gt4py.next.type_system import type_specifications as ts @@ -66,3 +70,40 @@ def get_tuple_type(data: tuple[Any, ...]) -> ts.TupleType: return ts.TupleType( types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_dtype for d in data] ) + + +def replace_invalid_symbols(sdfg: dace.SDFG, ir: gtir.Program) -> gtir.Program: + """ + Ensure that all symbols used in the program IR are valid strings (e.g. no unicode-strings). + + If any invalid symbol present, this funtion returns a copy of the input IR where + the invalid symbols have been replaced with new names. If all symbols are valid, + the input IR is returned without copying it. + """ + + class ReplaceSymbols(eve.PreserveLocationVisitor, eve.NodeTranslator): + T = TypeVar("T", gtir.Sym, gtir.SymRef) + + def _replace_sym(self, node: T, symtable: Dict[str, str]) -> T: + sym = str(node.id) + return type(node)(id=symtable.get(sym, sym), type=node.type) + + def visit_Sym(self, node: gtir.Sym, *, symtable: Dict[str, str]) -> gtir.Sym: + return self._replace_sym(node, symtable) + + def visit_SymRef(self, node: gtir.SymRef, *, symtable: Dict[str, str]) -> gtir.SymRef: + return self._replace_sym(node, symtable) + + # program arguments are checked separetely, because they cannot be replaced + if not all(dace.dtypes.validate_name(str(sym.id)) for sym in ir.params): + raise ValueError("Invalid symbol in program parameters.") + + invalid_symbols_mapping = { + sym_id: sdfg.temp_data_name() + for sym in eve.walk_values(ir).if_isinstance(gtir.Sym).to_set() + if not dace.dtypes.validate_name(sym_id := str(sym.id)) + } + if len(invalid_symbols_mapping) != 0: + return ReplaceSymbols().visit(ir, symtable=invalid_symbols_mapping) + else: + return ir diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 41f540d3cf..9f5498b4a7 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1629,20 +1629,20 @@ def test_gtir_let_lambda(): declarations=[], body=[ gtir.SetAt( - # `x1` is a let-lambda expression representing `x * 3` - # `x2` is a let-lambda expression representing `x * 4` - # - note that the let-symbol `x2` is used twice, in a nested let-expression, to test aliasing of the symbol - # `x3` is a let-lambda expression simply accessing `x` field symref - expr=im.let("x1", im.op_as_fieldop("multiplies", subdomain)(3.0, "x"))( + # `xᐞ1` is a let-lambda expression representing `x * 3` + # `xᐞ2` is a let-lambda expression representing `x * 4` + # - note that the let-symbol `xᐞ2` is used twice, in a nested let-expression, to test aliasing of the symbol + # `xᐞ3` is a let-lambda expression simply accessing `x` field symref + expr=im.let("xᐞ1", im.op_as_fieldop("multiplies", subdomain)(3.0, "x"))( im.let( - "x2", - im.let("x2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))( - im.op_as_fieldop("plus", subdomain)("x2", "x2") + "xᐞ2", + im.let("xᐞ2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))( + im.op_as_fieldop("plus", subdomain)("xᐞ2", "xᐞ2") ), )( - im.let("x3", "x")( + im.let("xᐞ3", "x")( im.op_as_fieldop("plus", subdomain)( - "x1", im.op_as_fieldop("plus", subdomain)("x2", "x3") + "xᐞ1", im.op_as_fieldop("plus", subdomain)("xᐞ2", "xᐞ3") ) ) ) From 9f3b0a7508ce5a2e66bae5c528f4a1b4d4194728 Mon Sep 17 00:00:00 2001 From: SF-N Date: Mon, 4 Nov 2024 13:00:59 +0100 Subject: [PATCH 028/178] feat[next]: Index builtin (#1699) Adds index builtin for embedded and gtfn backends. --- src/gt4py/next/iterator/builtins.py | 6 ++ src/gt4py/next/iterator/embedded.py | 9 ++- src/gt4py/next/iterator/ir.py | 8 ++- src/gt4py/next/iterator/pretty_parser.py | 8 +-- .../iterator/type_system/type_synthesizer.py | 8 +++ .../codegens/gtfn/codegen.py | 2 +- .../codegens/gtfn/gtfn_ir.py | 3 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 1 - .../runners/dace_iterator/__init__.py | 13 ++++- .../runners/dace_iterator/itir_to_sdfg.py | 15 ++++- tests/next_tests/definitions.py | 4 +- .../iterator_tests/test_program.py | 56 ++++++++++++++++++- 12 files changed, 117 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 264ac2685c..c8edc12331 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -22,6 +22,11 @@ def as_fieldop(*args): raise BackendNotSelectedError() +@builtin_dispatch +def index(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def deref(*args): raise BackendNotSelectedError() @@ -430,6 +435,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "unstructured_domain", "named_range", "as_fieldop", + "index", *MATH_BUILTINS, } diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 84dd9e3f72..6221c95522 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1204,12 +1204,14 @@ def premap( def restrict(self, item: common.AnyIndexSpec) -> Self: if isinstance(item, Sequence) and all(isinstance(e, common.NamedIndex) for e in item): + assert len(item) == 1 assert isinstance(item[0], common.NamedIndex) # for mypy errors on multiple lines below d, r = item[0] assert d == self._dimension assert isinstance(r, core_defs.INTEGRAL_TYPES) + # TODO(tehrengruber): Use a regular zero dimensional field instead. return self.__class__(self._dimension, r) - # TODO set a domain... + # TODO: set a domain... raise NotImplementedError() __call__ = premap @@ -1793,6 +1795,11 @@ def impl(*args): return impl +@builtins.index.register(EMBEDDED) +def index(axis: common.Dimension) -> common.Field: + return IndexField(axis) + + @runtime.closure.register(EMBEDDED) def closure( domain_: Domain, diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 42da4c83a6..b6f543e9d1 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -101,13 +101,18 @@ class StencilClosure(Node): domain: FunCall stencil: Expr output: Union[SymRef, FunCall] - inputs: List[SymRef] + inputs: List[Union[SymRef, FunCall]] @datamodels.validator("output") def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): if isinstance(value, FunCall) and value.fun != SymRef(id="make_tuple"): raise ValueError("Only FunCall to 'make_tuple' allowed.") + @datamodels.validator("inputs") + def _input_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): + if any(isinstance(v, FunCall) and v.fun != SymRef(id="index") for v in value): + raise ValueError("Only FunCall to 'index' allowed.") + UNARY_MATH_NUMBER_BUILTINS = {"abs"} UNARY_LOGICAL_BUILTINS = {"not_"} @@ -183,6 +188,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib "can_deref", "scan", "if_", + "index", # `index(dim)` creates a dim-field that has the current index at each point *ARITHMETIC_BUILTINS, *TYPEBUILTINS, } diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index 08459a9423..b4a673772f 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -31,9 +31,9 @@ INT_LITERAL: SIGNED_INT FLOAT_LITERAL: SIGNED_FLOAT OFFSET_LITERAL: ( INT_LITERAL | CNAME ) "ₒ" - _literal: INT_LITERAL | FLOAT_LITERAL | OFFSET_LITERAL + AXIS_LITERAL: CNAME ("ᵥ" | "ₕ") + _literal: INT_LITERAL | FLOAT_LITERAL | OFFSET_LITERAL | AXIS_LITERAL ID_NAME: CNAME - AXIS_NAME: CNAME ("ᵥ" | "ₕ") ?prec0: prec1 | "λ(" ( SYM "," )* SYM? ")" "→" prec0 -> lam @@ -84,7 +84,7 @@ else_branch_seperator: "else" if_stmt: "if" "(" prec0 ")" "{" ( stmt )* "}" else_branch_seperator "{" ( stmt )* "}" - named_range: AXIS_NAME ":" "[" prec0 "," prec0 ")" + named_range: AXIS_LITERAL ":" "[" prec0 "," prec0 ")" function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";" declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" TYPE_LITERAL ")" ";" stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";" @@ -128,7 +128,7 @@ def OFFSET_LITERAL(self, value: lark_lexer.Token) -> ir.OffsetLiteral: def ID_NAME(self, value: lark_lexer.Token) -> str: return value.value - def AXIS_NAME(self, value: lark_lexer.Token) -> ir.AxisLiteral: + def AXIS_LITERAL(self, value: lark_lexer.Token) -> ir.AxisLiteral: name = value.value[:-1] kind = ir.DimensionKind.HORIZONTAL if value.value[-1] == "ₕ" else ir.DimensionKind.VERTICAL return ir.AxisLiteral(value=name, kind=kind) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index c55cfd8d51..6579107197 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -189,6 +189,14 @@ def make_tuple(*args: ts.DataType) -> ts.TupleType: return ts.TupleType(types=list(args)) +@_register_builtin_type_synthesizer +def index(arg: ts.DimensionType) -> ts.FieldType: + return ts.FieldType( + dims=[arg.dim], + dtype=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())), + ) + + @_register_builtin_type_synthesizer def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> it_ts.ListType: assert ( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 92dbcedeaa..bfc45d7944 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -260,7 +260,7 @@ def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Coll #include #include #include - + namespace generated{ namespace gtfn = ::gridtools::fn; diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 1995e4de0b..20a1a0cf76 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -153,7 +153,7 @@ class StencilExecution(Stmt): backend: Backend stencil: SymRef output: Union[SymRef, SidComposite] - inputs: list[Union[SymRef, SidComposite, SidFromScalar]] + inputs: list[Union[SymRef, SidComposite, SidFromScalar, FunCall]] class Scan(Node): @@ -192,6 +192,7 @@ class TemporaryAllocation(Node): "unstructured_domain", "named_range", "reduce", + "index", ] ARITHMETIC_BUILTINS = itir.ARITHMETIC_BUILTINS TYPEBUILTINS = itir.TYPEBUILTINS diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 3bd96d14d7..fb2645208c 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -611,7 +611,6 @@ def convert_el_to_sid(el_expr: Expr, el_type: ts.ScalarType | ts.FieldType) -> E tuple_constructor=lambda *elements: SidComposite(values=list(elements)), ) - assert isinstance(lowered_input_as_sid, (SidComposite, SidFromScalar, SymRef)) lowered_inputs.append(lowered_input_as_sid) backend = Backend(domain=self.visit(domain, stencil=stencil, **kwargs)) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index dab8d29fd1..6383d4bb44 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -24,6 +24,7 @@ from gt4py.next import common from gt4py.next.ffront import decorator from gt4py.next.iterator import transforms as itir_transforms +from gt4py.next.iterator.ir import SymRef from gt4py.next.iterator.transforms import program_to_fencil from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.program_processors.runners.dace_common import utility as dace_utils @@ -197,11 +198,16 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: # Halo exchange related metadata, i.e. gt4py_program_input_fields, gt4py_program_output_fields, offset_providers_per_input_field # Add them as dynamic properties to the SDFG + assert all( + isinstance(in_field, SymRef) + for closure in self.itir.closures + for in_field in closure.inputs + ) # backend only supports SymRef inputs, not `index` calls input_fields = [ - str(in_field.id) + str(in_field.id) # type: ignore[union-attr] # ensured by assert for closure in self.itir.closures for in_field in closure.inputs - if str(in_field.id) in fields + if str(in_field.id) in fields # type: ignore[union-attr] # ensured by assert ] sdfg.gt4py_program_input_fields = { in_field: dim @@ -237,6 +243,9 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: closure.stencil, num_args=len(closure.inputs) ) for param, shifts in zip(closure.inputs, params_shifts): + assert isinstance( + param, SymRef + ) # backend only supports SymRef inputs, not `index` calls if not isinstance(param.id, str): continue if param.id not in sdfg.gt4py_program_input_fields: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index d52fbc5857..a824760ce4 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -357,7 +357,10 @@ def visit_StencilClosure( closure_state = closure_sdfg.add_state("closure_entry") closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True) - input_names = [str(inp.id) for inp in node.inputs] + assert all( + isinstance(inp, SymRef) for inp in node.inputs + ) # backend only supports SymRef inputs, not `index` calls + input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert neighbor_tables = get_used_connectivities(node, self.offset_provider) connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() @@ -565,7 +568,10 @@ def _visit_scan_stencil_closure( assert isinstance(node.output, SymRef) neighbor_tables = get_used_connectivities(node, self.offset_provider) - input_names = [str(inp.id) for inp in node.inputs] + assert all( + isinstance(inp, SymRef) for inp in node.inputs + ) # backend only supports SymRef inputs, not `index` calls + input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() ] @@ -732,7 +738,10 @@ def _visit_parallel_stencil_closure( ], ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]: neighbor_tables = get_used_connectivities(node, self.offset_provider) - input_names = [str(inp.id) for inp in node.inputs] + assert all( + isinstance(inp, SymRef) for inp in node.inputs + ) # backend only supports SymRef inputs, not `index` calls + input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() ] diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 2c4102d5af..3fef43865b 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -117,6 +117,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values" USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo" CHECKS_SPECIFIC_ERROR = "checks_specific_error" +USES_INDEX_BUILTIN = "uses_index_builtin" # Skip messages (available format keys: 'marker', 'backend') UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend" @@ -127,7 +128,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): # Common list of feature markers to skip COMMON_SKIP_TEST_LIST = [ (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), @@ -145,6 +145,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), + (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), ] GTIR_DACE_SKIP_TEST_LIST = [ (ALL, SKIP, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index 4eab7502e7..db1c2a42aa 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -10,13 +10,22 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import as_fieldop, cartesian_domain, deref, named_range +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.builtins import ( + as_fieldop, + cartesian_domain, + deref, + index, + named_range, + shift, +) from gt4py.next.iterator.runtime import fendef, fundef, set_at from next_tests.unit_tests.conftest import program_processor, run_processor I = gtx.Dimension("I") +Ioff = gtx.FieldOffset("Ioff", source=I, target=(I,)) @fundef @@ -44,3 +53,48 @@ def test_prog(program_processor): run_processor(copy_program, program_processor, inp, out, isize, offset_provider={}) if validate: assert np.allclose(inp.asnumpy(), out.asnumpy()) + + +@fendef +def index_program_simple(out, size): + set_at( + as_fieldop(lambda i: deref(i), cartesian_domain(named_range(I, 0, size)))(index(I)), + cartesian_domain(named_range(I, 0, size)), + out, + ) + + +@pytest.mark.starts_from_gtir_program +@pytest.mark.uses_index_builtin +def test_index_builtin(program_processor): + program_processor, validate = program_processor + + isize = 10 + out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, itir.INTEGER_INDEX_BUILTIN)) + + run_processor(index_program_simple, program_processor, out, isize, offset_provider={}) + if validate: + assert np.allclose(np.arange(10), out.asnumpy()) + + +@fendef +def index_program_shift(out, size): + set_at( + as_fieldop( + lambda i: deref(i) + deref(shift(Ioff, 1)(i)), cartesian_domain(named_range(I, 0, size)) + )(index(I)), + cartesian_domain(named_range(I, 0, size)), + out, + ) + + +@pytest.mark.uses_index_builtin +def test_index_builtin_shift(program_processor): + program_processor, validate = program_processor + + isize = 10 + out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, itir.INTEGER_INDEX_BUILTIN)) + + run_processor(index_program_shift, program_processor, out, isize, offset_provider={"Ioff": I}) + if validate: + assert np.allclose(np.arange(10) + np.arange(1, 11), out.asnumpy()) From d3bd61d026d2e70f8ae8fcfb8fa1b85243deb294 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 4 Nov 2024 13:05:44 +0100 Subject: [PATCH 029/178] test[next]: non-supported itir.List test-case (#1721) added for documentation purposes --- .../test_embedded_field_with_list.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py index 56d52c75ae..dcc3a306f2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py @@ -105,6 +105,34 @@ def testee(inp, mask): np.testing.assert_array_equal(result.asnumpy(), ref) +def test_write_non_mapped_conditional_neighbors_and_const_list(): + """ + This test-case demonstrates a non-supported pattern: + Current ITIR requires the `if_` to be `map_`ed, see `test_write_map_conditional_neighbors_and_const_list`. + We keep it here for documenting corner cases of the `itir.List` implementation for future discussions. + """ + + pytest.skip("Unsupported.") + + def testee(inp, mask): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda m, x, y: if_(deref(m), deref(x), deref(y)), domain)( + mask, + as_fieldop(lambda it: make_const_list(deref(it)), domain)(42.0), + as_fieldop(lambda it: neighbors(E2V, it), domain)(inp), + ) + + inp = gtx.as_field([V], np.arange(3)) + mask_field = gtx.as_field([E], np.array([True, False])) + with embedded_context.new_context(offset_provider={"E2V": e2v_conn}) as ctx: + result = ctx.run(testee, inp, mask_field) + + ref = np.empty_like(e2v_arr, dtype=float) + ref[0, :] = e2v_arr[0, :] + ref[1, :] = 42.0 + np.testing.assert_array_equal(result.asnumpy(), ref) + + def test_write_map_const_list_and_const_list(): def testee(): domain = runtime.UnstructuredDomain({E: range(2)}) From d4b1d7567b1286959b83c7914701f272c1460bc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Mon, 4 Nov 2024 13:41:10 +0100 Subject: [PATCH 030/178] fix[DaCe]: Disable Some Transformations (#1711) DaCe's `MapReduceFusion` and `MapWCRFusion` are interesting as they move the initialization of the reduction accumulator away, which enables more fusion. However, they currently have a bug, as they assume that the reduction node is in the global scope and not inside a map scope. --- .../runners/dace_fieldview/transformations/auto_opt.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 37cc89aa2b..e070cdfe4e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -257,9 +257,12 @@ def gt_auto_optimize( sdfg.apply_transformations_repeated( [ dace_dataflow.TrivialMapElimination, - # TODO(phimuell): Investigate if these two are appropriate. - dace_dataflow.MapReduceFusion, - dace_dataflow.MapWCRFusion, + # TODO(phimuell): The transformation are interesting, but they have + # a bug as they assume that they are not working inside a map scope. + # Before we use them we have to fix them. + # https://chat.spcl.inf.ethz.ch/spcl/pl/8mtgtqjb378hfy7h9a96sy3nhc + # dace_dataflow.MapReduceFusion, + # dace_dataflow.MapWCRFusion, ], validate=validate, validate_all=validate_all, From 725b6ba070f0f6eadda250aa308803a7ff30b685 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 4 Nov 2024 13:46:36 +0100 Subject: [PATCH 031/178] build: update gitpod image (#1722) --- .gitpod.Dockerfile | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.gitpod.Dockerfile b/.gitpod.Dockerfile index 967ae36f2e..5d02a0f436 100644 --- a/.gitpod.Dockerfile +++ b/.gitpod.Dockerfile @@ -1,8 +1,6 @@ -FROM gitpod/workspace-python +FROM gitpod/workspace-python-3.11 USER root RUN apt-get update \ && apt-get install -y libboost-dev \ && apt-get clean && rm -rf /var/cache/apt/* && rm -rf /var/lib/apt/lists/* && rm -rf /tmp/* USER gitpod -RUN pyenv install 3.10.2 -RUN pyenv global 3.10.2 From eea1fb63717beda13516137d9afb17d0e79d396a Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 4 Nov 2024 15:11:27 +0100 Subject: [PATCH 032/178] bug[next]: fix missing local kind in gtfn connectivity (#1715) The second dimension of a connectivity is a local dimension. Before we defaulted to make this dimension horizontal. Currently, this information is not used. --- .../codegens/gtfn/gtfn_module.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index d729a5ba2f..07eec0b64b 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -18,7 +18,6 @@ from gt4py._core import definitions as core_defs from gt4py.eve import codegen from gt4py.next import common -from gt4py.next.common import Connectivity, Dimension from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import LiftMode, fencil_to_program, pass_manager @@ -84,7 +83,7 @@ def _process_regular_arguments( self, program: itir.FencilDefinition | itir.Program, arg_types: tuple[ts.TypeSpec, ...], - offset_provider: dict[str, Connectivity | Dimension], + offset_provider: common.OffsetProvider, ) -> tuple[list[interface.Parameter], list[str]]: parameters: list[interface.Parameter] = [] arg_exprs: list[str] = [] @@ -107,20 +106,20 @@ def _process_regular_arguments( # translate sparse dimensions to tuple dtype dim_name = dim.value connectivity = offset_provider[dim_name] - assert isinstance(connectivity, Connectivity) + assert isinstance(connectivity, common.Connectivity) size = connectivity.max_neighbors arg = f"gridtools::sid::dimension_to_tuple_like({arg})" arg_exprs.append(arg) return parameters, arg_exprs def _process_connectivity_args( - self, offset_provider: dict[str, Connectivity | Dimension] + self, offset_provider: dict[str, common.Connectivity | common.Dimension] ) -> tuple[list[interface.Parameter], list[str]]: parameters: list[interface.Parameter] = [] arg_exprs: list[str] = [] for name, connectivity in offset_provider.items(): - if isinstance(connectivity, Connectivity): + if isinstance(connectivity, common.Connectivity): if connectivity.index_type not in [np.int32, np.int64]: raise ValueError( "Neighbor table indices must be of type 'np.int32' or 'np.int64'." @@ -131,7 +130,12 @@ def _process_connectivity_args( interface.Parameter( name=GENERATED_CONNECTIVITY_PARAM_PREFIX + name.lower(), type_=ts.FieldType( - dims=[connectivity.origin_axis, Dimension(name)], + dims=[ + connectivity.origin_axis, + common.Dimension( + name, kind=common.DimensionKind.LOCAL + ), # TODO(havogt): we should not use the name of the offset as the name of the local dimension + ], dtype=ts.ScalarType( type_translation.get_scalar_kind(connectivity.index_type) ), @@ -149,7 +153,7 @@ def _process_connectivity_args( arg_exprs.append( f"gridtools::hymap::keys::make_values({nbtbl})" ) - elif isinstance(connectivity, Dimension): + elif isinstance(connectivity, common.Dimension): pass else: raise AssertionError( @@ -162,7 +166,7 @@ def _process_connectivity_args( def _preprocess_program( self, program: itir.FencilDefinition | itir.Program, - offset_provider: dict[str, Connectivity | Dimension], + offset_provider: dict[str, common.Connectivity | common.Dimension], ) -> itir.Program: if isinstance(program, itir.FencilDefinition) and not self.enable_itir_transforms: return fencil_to_program.FencilToProgram().apply( @@ -196,7 +200,7 @@ def _preprocess_program( def generate_stencil_source( self, program: itir.FencilDefinition | itir.Program, - offset_provider: dict[str, Connectivity | Dimension], + offset_provider: dict[str, common.Connectivity | common.Dimension], column_axis: Optional[common.Dimension], ) -> str: new_program = self._preprocess_program(program, offset_provider) From 604e377273a6b187aa583818181c7e681056b428 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 5 Nov 2024 12:26:34 +0100 Subject: [PATCH 033/178] feat[next]: Enable GTIR dace backend in feature tests (#1705) Equivalent of #1702, but for dace backend. These code changes are tested on the GTIR integration branch. Note that SDFG auto-optimization is disabled by default, in both the CPU and GPU DaCe backends. We have to keep the GPU backend disabled in GT4Py feature tests because there are errors related to symbolic domain and sub-domain computation in following tests: - test_execution.py: `test_domain_input_bounds_1, test_domain_tuple` - test_where.py: `test_where_k_offset` - test_laplacian.py: `test_ffront_lap, test_ffront_skewedlap, test_ffront_laplap` These errors are solved by next DaCe release (v1.0.0). The GPU tests can be enabled once PR #1639 is merged. --- .../next/program_processors/runners/dace.py | 40 +++++++++----- .../runners/dace_common/dace_backend.py | 30 ++++++++--- .../runners/dace_common/workflow.py | 4 +- .../runners/dace_fieldview/gtir_sdfg.py | 21 ++++++-- .../runners/dace_fieldview/workflow.py | 30 ++++++++--- tests/next_tests/definitions.py | 13 ++++- .../feature_tests/dace/test_orchestration.py | 5 +- .../ffront_tests/ffront_test_utils.py | 6 ++- .../iterator_tests/test_program.py | 1 + .../dace_tests/test_gtir_to_sdfg.py | 53 +++---------------- 10 files changed, 119 insertions(+), 84 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace.py index 2db8e98804..9a45b6a29a 100644 --- a/src/gt4py/next/program_processors/runners/dace.py +++ b/src/gt4py/next/program_processors/runners/dace.py @@ -8,8 +8,8 @@ import factory -from gt4py.next import allocators as next_allocators, backend -from gt4py.next.ffront import foast_to_gtir, past_to_itir +from gt4py.next import backend +from gt4py.next.ffront import foast_to_gtir, foast_to_past, past_to_itir from gt4py.next.program_processors.runners.dace_fieldview import workflow as dace_fieldview_workflow from gt4py.next.program_processors.runners.dace_iterator import workflow as dace_iterator_workflow from gt4py.next.program_processors.runners.gtfn import GTFNBackendFactory @@ -25,12 +25,12 @@ class Params: ), ) auto_optimize = factory.Trait( - otf_workflow__translation__auto_optimize=True, name_temps="_opt" + otf_workflow__translation__auto_optimize=True, name_postfix="_opt" ) use_field_canonical_representation: bool = False name = factory.LazyAttribute( - lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}" + lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.itir" ) transforms = backend.DEFAULT_TRANSFORMS @@ -45,12 +45,28 @@ class Params: itir_cpu = run_dace_cpu itir_gpu = run_dace_gpu -gtir_cpu = backend.Backend( - name="dace.gtir.cpu", - executor=dace_fieldview_workflow.DaCeWorkflowFactory(), - allocator=next_allocators.StandardCPUFieldBufferAllocator(), - transforms=backend.Transforms( + +class DaCeFieldviewBackendFactory(GTFNBackendFactory): + class Params: + otf_workflow = factory.SubFactory( + dace_fieldview_workflow.DaCeWorkflowFactory, + device_type=factory.SelfAttribute("..device_type"), + auto_optimize=factory.SelfAttribute("..auto_optimize"), + ) + auto_optimize = factory.Trait(name_postfix="_opt") + + name = factory.LazyAttribute( + lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.gtir" + ) + + transforms = backend.Transforms( past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), - foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(cached=True), - ), -) + foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(), + field_view_op_to_prog=foast_to_past.operator_to_program_factory( + foast_to_itir_step=foast_to_gtir.adapted_foast_to_gtir_factory() + ), + ) + + +gtir_cpu = DaCeFieldviewBackendFactory(cached=True, auto_optimize=False) +gtir_gpu = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=False) diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index 5d3cc7a358..bbf45a822c 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -32,7 +32,7 @@ def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: # Note that 'ndarray.item()' always transforms the numpy scalar to a python scalar, # which may change its precision. To avoid this, we use here the empty tuple as index # for 'ndarray.__getitem__()'. - return arg.ndarray[()] + return arg.asnumpy()[()] # field domain offsets are not supported non_zero_offsets = [ (dim, dim_range) @@ -88,10 +88,19 @@ def _get_shape_args( for name, value in args.items(): for sym, size in zip(arrays[name].shape, value.shape, strict=True): if isinstance(sym, dace.symbol): - assert sym.name not in shape_args - shape_args[sym.name] = size + if sym.name not in shape_args: + shape_args[sym.name] = size + elif shape_args[sym.name] != size: + # The same shape symbol is used by all fields of a tuple, because the current assumption is that all fields + # in a tuple have the same dimensions and sizes. Therefore, this if-branch only exists to ensure that array + # size (i.e. the value assigned to the shape symbol) is the same for all fields in a tuple. + # TODO(edopao): change to `assert sym.name not in shape_args` to ensure that shape symbols are unique, + # once the assumption on tuples is removed. + raise ValueError( + f"Expected array size {sym.name} for arg {name} to be {shape_args[sym.name]}, got {size}." + ) elif sym != size: - raise RuntimeError( + raise ValueError( f"Expected shape {arrays[name].shape} for arg {name}, got {value.shape}." ) return shape_args @@ -109,10 +118,17 @@ def _get_stride_args( f"Stride ({stride_size} bytes) for argument '{sym}' must be a multiple of item size ({value.itemsize} bytes)." ) if isinstance(sym, dace.symbol): - assert sym.name not in stride_args - stride_args[str(sym)] = stride + if sym.name not in stride_args: + stride_args[str(sym)] = stride + elif stride_args[sym.name] != stride: + # See above comment in `_get_shape_args`, same for stride symbols of fields in a tuple. + # TODO(edopao): change to `assert sym.name not in stride_args` to ensure that stride symbols are unique, + # once the assumption on tuples is removed. + raise ValueError( + f"Expected array stride {sym.name} for arg {name} to be {stride_args[sym.name]}, got {stride}." + ) elif sym != stride: - raise RuntimeError( + raise ValueError( f"Expected stride {arrays[name].strides} for arg {name}, got {value.strides}." ) return stride_args diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace_common/workflow.py index ae0a24605d..91e83dba9d 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_common/workflow.py @@ -17,7 +17,7 @@ from dace.codegen.compiled_sdfg import _array_interface_ptr as get_array_interface_ptr from gt4py._core import definitions as core_defs -from gt4py.next import common, config +from gt4py.next import common, config, utils as gtx_utils from gt4py.next.otf import arguments, languages, stages, step_types, workflow from gt4py.next.otf.compilation import cache from gt4py.next.program_processors.runners.dace_common import dace_backend, utility as dace_utils @@ -116,7 +116,7 @@ def decorated_program( args = (*args, *arguments.iter_size_args(args)) if sdfg_program._lastargs: - kwargs = dict(zip(sdfg.arg_names, args, strict=True)) + kwargs = dict(zip(sdfg.arg_names, gtx_utils.flatten_nested_tuple(args), strict=True)) kwargs.update(dace_backend.get_sdfg_conn_args(sdfg, offset_provider, on_gpu)) use_fast_call = True diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 48c666a363..f19f78d9d2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -217,6 +217,7 @@ def _add_storage( name: str, gt_type: ts.DataType, transient: bool = True, + tuple_name: Optional[str] = None, ) -> list[tuple[str, ts.DataType]]: """ Add storage in the SDFG for a given GT4Py data symbol. @@ -236,6 +237,7 @@ def _add_storage( name: Symbol Name to be allocated. gt_type: GT4Py symbol type. transient: True when the data symbol has to be allocated as internal storage. + tuple_name: Must be set for tuple fields in order to use the same array shape and strides symbols. Returns: List of tuples '(data_name, gt_type)' where 'data_name' is the name of @@ -250,7 +252,9 @@ def _add_storage( name, gt_type, flatten=True ): tuple_fields.extend( - self._add_storage(sdfg, symbolic_arguments, tname, tsymbol_type, transient) + self._add_storage( + sdfg, symbolic_arguments, tname, tsymbol_type, transient, tuple_name=name + ) ) return tuple_fields @@ -260,16 +264,23 @@ def _add_storage( return self._add_storage(sdfg, symbolic_arguments, name, gt_type.dtype, transient) # handle default case: field with one or more dimensions dc_dtype = dace_utils.as_dace_type(gt_type.dtype) - # use symbolic shape, which allows to invoke the program with fields of different size; - # and symbolic strides, which enables decoupling the memory layout from generated code. - sym_shape, sym_strides = self._make_array_shape_and_strides(name, gt_type.dims) + if tuple_name is None: + # Use symbolic shape, which allows to invoke the program with fields of different size; + # and symbolic strides, which enables decoupling the memory layout from generated code. + sym_shape, sym_strides = self._make_array_shape_and_strides(name, gt_type.dims) + else: + # All fields in a tuple must have the same dims and sizes, + # therefore we use the same shape and strides symbols based on 'tuple_name'. + sym_shape, sym_strides = self._make_array_shape_and_strides( + tuple_name, gt_type.dims + ) sdfg.add_array(name, sym_shape, dc_dtype, strides=sym_strides, transient=transient) return [(name, gt_type)] elif isinstance(gt_type, ts.ScalarType): dc_dtype = dace_utils.as_dace_type(gt_type) - if name in symbolic_arguments: + if dace_utils.is_field_symbol(name) or name in symbolic_arguments: if name in sdfg.symbols: # Sometimes, when the field domain is implicitly derived from the # field domain, the gt4py lowering adds the field size as a scalar diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index f2953eb05f..85ae95c432 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -16,14 +16,16 @@ import factory from gt4py._core import definitions as core_defs -from gt4py.next import common, config +from gt4py.next import allocators as gtx_allocators, common, config from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import languages, recipes, stages, step_types, workflow from gt4py.next.otf.binding import interface from gt4py.next.otf.languages import LanguageSettings from gt4py.next.program_processors.runners.dace_common import workflow as dace_workflow -from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg -from gt4py.next.type_system import type_translation as tt +from gt4py.next.program_processors.runners.dace_fieldview import ( + gtir_sdfg, + transformations as gtx_transformations, +) @dataclasses.dataclass(frozen=True) @@ -33,7 +35,8 @@ class DaCeTranslator( ], step_types.TranslationStep[languages.SDFG, languages.LanguageSettings], ): - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + device_type: core_defs.DeviceType + auto_optimize: bool def _language_settings(self) -> languages.LanguageSettings: return languages.LanguageSettings( @@ -45,9 +48,18 @@ def generate_sdfg( ir: itir.Program, offset_provider: common.OffsetProvider, column_axis: Optional[common.Dimension], + auto_opt: bool, + on_gpu: bool, ) -> dace.SDFG: ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) - return gtir_sdfg.build_sdfg_from_gtir(ir=ir, offset_provider=offset_provider) + sdfg = gtir_sdfg.build_sdfg_from_gtir(ir, offset_provider=offset_provider) + + if auto_opt: + gtx_transformations.gt_auto_optimize(sdfg, gpu=on_gpu) + elif on_gpu: + gtx_transformations.gt_gpu_transformation(sdfg, try_removing_trivial_maps=False) + + return sdfg def __call__( self, inp: stages.CompilableProgram @@ -60,11 +72,13 @@ def __call__( program, inp.args.offset_provider, inp.args.column_axis, + auto_opt=self.auto_optimize, + on_gpu=(self.device_type == gtx_allocators.CUPY_DEVICE), ) param_types = tuple( - interface.Parameter(param, tt.from_value(arg)) - for param, arg in zip(sdfg.arg_names, inp.args.args) + interface.Parameter(param, arg_type) + for param, arg_type in zip(sdfg.arg_names, inp.args.args) ) module: stages.ProgramSource[languages.SDFG, languages.LanguageSettings] = ( @@ -98,10 +112,12 @@ class Params: cmake_build_type: config.CMakeBuildType = factory.LazyFunction( lambda: config.CMAKE_BUILD_TYPE ) + auto_optimize: bool = False translation = factory.SubFactory( DaCeTranslationStepFactory, device_type=factory.SelfAttribute("..device_type"), + auto_optimize=factory.SelfAttribute("..auto_optimize"), ) bindings = _no_bindings compilation = factory.SubFactory( diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 3fef43865b..1bcc3554a7 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -71,6 +71,7 @@ class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): DACE_CPU = "gt4py.next.program_processors.runners.dace.itir_cpu" DACE_GPU = "gt4py.next.program_processors.runners.dace.itir_gpu" GTIR_DACE_CPU = "gt4py.next.program_processors.runners.dace.gtir_cpu" + GTIR_DACE_GPU = "gt4py.next.program_processors.runners.dace.gtir_gpu" class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): @@ -145,11 +146,14 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), ] GTIR_DACE_SKIP_TEST_LIST = [ - (ALL, SKIP, UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] EMBEDDED_SKIP_LIST = [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), @@ -177,6 +181,11 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.GTIR_DACE_CPU: GTIR_DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.GTIR_DACE_GPU: GTIR_DACE_SKIP_TEST_LIST + + [ + # TODO(edopao): Enable when GPU codegen issues related to symbolic domain are fixed. + (ALL, XFAIL, UNSUPPORTED_MESSAGE), + ], ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index 306f0034b5..1da34db3c0 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -32,7 +32,10 @@ try: import dace - from gt4py.next.program_processors.runners.dace import run_dace_cpu, run_dace_gpu + from gt4py.next.program_processors.runners.dace import ( + itir_cpu as run_dace_cpu, + itir_gpu as run_dace_gpu, + ) except ImportError: dace: Optional[ModuleType] = None # type:ignore[no-redef] run_dace_cpu: Optional[next_backend.Backend] = None diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 333a2dae28..0ed3365969 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -62,12 +62,16 @@ def __gt_allocator__( next_tests.definitions.OptionalProgramBackendId.DACE_CPU, marks=pytest.mark.requires_dace, ), + pytest.param( + next_tests.definitions.OptionalProgramBackendId.DACE_GPU, + marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), + ), pytest.param( next_tests.definitions.OptionalProgramBackendId.GTIR_DACE_CPU, marks=pytest.mark.requires_dace, ), pytest.param( - next_tests.definitions.OptionalProgramBackendId.DACE_GPU, + next_tests.definitions.OptionalProgramBackendId.GTIR_DACE_GPU, marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), ), ], diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index db1c2a42aa..f6fd0a48d0 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -88,6 +88,7 @@ def index_program_shift(out, size): ) +@pytest.mark.starts_from_gtir_program @pytest.mark.uses_index_builtin def test_index_builtin_shift(program_processor): program_processor, validate = program_processor diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 9f5498b4a7..dea9f2879b 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -262,16 +262,8 @@ def test_gtir_tuple_args(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) x_fields = (a, a, b) - x_symbols = dict( - __x_0_size_0=FSYMBOLS["__x_size_0"], - __x_0_stride_0=FSYMBOLS["__x_stride_0"], - __x_1_0_size_0=FSYMBOLS["__x_size_0"], - __x_1_0_stride_0=FSYMBOLS["__x_stride_0"], - __x_1_1_size_0=FSYMBOLS["__y_size_0"], - __x_1_1_stride_0=FSYMBOLS["__y_stride_0"], - ) - sdfg(*x_fields, c, **FSYMBOLS, **x_symbols) + sdfg(*x_fields, c, **FSYMBOLS) assert np.allclose(c, a * 2 + b) @@ -432,16 +424,8 @@ def test_gtir_tuple_return(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a), np.empty_like(a)) - z_symbols = dict( - __z_0_0_size_0=FSYMBOLS["__x_size_0"], - __z_0_0_stride_0=FSYMBOLS["__x_stride_0"], - __z_0_1_size_0=FSYMBOLS["__x_size_0"], - __z_0_1_stride_0=FSYMBOLS["__x_stride_0"], - __z_1_size_0=FSYMBOLS["__x_size_0"], - __z_1_stride_0=FSYMBOLS["__x_stride_0"], - ) - sdfg(a, b, *z_fields, **FSYMBOLS, **z_symbols) + sdfg(a, b, *z_fields, **FSYMBOLS) assert np.allclose(z_fields[0], a + b) assert np.allclose(z_fields[1], a) assert np.allclose(z_fields[2], b) @@ -694,18 +678,11 @@ def test_gtir_cond_with_tuple_return(): b = np.random.rand(N) c = np.random.rand(N) - z_symbols = dict( - __z_0_size_0=FSYMBOLS["__x_size_0"], - __z_0_stride_0=FSYMBOLS["__x_stride_0"], - __z_1_size_0=FSYMBOLS["__x_size_0"], - __z_1_stride_0=FSYMBOLS["__x_stride_0"], - ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) for s in [False, True]: z_fields = (np.empty_like(a), np.empty_like(a)) - sdfg(a, b, c, *z_fields, pred=np.bool_(s), **FSYMBOLS, **z_symbols) + sdfg(a, b, c, *z_fields, pred=np.bool_(s), **FSYMBOLS) assert np.allclose(z_fields[0], a if s else b) assert np.allclose(z_fields[1], b if s else a) @@ -1833,14 +1810,8 @@ def test_gtir_let_lambda_with_tuple1(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a)) - z_symbols = dict( - __z_0_size_0=FSYMBOLS["__x_size_0"], - __z_0_stride_0=FSYMBOLS["__x_stride_0"], - __z_1_size_0=FSYMBOLS["__x_size_0"], - __z_1_stride_0=FSYMBOLS["__x_stride_0"], - ) - sdfg(a, b, *z_fields, **FSYMBOLS, **z_symbols) + sdfg(a, b, *z_fields, **FSYMBOLS) assert np.allclose(z_fields[0], a) assert np.allclose(z_fields[1], b) @@ -1879,16 +1850,8 @@ def test_gtir_let_lambda_with_tuple2(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a), np.empty_like(a)) - z_symbols = dict( - __z_0_size_0=FSYMBOLS["__x_size_0"], - __z_0_stride_0=FSYMBOLS["__x_stride_0"], - __z_1_size_0=FSYMBOLS["__x_size_0"], - __z_1_stride_0=FSYMBOLS["__x_stride_0"], - __z_2_size_0=FSYMBOLS["__x_size_0"], - __z_2_stride_0=FSYMBOLS["__x_stride_0"], - ) - sdfg(a, b, *z_fields, **FSYMBOLS, **z_symbols) + sdfg(a, b, *z_fields, **FSYMBOLS) assert np.allclose(z_fields[0], a + b) assert np.allclose(z_fields[1], val) assert np.allclose(z_fields[2], b) @@ -1939,13 +1902,9 @@ def test_gtir_if_scalars(): d2 = np.random.randint(0, 1000) sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) - x_symbols = dict( - __x_0_size_0=FSYMBOLS["__x_size_0"], - __x_0_stride_0=FSYMBOLS["__x_stride_0"], - ) for s in [False, True]: - sdfg(x_0=a, x_1_0=d1, x_1_1=d2, z=b, pred=np.bool_(s), **FSYMBOLS, **x_symbols) + sdfg(x_0=a, x_1_0=d1, x_1_1=d2, z=b, pred=np.bool_(s), **FSYMBOLS) assert np.allclose(b, (a + d1 if s else a + d2)) From 60bb7b17963d93d430725f7ddecbc9dc61b4d71f Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 5 Nov 2024 12:45:58 +0100 Subject: [PATCH 034/178] fix[cartesian]: While loops inside conditions (#1712) This is a follow-up from PR https://github.com/GridTools/gt4py/pull/1630. It combines two things 1. In `oir_to_npir`, we fix conditional `while` loops in `numpy` backend. After PR 1630 these were stuck under certain conditions. Tests coverage extended. 2. In `gtir_to_oir`, we cleaned up the now unused `mask` parameter, which was pre-PR 1630 needed to pass down the mask information. With PR 1630 we actually removed the need for that parameter to be passed along because we properly nest the nested statements. --- src/gt4py/cartesian/gtc/gtir_to_oir.py | 53 ++++--------------- src/gt4py/cartesian/gtc/numpy/oir_to_npir.py | 9 ++-- .../multi_feature_tests/test_suites.py | 30 +++++++++++ .../unit_tests/test_gtc/test_oir_to_npir.py | 13 +++++ 4 files changed, 58 insertions(+), 47 deletions(-) diff --git a/src/gt4py/cartesian/gtc/gtir_to_oir.py b/src/gt4py/cartesian/gtc/gtir_to_oir.py index 560cbf96cf..d36c2e5c4a 100644 --- a/src/gt4py/cartesian/gtc/gtir_to_oir.py +++ b/src/gt4py/cartesian/gtc/gtir_to_oir.py @@ -7,11 +7,11 @@ # SPDX-License-Identifier: BSD-3-Clause from dataclasses import dataclass, field -from typing import Any, List, Optional, Set, Union +from typing import Any, List, Set, Union from gt4py import eve -from gt4py.cartesian.gtc import common, gtir, oir, utils -from gt4py.cartesian.gtc.common import CartesianOffset, DataType, LogicalOperator, UnaryOperator +from gt4py.cartesian.gtc import gtir, oir, utils +from gt4py.cartesian.gtc.common import CartesianOffset, DataType, UnaryOperator from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_fields_extents @@ -118,15 +118,8 @@ def visit_NativeFuncCall(self, node: gtir.NativeFuncCall) -> oir.NativeFuncCall: ) # --- Statements --- - def visit_ParAssignStmt( - self, node: gtir.ParAssignStmt, *, mask: Optional[oir.Expr] = None, **kwargs: Any - ) -> Union[oir.AssignStmt, oir.MaskStmt]: - statement = oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right)) - if mask is None: - return statement - - # Wrap inside MaskStmt - return oir.MaskStmt(body=[statement], mask=mask, loc=node.loc) + def visit_ParAssignStmt(self, node: gtir.ParAssignStmt, **kwargs: Any) -> oir.AssignStmt: + return oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right)) def visit_HorizontalRestriction( self, node: gtir.HorizontalRestriction, **kwargs: Any @@ -138,24 +131,19 @@ def visit_HorizontalRestriction( return oir.HorizontalRestriction(mask=node.mask, body=body) - def visit_While( - self, node: gtir.While, *, mask: Optional[oir.Expr] = None, **kwargs: Any - ) -> oir.While: + def visit_While(self, node: gtir.While, **kwargs: Any) -> oir.While: body: List[oir.Stmt] = [] for statement in node.body: oir_statement = self.visit(statement, **kwargs) body.extend(utils.flatten_list(utils.listify(oir_statement))) condition: oir.Expr = self.visit(node.cond) - if mask: - condition = oir.BinaryOp(op=common.LogicalOperator.AND, left=mask, right=condition) return oir.While(cond=condition, body=body, loc=node.loc) def visit_FieldIfStmt( self, node: gtir.FieldIfStmt, *, - mask: Optional[oir.Expr] = None, ctx: Context, **kwargs: Any, ) -> List[Union[oir.AssignStmt, oir.MaskStmt]]: @@ -182,26 +170,17 @@ def visit_FieldIfStmt( loc=node.loc, ) - combined_mask: oir.Expr = condition - if mask: - combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc - ) body = utils.flatten_list( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body] ) - statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) + statements.append(oir.MaskStmt(body=body, mask=condition, loc=node.loc)) if node.false_branch: - combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition) - if mask: - combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc - ) + negated_condition = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc) body = utils.flatten_list( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.false_branch.body] ) - statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) + statements.append(oir.MaskStmt(body=body, mask=negated_condition, loc=node.loc)) return statements @@ -211,31 +190,21 @@ def visit_ScalarIfStmt( self, node: gtir.ScalarIfStmt, *, - mask: Optional[oir.Expr] = None, ctx: Context, **kwargs: Any, ) -> List[oir.MaskStmt]: condition = self.visit(node.cond) - combined_mask = condition - if mask: - combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, left=mask, right=condition, loc=node.loc - ) - body = utils.flatten_list( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body] ) statements = [oir.MaskStmt(body=body, mask=condition, loc=node.loc)] if node.false_branch: - combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc) - if mask: - combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask) - + negated_condition = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc) body = utils.flatten_list( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.false_branch.body] ) - statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) + statements.append(oir.MaskStmt(body=body, mask=negated_condition, loc=node.loc)) return statements diff --git a/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py b/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py index ed573ebfff..b6aeb49823 100644 --- a/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py +++ b/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py @@ -157,13 +157,12 @@ def visit_AssignStmt( def visit_While( self, node: oir.While, *, mask: Optional[npir.Expr] = None, **kwargs: Any ) -> npir.While: - cond = self.visit(node.cond, mask=mask, **kwargs) + cond_expr = self.visit(node.cond, **kwargs) if mask: - mask = npir.VectorLogic(op=common.LogicalOperator.AND, left=mask, right=cond) - else: - mask = cond + cond_expr = npir.VectorLogic(op=common.LogicalOperator.AND, left=mask, right=cond_expr) + return npir.While( - cond=cond, body=utils.flatten_list(self.visit(node.body, mask=mask, **kwargs)) + cond=cond_expr, body=utils.flatten_list(self.visit(node.body, mask=cond_expr, **kwargs)) ) def visit_HorizontalRestriction( diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py index d3a5744389..0312aea7c3 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py @@ -444,6 +444,36 @@ def validation(field_a, field_b, field_c, *, factor, domain, origin, **kwargs): field_a += 1 +class TestRuntimeIfNestedWhile(gt_testing.StencilTestSuite): + """Test conditional while statements.""" + + dtypes = (np.float_,) + domain_range = [(1, 15), (1, 15), (1, 15)] + backends = ALL_BACKENDS + symbols = dict( + infield=gt_testing.field(in_range=(-1, 1), boundary=[(0, 0), (0, 0), (0, 0)]), + outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]), + ) + + def definition(infield, outfield): + with computation(PARALLEL), interval(...): + if infield < 10: + outfield = 1 + done = False + while not done: + outfield = 2 + done = True + else: + condition = True + while condition: + outfield = 4 + condition = False + outfield = 3 + + def validation(infield, outfield, *, domain, origin, **kwargs): + outfield[...] = 2 + + class TestTernaryOp(gt_testing.StencilTestSuite): dtypes = (np.float_,) domain_range = [(1, 15), (2, 15), (1, 15)] diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py index 4de7f9f5d6..4877a39503 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py @@ -28,6 +28,7 @@ StencilFactory, VerticalLoopFactory, VerticalLoopSectionFactory, + WhileFactory, ) @@ -78,6 +79,18 @@ def test_mask_stmt_to_assigns() -> None: assert len(assign_stmts) == 1 +def test_mask_stmt_to_while() -> None: + mask_oir = MaskStmtFactory(body=[WhileFactory()]) + statements = OirToNpir().visit(mask_oir, extent=Extent.zeros(ndims=2)) + assert len(statements) == 1 + assert isinstance(statements[0], npir.While) + condition = statements[0].cond + assert isinstance(condition, npir.VectorLogic) + assert condition.op == common.LogicalOperator.AND + mask_npir = OirToNpir().visit(mask_oir.mask) + assert condition.left == mask_npir or condition.right == mask_npir + + def test_mask_propagation() -> None: mask_stmt = MaskStmtFactory() assign_stmts = OirToNpir().visit(mask_stmt, extent=Extent.zeros(ndims=2)) From 6873a0e7f87e43ed717cc22611793a81afce0c93 Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 6 Nov 2024 09:54:22 +0100 Subject: [PATCH 035/178] fix[next][dace]: Fix for nested SDFG outer data descriptor (#1726) Fixes a bug in lowering of let-statements. The lambda expression of a let-statement is lowered to a nested SDFG. The result data produced in the nested SDFG is written to temporary data allocated in the parent SDFG. The previous lowering was directly using a copy of the inner data descriptor for the outer data. The bug is that some symbols for array shape and strides might not be available in the parent SDFG, so we have to apply the symbol mapping on the outer data descriptor. The test case `test_gtir_let_lambda_with_cond` was modified to trigger this bug and verify the fix. --- .../runners/dace_fieldview/gtir_sdfg.py | 66 +++++++++++-------- .../dace_tests/test_gtir_to_sdfg.py | 16 ++--- 2 files changed, 45 insertions(+), 37 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index f19f78d9d2..da940e883c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -709,49 +709,61 @@ def _flatten_tuples( head_state.add_edge(src_node, None, nsdfg_node, connector, memlet) - def make_temps( - output_data: gtir_builtin_translators.FieldopData, + def construct_output_for_nested_sdfg( + inner_data: gtir_builtin_translators.FieldopData, ) -> gtir_builtin_translators.FieldopData: """ - This function will be called while traversing the result of the lambda - dataflow to setup the intermediate data nodes in the parent SDFG and - the data edges from the nested-SDFG output connectors. + This function makes a data container that lives inside a nested SDFG, denoted by `inner_data`, + available in the parent SDFG. + In order to achieve this, the data container inside the nested SDFG is marked as non-transient + (in other words, externally allocated - a requirement of the SDFG IR) and a new data container + is created within the parent SDFG, with the same properties (shape, stride, etc.) of `inner_data` + but appropriatly remapped using the symbol mapping table. + For lambda arguments that are simply returned by the lambda, the `inner_data` was already mapped + to a parent SDFG data container, therefore it can be directly accessed in the parent SDFG. + The same happens to symbols available in the lambda context but not explicitly passed as lambda + arguments, that are simply returned by the lambda: it can be directly accessed in the parent SDFG. """ - desc = output_data.dc_node.desc(nsdfg) - if desc.transient: - # Transient nodes actually contain some result produced by the dataflow - # itself, therefore these nodes are changed to non-transient and an output - # edge will write the result from the nested-SDFG to a new intermediate - # data node in the parent context. - desc.transient = False - temp, _ = sdfg.add_temp_transient_like(desc) - connector = output_data.dc_node.data - dst_node = head_state.add_access(temp) + inner_desc = inner_data.dc_node.desc(nsdfg) + if inner_desc.transient: + # Transient data nodes only exist within the nested SDFG. In order to return some result data, + # the corresponding data container inside the nested SDFG has to be changed to non-transient, + # that is externally allocated, as required by the SDFG IR. An output edge will write the result + # from the nested-SDFG to a new intermediate data container allocated in the parent SDFG. + inner_desc.transient = False + outer, outer_desc = sdfg.add_temp_transient_like(inner_desc) + # We cannot use a copy of the inner data descriptor directly, we have to apply the symbol mapping. + dace.symbolic.safe_replace( + nsdfg_symbols_mapping, + lambda m: dace.sdfg.replace_properties_dict(outer_desc, m), + ) + connector = inner_data.dc_node.data + outer_node = head_state.add_access(outer) head_state.add_edge( - nsdfg_node, connector, dst_node, None, sdfg.make_array_memlet(temp) + nsdfg_node, connector, outer_node, None, sdfg.make_array_memlet(outer) ) - temp_field = gtir_builtin_translators.FieldopData( - dst_node, output_data.gt_dtype, output_data.local_offset + outer_data = gtir_builtin_translators.FieldopData( + outer_node, inner_data.gt_dtype, inner_data.local_offset ) - elif output_data.dc_node.data in lambda_arg_nodes: + elif inner_data.dc_node.data in lambda_arg_nodes: # This if branch and the next one handle the non-transient result nodes. # Non-transient nodes are just input nodes that are immediately returned # by the lambda expression. Therefore, these nodes are already available # in the parent context and can be directly accessed there. - temp_field = lambda_arg_nodes[output_data.dc_node.data] + outer_data = lambda_arg_nodes[inner_data.dc_node.data] else: - dc_node = head_state.add_access(output_data.dc_node.data) - temp_field = gtir_builtin_translators.FieldopData( - dc_node, output_data.gt_dtype, output_data.local_offset + outer_node = head_state.add_access(inner_data.dc_node.data) + outer_data = gtir_builtin_translators.FieldopData( + outer_node, inner_data.gt_dtype, inner_data.local_offset ) # Isolated access node will make validation fail. # Isolated access nodes can be found in the join-state of an if-expression # or in lambda expressions that just construct tuples from input arguments. - if nstate.degree(output_data.dc_node) == 0: - nstate.remove_node(output_data.dc_node) - return temp_field + if nstate.degree(inner_data.dc_node) == 0: + nstate.remove_node(inner_data.dc_node) + return outer_data - return gtx_utils.tree_map(make_temps)(lambda_result) + return gtx_utils.tree_map(construct_output_for_nested_sdfg)(lambda_result) def visit_Literal( self, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index dea9f2879b..cc72adae4f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -621,7 +621,7 @@ def test_gtir_cond(): expr=im.op_as_fieldop("plus", domain)( "x", im.if_( - im.greater(gtir.SymRef(id="s1"), gtir.SymRef(id="s2")), + im.greater("s1", "s2"), im.op_as_fieldop("plus", domain)("y", "scalar"), im.op_as_fieldop("plus", domain)("w", "scalar"), ), @@ -663,7 +663,7 @@ def test_gtir_cond_with_tuple_return(): expr=im.tuple_get( 0, im.if_( - gtir.SymRef(id="pred"), + "pred", im.make_tuple(im.make_tuple("x", "y"), "w"), im.make_tuple(im.make_tuple("y", "x"), "w"), ), @@ -703,10 +703,10 @@ def test_gtir_cond_nested(): body=[ gtir.SetAt( expr=im.if_( - gtir.SymRef(id="pred_1"), + "pred_1", im.op_as_fieldop("plus", domain)("x", 1.0), im.if_( - gtir.SymRef(id="pred_2"), + "pred_2", im.op_as_fieldop("plus", domain)("x", 2.0), im.op_as_fieldop("plus", domain)("x", 3.0), ), @@ -1534,7 +1534,7 @@ def test_gtir_reduce_with_cond_neighbors(): vertex_domain, )( im.if_( - gtir.SymRef(id="pred"), + "pred", im.as_fieldop_neighbors("V2E_FULL", "edges", vertex_domain), im.as_fieldop_neighbors("V2E", "edges", vertex_domain), ) @@ -1756,11 +1756,7 @@ def test_gtir_let_lambda_with_cond(): gtir.SetAt( expr=im.let("x1", "x")( im.let("x2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))( - im.if_( - gtir.SymRef(id="pred"), - im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x1"), - im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x2"), - ) + im.if_("pred", "x1", "x2") ) ), domain=domain, From 1b9eb5c044ba3a93c1af5f0005f01c835fa79ec0 Mon Sep 17 00:00:00 2001 From: SF-N Date: Thu, 7 Nov 2024 15:04:37 +0100 Subject: [PATCH 036/178] feat[next]: Add memory and disk-based caching to more workflow steps (#1690) Add memory and disk-based caching to other workflow steps and, therefore, removing unnecessary overhead of Program calls and significantly improving time to first computed value. Changes: - setting `cached = True` for `func_to_past_factory` - wrapping the GTFN code generation into a `CachedStep` (using `Diskcache`) which is activated when setting `otf_workflow__cached_translation=True`, similar as in [PR#1474](https://github.com/GridTools/gt4py/pull/1474) (without CachedStep) - Fixing hash function of `ProgramDefinition` - New tests for added functionality This leads to a runtime decrease of about 25% for PMAP-G in the advect-uniform testcase (5 hours) after caches are populated. TODOs: - [x] improving hash functions of `fingerprint_stage` --------- Co-authored-by: Till Ehrengruber Co-authored-by: Enrique G. Paredes --- .pre-commit-config.yaml | 3 +- constraints.txt | 9 +- min-extra-requirements-test.txt | 1 + min-requirements-test.txt | 1 + pyproject.toml | 1 + requirements-dev.txt | 9 +- src/gt4py/next/ffront/func_to_past.py | 2 +- src/gt4py/next/ffront/stages.py | 8 +- src/gt4py/next/iterator/ir.py | 4 +- src/gt4py/next/otf/workflow.py | 8 +- .../codegens/gtfn/gtfn_module.py | 1 + .../next/program_processors/runners/gtfn.py | 61 ++++++++-- .../ffront_tests/test_execution.py | 5 +- .../gtfn_tests/test_gtfn_module.py | 112 ++++++++++++++++++ 14 files changed, 200 insertions(+), 25 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 880a422160..f2f5b73613 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -94,13 +94,14 @@ repos: - astunparse==1.6.3 - attrs==24.2.0 - black==24.8.0 - - boltons==24.0.0 + - boltons==24.1.0 - cached-property==2.0.1 - click==8.1.7 - cmake==3.30.5 - cytoolz==1.0.0 - deepdiff==8.0.1 - devtools==0.12.2 + - diskcache==5.6.3 - factory-boy==3.3.1 - frozendict==2.4.6 - gridtools-cpp==2.3.6 diff --git a/constraints.txt b/constraints.txt index e846d4126c..e7acc466cd 100644 --- a/constraints.txt +++ b/constraints.txt @@ -13,10 +13,10 @@ attrs==24.2.0 # via gt4py (pyproject.toml), hypothesis, jsonschema, babel==2.16.0 # via sphinx backcall==0.2.0 # via ipython black==24.8.0 # via gt4py (pyproject.toml) -boltons==24.0.0 # via gt4py (pyproject.toml) +boltons==24.1.0 # via gt4py (pyproject.toml) bracex==2.5.post1 # via wcmatch build==1.2.2.post1 # via pip-tools -bump-my-version==0.28.0 # via -r requirements-dev.in +bump-my-version==0.28.1 # via -r requirements-dev.in cached-property==2.0.1 # via gt4py (pyproject.toml) cachetools==5.5.0 # via tox certifi==2024.8.30 # via requests @@ -40,6 +40,7 @@ decorator==5.1.1 # via ipython deepdiff==8.0.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.9 # via dace +diskcache==5.6.3 # via gt4py (pyproject.toml) distlib==0.3.9 # via virtualenv docutils==0.20.1 # via sphinx, sphinx-rtd-theme exceptiongroup==1.2.2 # via hypothesis, pytest @@ -135,7 +136,7 @@ pyzmq==26.2.0 # via ipykernel, jupyter-client questionary==2.0.1 # via bump-my-version referencing==0.35.1 # via jsonschema, jsonschema-specifications requests==2.32.3 # via sphinx -rich==13.9.3 # via bump-my-version, rich-click, tach +rich==13.9.4 # via bump-my-version, rich-click, tach rich-click==1.8.3 # via bump-my-version rpds-py==0.20.1 # via jsonschema, referencing ruff==0.7.2 # via -r requirements-dev.in @@ -158,7 +159,7 @@ stack-data==0.6.3 # via ipython stdlib-list==0.10.0 # via tach sympy==1.12.1 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.14.1 # via -r requirements-dev.in +tach==0.14.2 # via -r requirements-dev.in tomli==2.0.2 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via tach tomlkit==0.13.2 # via bump-my-version diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 7fea11bc3d..f63042906c 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -65,6 +65,7 @@ dace==0.16.1 darglint==1.6 deepdiff==5.6.0 devtools==0.6 +diskcache==5.6.3 factory-boy==3.3.0 frozendict==2.3 gridtools-cpp==2.3.6 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index c20883e25e..666aa79107 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -61,6 +61,7 @@ cytoolz==0.12.1 darglint==1.6 deepdiff==5.6.0 devtools==0.6 +diskcache==5.6.3 factory-boy==3.3.0 frozendict==2.3 gridtools-cpp==2.3.6 diff --git a/pyproject.toml b/pyproject.toml index 64f08e671e..c9f7b3b50b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ 'cytoolz>=0.12.1', 'deepdiff>=5.6.0', 'devtools>=0.6', + 'diskcache>=5.6.3', 'factory-boy>=3.3.0', 'frozendict>=2.3', 'gridtools-cpp>=2.3.6,==2.*', diff --git a/requirements-dev.txt b/requirements-dev.txt index eb757e0afd..a036307e80 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -13,10 +13,10 @@ attrs==24.2.0 # via -c constraints.txt, gt4py (pyproject.toml), hypo babel==2.16.0 # via -c constraints.txt, sphinx backcall==0.2.0 # via -c constraints.txt, ipython black==24.8.0 # via -c constraints.txt, gt4py (pyproject.toml) -boltons==24.0.0 # via -c constraints.txt, gt4py (pyproject.toml) +boltons==24.1.0 # via -c constraints.txt, gt4py (pyproject.toml) bracex==2.5.post1 # via -c constraints.txt, wcmatch build==1.2.2.post1 # via -c constraints.txt, pip-tools -bump-my-version==0.28.0 # via -c constraints.txt, -r requirements-dev.in +bump-my-version==0.28.1 # via -c constraints.txt, -r requirements-dev.in cached-property==2.0.1 # via -c constraints.txt, gt4py (pyproject.toml) cachetools==5.5.0 # via -c constraints.txt, tox certifi==2024.8.30 # via -c constraints.txt, requests @@ -40,6 +40,7 @@ decorator==5.1.1 # via -c constraints.txt, ipython deepdiff==8.0.1 # via -c constraints.txt, gt4py (pyproject.toml) devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml) dill==0.3.9 # via -c constraints.txt, dace +diskcache==5.6.3 # via -c constraints.txt, gt4py (pyproject.toml) distlib==0.3.9 # via -c constraints.txt, virtualenv docutils==0.20.1 # via -c constraints.txt, sphinx, sphinx-rtd-theme exceptiongroup==1.2.2 # via -c constraints.txt, hypothesis, pytest @@ -135,7 +136,7 @@ pyzmq==26.2.0 # via -c constraints.txt, ipykernel, jupyter-client questionary==2.0.1 # via -c constraints.txt, bump-my-version referencing==0.35.1 # via -c constraints.txt, jsonschema, jsonschema-specifications requests==2.32.3 # via -c constraints.txt, sphinx -rich==13.9.3 # via -c constraints.txt, bump-my-version, rich-click, tach +rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach rich-click==1.8.3 # via -c constraints.txt, bump-my-version rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing ruff==0.7.2 # via -c constraints.txt, -r requirements-dev.in @@ -157,7 +158,7 @@ stack-data==0.6.3 # via -c constraints.txt, ipython stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.12.1 # via -c constraints.txt, dace, gt4py (pyproject.toml) tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.14.1 # via -c constraints.txt, -r requirements-dev.in +tach==0.14.2 # via -c constraints.txt, -r requirements-dev.in tomli==2.0.2 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index f415c95b63..09f53be600 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -64,7 +64,7 @@ def func_to_past(inp: DSL_PRG) -> PRG: ) -def func_to_past_factory(cached: bool = False) -> workflow.Workflow[DSL_PRG, PRG]: +def func_to_past_factory(cached: bool = True) -> workflow.Workflow[DSL_PRG, PRG]: """ Wrap `func_to_past` in a chainable and optionally cached workflow step. diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index bf3bee4b56..834536ff59 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -100,6 +100,7 @@ def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> No @add_content_to_fingerprint.register(FieldOperatorDefinition) @add_content_to_fingerprint.register(FoastOperatorDefinition) +@add_content_to_fingerprint.register(ProgramDefinition) @add_content_to_fingerprint.register(PastProgramDefinition) @add_content_to_fingerprint.register(toolchain.CompilableProgram) @add_content_to_fingerprint.register(arguments.CompileTimeArgs) @@ -121,10 +122,14 @@ def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgo for item in sourcedef: add_content_to_fingerprint(item, hasher) + closure_vars = source_utils.get_closure_vars_from_function(obj) + for item in sorted(closure_vars.items(), key=lambda x: x[0]): + add_content_to_fingerprint(item, hasher) + @add_content_to_fingerprint.register def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None: - for key, value in obj.items(): + for key, value in sorted(obj.items()): add_content_to_fingerprint(key, hasher) add_content_to_fingerprint(value, hasher) @@ -148,4 +153,3 @@ def add_foast_located_node_to_fingerprint( ) -> None: add_content_to_fingerprint(obj.location, hasher) add_content_to_fingerprint(str(obj), hasher) - add_content_to_fingerprint(str(obj), hasher) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index b6f543e9d1..f50d8080eb 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -208,7 +208,9 @@ class FencilDefinition(Node, ValidatedSymbolTableTrait): closures: List[StencilClosure] implicit_domain: bool = False - _NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in BUILTINS] + _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ + Sym(id=name) for name in sorted(BUILTINS) + ] # sorted for serialization stability class Stmt(Node): ... diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index a63801c97e..ef3a4083b9 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -12,6 +12,7 @@ import dataclasses import functools import typing +from collections.abc import MutableMapping from typing import Any, Callable, Generic, Protocol, TypeVar from typing_extensions import Self @@ -253,16 +254,15 @@ class CachedStep( step: Workflow[StartT, EndT] hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment] - - _cache: dict[HashT, EndT] = dataclasses.field(repr=False, init=False, default_factory=dict) + cache: MutableMapping[HashT, EndT] = dataclasses.field(repr=False, default_factory=dict) def __call__(self, inp: StartT) -> EndT: """Run the step only if the input is not cached, else return from cache.""" hash_ = self.hash_function(inp) try: - result = self._cache[hash_] + result = self.cache[hash_] except KeyError: - result = self._cache[hash_] = self.step(inp) + result = self.cache[hash_] = self.step(inp) return result diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 07eec0b64b..66d74d53cc 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -213,6 +213,7 @@ def generate_stencil_source( generated_code = GTFNIMCodegen.apply(gtfn_im_ir) else: generated_code = GTFNCodegen.apply(gtfn_ir) + return codegen.format_source("cpp", generated_code, style="LLVM") def __call__( diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 2275576081..4a788bf40c 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -8,16 +8,19 @@ import functools import warnings -from typing import Any +from typing import Any, Optional +import diskcache import factory import numpy.typing as npt import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators +from gt4py.eve import utils from gt4py.eve.utils import content_hash from gt4py.next import backend, common, config -from gt4py.next.iterator import transforms +from gt4py.next.common import Connectivity, Dimension +from gt4py.next.iterator import ir as itir, transforms from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler @@ -116,6 +119,37 @@ def compilation_hash(otf_closure: stages.CompilableProgram) -> int: ) +def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str: + """ + Generates a unique hash string for a stencil source program representing + the program, sorted offset_provider, and column_axis. + """ + program: itir.FencilDefinition | itir.Program = inp.data + offset_provider: dict[str, Connectivity | Dimension] = inp.args.offset_provider + column_axis: Optional[common.Dimension] = inp.args.column_axis + + program_hash = utils.content_hash( + ( + program, + sorted(offset_provider.items(), key=lambda el: el[0]), + column_axis, + ) + ) + + return program_hash + + +class FileCache(diskcache.Cache): + """ + This class extends `diskcache.Cache` to ensure the cache is closed upon deletion, + i.e. it ensures that any resources associated with the cache are properly + released when the instance is garbage collected. + """ + + def __del__(self) -> None: + self.close() + + class GTFNCompileWorkflowFactory(factory.Factory): class Meta: model = recipes.OTFCompileWorkflow @@ -129,10 +163,23 @@ class Params: lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) ) - translation = factory.SubFactory( - gtfn_module.GTFNTranslationStepFactory, - device_type=factory.SelfAttribute("..device_type"), - ) + cached_translation = factory.Trait( + translation=factory.LazyAttribute( + lambda o: workflow.CachedStep( + o.translation_, + hash_function=fingerprint_compilable_program, + cache=FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), + ) + ), + ) + + translation_ = factory.SubFactory( + gtfn_module.GTFNTranslationStepFactory, + device_type=factory.SelfAttribute("..device_type"), + ) + + translation = factory.LazyAttribute(lambda o: o.translation_) + bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] = ( nanobind.bind_source ) @@ -193,7 +240,7 @@ class Params: name_postfix="_imperative", otf_workflow__translation__use_imperative_backend=True ) -run_gtfn_cached = GTFNBackendFactory(cached=True) +run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True) run_gtfn_with_temporaries = GTFNBackendFactory(use_temporaries=True) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 7540d52fb3..27f94960dc 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -7,9 +7,12 @@ # SPDX-License-Identifier: BSD-3-Clause from functools import reduce - +from gt4py.next.otf import languages, stages, workflow +from gt4py.next.otf.binding import interface import numpy as np import pytest +import diskcache +from gt4py.eve import SymbolName import gt4py.next as gtx from gt4py.next import ( diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index e3e0ee474f..e64bd8a57d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -8,13 +8,25 @@ import numpy as np import pytest +import copy +import diskcache + import gt4py.next as gtx from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.otf import arguments, languages, stages from gt4py.next.program_processors.codegens.gtfn import gtfn_module +from gt4py.next.program_processors.runners import gtfn from gt4py.next.type_system import type_translation +from next_tests.integration_tests import cases +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import KDim + +from next_tests.integration_tests.cases import cartesian_case + +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + exec_alloc_descriptor, +) @pytest.fixture @@ -71,3 +83,103 @@ def test_codegen(fencil_example): assert module.entry_point.name == fencil.id assert any(d.name == "gridtools_cpu" for d in module.library_deps) assert module.language is languages.CPP + + +def test_hash_and_diskcache(fencil_example, tmp_path): + fencil, parameters = fencil_example + compilable_program = stages.CompilableProgram( + data=fencil, + args=arguments.CompileTimeArgs.from_concrete_no_size( + *parameters, **{"offset_provider": {}} + ), + ) + hash = gtfn.fingerprint_compilable_program(compilable_program) + + with diskcache.Cache(tmp_path) as cache: + cache[hash] = compilable_program + + # check content of cash file + with diskcache.Cache(tmp_path) as reopened_cache: + assert hash in reopened_cache + compilable_program_from_cache = reopened_cache[hash] + assert compilable_program == compilable_program_from_cache + del reopened_cache[hash] # delete data + + # hash creation is deterministic + assert hash == gtfn.fingerprint_compilable_program(compilable_program) + assert hash == gtfn.fingerprint_compilable_program(compilable_program_from_cache) + + # hash is different if program changes + altered_program_id = copy.deepcopy(compilable_program) + altered_program_id.data.id = "example2" + assert gtfn.fingerprint_compilable_program( + compilable_program + ) != gtfn.fingerprint_compilable_program(altered_program_id) + + altered_program_offset_provider = copy.deepcopy(compilable_program) + object.__setattr__(altered_program_offset_provider.args, "offset_provider", {"Koff": KDim}) + assert gtfn.fingerprint_compilable_program( + compilable_program + ) != gtfn.fingerprint_compilable_program(altered_program_offset_provider) + + altered_program_column_axis = copy.deepcopy(compilable_program) + object.__setattr__(altered_program_column_axis.args, "column_axis", KDim) + assert gtfn.fingerprint_compilable_program( + compilable_program + ) != gtfn.fingerprint_compilable_program(altered_program_column_axis) + + +def test_gtfn_file_cache(fencil_example): + fencil, parameters = fencil_example + compilable_program = stages.CompilableProgram( + data=fencil, + args=arguments.CompileTimeArgs.from_concrete_no_size( + *parameters, **{"offset_provider": {}} + ), + ) + cached_gtfn_translation_step = gtfn.GTFNBackendFactory( + gpu=False, cached=True, otf_workflow__cached_translation=True + ).executor.step.translation + + bare_gtfn_translation_step = gtfn.GTFNBackendFactory( + gpu=False, cached=True, otf_workflow__cached_translation=False + ).executor.step.translation + + cache_key = gtfn.fingerprint_compilable_program(compilable_program) + + # ensure the actual cached step in the backend generates the cache item for the test + if cache_key in (translation_cache := cached_gtfn_translation_step.cache): + del translation_cache[cache_key] + cached_gtfn_translation_step(compilable_program) + assert bare_gtfn_translation_step(compilable_program) == cached_gtfn_translation_step( + compilable_program + ) + + assert cache_key in cached_gtfn_translation_step.cache + assert ( + bare_gtfn_translation_step(compilable_program) + == cached_gtfn_translation_step.cache[cache_key] + ) + + +# TODO(egparedes): we should switch to use the cached backend by default and then remove this test +def test_gtfn_file_cache_whole_workflow(cartesian_case): + if cartesian_case.backend != gtfn.run_gtfn: + pytest.skip("Skipping backend.") + cartesian_case.backend = gtfn.GTFNBackendFactory( + gpu=False, cached=True, otf_workflow__cached_translation=True + ) + + @gtx.field_operator + def testee(a: cases.IJKField) -> cases.IJKField: + field_tuple = (a, a) + field_0 = field_tuple[0] + field_1 = field_tuple[1] + return field_0 + + # first call: this generates the cache file + cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) + # clearing the OTFCompileWorkflow cache such that the OTFCompileWorkflow step is executed again + object.__setattr__(cartesian_case.backend.executor, "cache", {}) + # second call: the cache file is used + cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) From 855ddc8101bd7fbfefe4ae095084f9061b9f8543 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 8 Nov 2024 12:05:44 +0100 Subject: [PATCH 037/178] fix[next]: Bugfix in dace-ITIR backend: use canonical name for field shape symbol (#1730) Shape and stride symbols are expected to match a canonical string pattern. This PR adopts the canonical pattern in ITIR dace backend. Note: it solves an issue encountered in dace orchestration tests. --- .../runners/dace_iterator/itir_to_sdfg.py | 3 ++- .../program_processors/runners/dace_iterator/utility.py | 9 ++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index a824760ce4..a0f4b83d35 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -15,6 +15,7 @@ import gt4py.eve as eve from gt4py.next import Dimension, DimensionKind from gt4py.next.common import Connectivity +from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef from gt4py.next.program_processors.runners.dace_common import utility as dace_utils @@ -103,7 +104,7 @@ def _make_array_shape_and_strides( tuple(shape, strides) The output tuple fields are arrays of dace symbolic expressions. """ - dtype = dace.int32 + dtype = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) sorted_dims = dace_utils.get_sorted_dims(dims) if sort_dims else list(enumerate(dims)) neighbor_tables = dace_utils.filter_connectivities(offset_provider) shape = [ diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index d808fbfbe1..d367eb0883 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -14,6 +14,7 @@ import gt4py.next.iterator.ir as itir from gt4py import eve from gt4py.next.common import Connectivity +from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.program_processors.runners.dace_common import utility as dace_utils @@ -132,9 +133,11 @@ def unique_var_name(): def new_array_symbols(name: str, ndim: int) -> tuple[list[dace.symbol], list[dace.symbol]]: - dtype = dace.int64 - shape = [dace.symbol(unique_name(f"{name}_shape{i}"), dtype) for i in range(ndim)] - strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i in range(ndim)] + dtype = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) + shape = [dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) for i in range(ndim)] + strides = [ + dace.symbol(dace_utils.field_stride_symbol_name(name, i), dtype) for i in range(ndim) + ] return shape, strides From 5ce0d9d9c569c7172dd2284a45bf67ff1ba68b31 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 11 Nov 2024 10:41:25 +0100 Subject: [PATCH 038/178] build: Bump gridtools-cpp to 2.3.7 in preparation of #1648 (#1732) #1648 exposed a compilation problem with nvcc which has been fixed in https://github.com/GridTools/gridtools/pull/1811 included in gridtools 2.3.7. --- .pre-commit-config.yaml | 8 ++++---- constraints.txt | 16 ++++++++-------- min-extra-requirements-test.txt | 2 +- min-requirements-test.txt | 2 +- pyproject.toml | 2 +- requirements-dev.txt | 16 ++++++++-------- 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f2f5b73613..93ea4685f4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,7 @@ repos: ## version = re.search('ruff==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: v{version}") ##]]] - rev: v0.7.2 + rev: v0.7.3 ##[[[end]]] hooks: # Run the linter. @@ -97,14 +97,14 @@ repos: - boltons==24.1.0 - cached-property==2.0.1 - click==8.1.7 - - cmake==3.30.5 + - cmake==3.31.0.1 - cytoolz==1.0.0 - deepdiff==8.0.1 - devtools==0.12.2 - diskcache==5.6.3 - factory-boy==3.3.1 - frozendict==2.4.6 - - gridtools-cpp==2.3.6 + - gridtools-cpp==2.3.7 - importlib-resources==6.4.5 - jinja2==3.1.4 - lark==1.2.2 @@ -112,7 +112,7 @@ repos: - nanobind==2.2.0 - ninja==1.11.1.1 - numpy==1.24.4 - - packaging==24.1 + - packaging==24.2 - pybind11==2.13.6 - setuptools==75.3.0 - tabulate==0.9.0 diff --git a/constraints.txt b/constraints.txt index e7acc466cd..4aca6645d5 100644 --- a/constraints.txt +++ b/constraints.txt @@ -25,7 +25,7 @@ chardet==5.2.0 # via tox charset-normalizer==3.4.0 # via requests clang-format==19.1.3 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.30.5 # via gt4py (pyproject.toml) +cmake==3.31.0.1 # via gt4py (pyproject.toml) cogapp==3.4.1 # via -r requirements-dev.in colorama==0.4.6 # via tox comm==0.2.2 # via ipykernel @@ -35,7 +35,7 @@ cycler==0.12.1 # via matplotlib cytoolz==1.0.0 # via gt4py (pyproject.toml) dace==0.16.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -debugpy==1.8.7 # via ipykernel +debugpy==1.8.8 # via ipykernel decorator==5.1.1 # via ipython deepdiff==8.0.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) @@ -55,7 +55,7 @@ fparser==0.1.4 # via dace frozendict==2.4.6 # via gt4py (pyproject.toml) gitdb==4.0.11 # via gitpython gitpython==3.1.43 # via tach -gridtools-cpp==2.3.6 # via gt4py (pyproject.toml) +gridtools-cpp==2.3.7 # via gt4py (pyproject.toml) hypothesis==6.113.0 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via pre-commit idna==3.10 # via requests @@ -66,7 +66,7 @@ inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest ipykernel==6.29.5 # via nbmake ipython==8.12.3 # via ipykernel -jedi==0.19.1 # via ipython +jedi==0.19.2 # via ipython jinja2==3.1.4 # via dace, gt4py (pyproject.toml), sphinx jsonschema==4.23.0 # via nbformat jsonschema-specifications==2023.12.1 # via jsonschema @@ -95,7 +95,7 @@ ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.9.1 # via pre-commit numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), matplotlib, scipy orderly-set==5.2.2 # via deepdiff -packaging==24.1 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox +packaging==24.2 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox parso==0.8.4 # via jedi pathspec==0.12.1 # via black pexpect==4.9.0 # via ipython @@ -139,7 +139,7 @@ requests==2.32.3 # via sphinx rich==13.9.4 # via bump-my-version, rich-click, tach rich-click==1.8.3 # via bump-my-version rpds-py==0.20.1 # via jsonschema, referencing -ruff==0.7.2 # via -r requirements-dev.in +ruff==0.7.3 # via -r requirements-dev.in scipy==1.10.1 # via gt4py (pyproject.toml) setuptools-scm==8.1.0 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil @@ -159,7 +159,7 @@ stack-data==0.6.3 # via ipython stdlib-list==0.10.0 # via tach sympy==1.12.1 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.14.2 # via -r requirements-dev.in +tach==0.14.3 # via -r requirements-dev.in tomli==2.0.2 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via tach tomlkit==0.13.2 # via bump-my-version @@ -174,7 +174,7 @@ virtualenv==20.27.1 # via pre-commit, tox wcmatch==10.0 # via bump-my-version wcwidth==0.2.13 # via prompt-toolkit websockets==13.1 # via dace -wheel==0.44.0 # via astunparse, pip-tools +wheel==0.45.0 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.20.2 # via importlib-metadata, importlib-resources diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index f63042906c..6fd3d1af55 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -68,7 +68,7 @@ devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 frozendict==2.3 -gridtools-cpp==2.3.6 +gridtools-cpp==2.3.7 hypothesis==6.0.0 importlib-resources==5.0; python_version < "3.9" jax[cpu]==0.4.18; python_version >= "3.10" diff --git a/min-requirements-test.txt b/min-requirements-test.txt index 666aa79107..b8779096c0 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -64,7 +64,7 @@ devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 frozendict==2.3 -gridtools-cpp==2.3.6 +gridtools-cpp==2.3.7 hypothesis==6.0.0 importlib-resources==5.0; python_version < "3.9" jinja2==3.0.0 diff --git a/pyproject.toml b/pyproject.toml index c9f7b3b50b..7d63f70f15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ 'diskcache>=5.6.3', 'factory-boy>=3.3.0', 'frozendict>=2.3', - 'gridtools-cpp>=2.3.6,==2.*', + 'gridtools-cpp>=2.3.7,==2.*', "importlib-resources>=5.0;python_version<'3.9'", 'jinja2>=3.0.0', 'lark>=1.1.2', diff --git a/requirements-dev.txt b/requirements-dev.txt index a036307e80..8892620786 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -25,7 +25,7 @@ chardet==5.2.0 # via -c constraints.txt, tox charset-normalizer==3.4.0 # via -c constraints.txt, requests clang-format==19.1.3 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via -c constraints.txt, black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.30.5 # via -c constraints.txt, gt4py (pyproject.toml) +cmake==3.31.0.1 # via -c constraints.txt, gt4py (pyproject.toml) cogapp==3.4.1 # via -c constraints.txt, -r requirements-dev.in colorama==0.4.6 # via -c constraints.txt, tox comm==0.2.2 # via -c constraints.txt, ipykernel @@ -35,7 +35,7 @@ cycler==0.12.1 # via -c constraints.txt, matplotlib cytoolz==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) dace==0.16.1 # via -c constraints.txt, gt4py (pyproject.toml) darglint==1.8.1 # via -c constraints.txt, -r requirements-dev.in -debugpy==1.8.7 # via -c constraints.txt, ipykernel +debugpy==1.8.8 # via -c constraints.txt, ipykernel decorator==5.1.1 # via -c constraints.txt, ipython deepdiff==8.0.1 # via -c constraints.txt, gt4py (pyproject.toml) devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml) @@ -55,7 +55,7 @@ fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython gitpython==3.1.43 # via -c constraints.txt, tach -gridtools-cpp==2.3.6 # via -c constraints.txt, gt4py (pyproject.toml) +gridtools-cpp==2.3.7 # via -c constraints.txt, gt4py (pyproject.toml) hypothesis==6.113.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via -c constraints.txt, pre-commit idna==3.10 # via -c constraints.txt, requests @@ -66,7 +66,7 @@ inflection==0.5.1 # via -c constraints.txt, pytest-factoryboy iniconfig==2.0.0 # via -c constraints.txt, pytest ipykernel==6.29.5 # via -c constraints.txt, nbmake ipython==8.12.3 # via -c constraints.txt, ipykernel -jedi==0.19.1 # via -c constraints.txt, ipython +jedi==0.19.2 # via -c constraints.txt, ipython jinja2==3.1.4 # via -c constraints.txt, dace, gt4py (pyproject.toml), sphinx jsonschema==4.23.0 # via -c constraints.txt, nbformat jsonschema-specifications==2023.12.1 # via -c constraints.txt, jsonschema @@ -95,7 +95,7 @@ ninja==1.11.1.1 # via -c constraints.txt, gt4py (pyproject.toml) nodeenv==1.9.1 # via -c constraints.txt, pre-commit numpy==1.24.4 # via -c constraints.txt, contourpy, dace, gt4py (pyproject.toml), matplotlib orderly-set==5.2.2 # via -c constraints.txt, deepdiff -packaging==24.1 # via -c constraints.txt, black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox +packaging==24.2 # via -c constraints.txt, black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox parso==0.8.4 # via -c constraints.txt, jedi pathspec==0.12.1 # via -c constraints.txt, black pexpect==4.9.0 # via -c constraints.txt, ipython @@ -139,7 +139,7 @@ requests==2.32.3 # via -c constraints.txt, sphinx rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach rich-click==1.8.3 # via -c constraints.txt, bump-my-version rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing -ruff==0.7.2 # via -c constraints.txt, -r requirements-dev.in +ruff==0.7.3 # via -c constraints.txt, -r requirements-dev.in setuptools-scm==8.1.0 # via -c constraints.txt, fparser six==1.16.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil smmap==5.0.1 # via -c constraints.txt, gitdb @@ -158,7 +158,7 @@ stack-data==0.6.3 # via -c constraints.txt, ipython stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.12.1 # via -c constraints.txt, dace, gt4py (pyproject.toml) tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.14.2 # via -c constraints.txt, -r requirements-dev.in +tach==0.14.3 # via -c constraints.txt, -r requirements-dev.in tomli==2.0.2 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version @@ -173,7 +173,7 @@ virtualenv==20.27.1 # via -c constraints.txt, pre-commit, tox wcmatch==10.0 # via -c constraints.txt, bump-my-version wcwidth==0.2.13 # via -c constraints.txt, prompt-toolkit websockets==13.1 # via -c constraints.txt, dace -wheel==0.44.0 # via -c constraints.txt, astunparse, pip-tools +wheel==0.45.0 # via -c constraints.txt, astunparse, pip-tools xxhash==3.0.0 # via -c constraints.txt, gt4py (pyproject.toml) zipp==3.20.2 # via -c constraints.txt, importlib-metadata, importlib-resources From 89fea8fa86eceea7e7cadb1b16924354267ebf0c Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 14 Nov 2024 14:44:23 +0100 Subject: [PATCH 039/178] bug[next]: Fix ITIR program hash stability (#1733) #1690 included a change to make the hash of an `itir.FencilDefinition` stable across multiple runs. This PR adopts the same change to an `itir.Program`, --- src/gt4py/next/iterator/ir.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index f50d8080eb..7098e9fa2e 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -242,7 +242,9 @@ class Program(Node, ValidatedSymbolTableTrait): body: List[Stmt] implicit_domain: bool = False - _NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in GTIR_BUILTINS] + _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ + Sym(id=name) for name in sorted(GTIR_BUILTINS) + ] # sorted for serialization stability # TODO(fthaler): just use hashable types in nodes (tuples instead of lists) From b60ffff3f14d272cfc5ee470c80b460358fd8add Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 14 Nov 2024 17:46:41 +0100 Subject: [PATCH 040/178] build: Bump gridtools-cpp to 2.3.8 in preparation of #1648 (#1737) #1648 exposed a compilation problem with nvcc which has been fixed in #1812 included in gridtools 2.3.8. --- .pre-commit-config.yaml | 2 +- constraints.txt | 10 +++++----- min-extra-requirements-test.txt | 2 +- min-requirements-test.txt | 2 +- pyproject.toml | 2 +- requirements-dev.txt | 10 +++++----- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 93ea4685f4..f56e84f8d9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -104,7 +104,7 @@ repos: - diskcache==5.6.3 - factory-boy==3.3.1 - frozendict==2.4.6 - - gridtools-cpp==2.3.7 + - gridtools-cpp==2.3.8 - importlib-resources==6.4.5 - jinja2==3.1.4 - lark==1.2.2 diff --git a/constraints.txt b/constraints.txt index 4aca6645d5..4247f4951d 100644 --- a/constraints.txt +++ b/constraints.txt @@ -47,7 +47,7 @@ exceptiongroup==1.2.2 # via hypothesis, pytest execnet==2.1.1 # via pytest-cache, pytest-xdist executing==2.1.0 # via devtools, stack-data factory-boy==3.3.1 # via gt4py (pyproject.toml), pytest-factoryboy -faker==30.8.2 # via factory-boy +faker==33.0.0 # via factory-boy fastjsonschema==2.20.0 # via nbformat filelock==3.16.1 # via tox, virtualenv fonttools==4.54.1 # via matplotlib @@ -55,7 +55,7 @@ fparser==0.1.4 # via dace frozendict==2.4.6 # via gt4py (pyproject.toml) gitdb==4.0.11 # via gitpython gitpython==3.1.43 # via tach -gridtools-cpp==2.3.7 # via gt4py (pyproject.toml) +gridtools-cpp==2.3.8 # via gt4py (pyproject.toml) hypothesis==6.113.0 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via pre-commit idna==3.10 # via requests @@ -137,7 +137,7 @@ questionary==2.0.1 # via bump-my-version referencing==0.35.1 # via jsonschema, jsonschema-specifications requests==2.32.3 # via sphinx rich==13.9.4 # via bump-my-version, rich-click, tach -rich-click==1.8.3 # via bump-my-version +rich-click==1.8.4 # via bump-my-version rpds-py==0.20.1 # via jsonschema, referencing ruff==0.7.3 # via -r requirements-dev.in scipy==1.10.1 # via gt4py (pyproject.toml) @@ -147,7 +147,7 @@ smmap==5.0.1 # via gitdb snowballstemmer==2.2.0 # via sphinx sortedcontainers==2.4.0 # via hypothesis sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==3.0.1 # via -r requirements-dev.in +sphinx-rtd-theme==3.0.2 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -160,7 +160,7 @@ stdlib-list==0.10.0 # via tach sympy==1.12.1 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) tach==0.14.3 # via -r requirements-dev.in -tomli==2.0.2 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox +tomli==2.1.0 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via tach tomlkit==0.13.2 # via bump-my-version toolz==1.0.0 # via cytoolz diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 6fd3d1af55..4190570105 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -68,7 +68,7 @@ devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 frozendict==2.3 -gridtools-cpp==2.3.7 +gridtools-cpp==2.3.8 hypothesis==6.0.0 importlib-resources==5.0; python_version < "3.9" jax[cpu]==0.4.18; python_version >= "3.10" diff --git a/min-requirements-test.txt b/min-requirements-test.txt index b8779096c0..81a1c2dea3 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -64,7 +64,7 @@ devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 frozendict==2.3 -gridtools-cpp==2.3.7 +gridtools-cpp==2.3.8 hypothesis==6.0.0 importlib-resources==5.0; python_version < "3.9" jinja2==3.0.0 diff --git a/pyproject.toml b/pyproject.toml index 7d63f70f15..1504c8b17b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ 'diskcache>=5.6.3', 'factory-boy>=3.3.0', 'frozendict>=2.3', - 'gridtools-cpp>=2.3.7,==2.*', + 'gridtools-cpp>=2.3.8,==2.*', "importlib-resources>=5.0;python_version<'3.9'", 'jinja2>=3.0.0', 'lark>=1.1.2', diff --git a/requirements-dev.txt b/requirements-dev.txt index 8892620786..ca7eb32487 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -47,7 +47,7 @@ exceptiongroup==1.2.2 # via -c constraints.txt, hypothesis, pytest execnet==2.1.1 # via -c constraints.txt, pytest-cache, pytest-xdist executing==2.1.0 # via -c constraints.txt, devtools, stack-data factory-boy==3.3.1 # via -c constraints.txt, gt4py (pyproject.toml), pytest-factoryboy -faker==30.8.2 # via -c constraints.txt, factory-boy +faker==33.0.0 # via -c constraints.txt, factory-boy fastjsonschema==2.20.0 # via -c constraints.txt, nbformat filelock==3.16.1 # via -c constraints.txt, tox, virtualenv fonttools==4.54.1 # via -c constraints.txt, matplotlib @@ -55,7 +55,7 @@ fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython gitpython==3.1.43 # via -c constraints.txt, tach -gridtools-cpp==2.3.7 # via -c constraints.txt, gt4py (pyproject.toml) +gridtools-cpp==2.3.8 # via -c constraints.txt, gt4py (pyproject.toml) hypothesis==6.113.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) identify==2.6.1 # via -c constraints.txt, pre-commit idna==3.10 # via -c constraints.txt, requests @@ -137,7 +137,7 @@ questionary==2.0.1 # via -c constraints.txt, bump-my-version referencing==0.35.1 # via -c constraints.txt, jsonschema, jsonschema-specifications requests==2.32.3 # via -c constraints.txt, sphinx rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach -rich-click==1.8.3 # via -c constraints.txt, bump-my-version +rich-click==1.8.4 # via -c constraints.txt, bump-my-version rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing ruff==0.7.3 # via -c constraints.txt, -r requirements-dev.in setuptools-scm==8.1.0 # via -c constraints.txt, fparser @@ -146,7 +146,7 @@ smmap==5.0.1 # via -c constraints.txt, gitdb snowballstemmer==2.2.0 # via -c constraints.txt, sphinx sortedcontainers==2.4.0 # via -c constraints.txt, hypothesis sphinx==7.1.2 # via -c constraints.txt, -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==3.0.1 # via -c constraints.txt, -r requirements-dev.in +sphinx-rtd-theme==3.0.2 # via -c constraints.txt, -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via -c constraints.txt, sphinx sphinxcontrib-devhelp==1.0.2 # via -c constraints.txt, sphinx sphinxcontrib-htmlhelp==2.0.1 # via -c constraints.txt, sphinx @@ -159,7 +159,7 @@ stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.12.1 # via -c constraints.txt, dace, gt4py (pyproject.toml) tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) tach==0.14.3 # via -c constraints.txt, -r requirements-dev.in -tomli==2.0.2 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox +tomli==2.1.0 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version toolz==1.0.0 # via -c constraints.txt, cytoolz From c51bdd1b6e515b2cebff466876d51b1bc0874096 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 15 Nov 2024 10:35:42 +0100 Subject: [PATCH 041/178] fix[next]: Fix type preservation in CSE (#1736) The common subexpression elimination uses typing information to decide what expressions can be extracted. However, while extracting it creates new nodes and uses the inline lambda pass, which did not preserve the types. This was observed in PMAP and is fixed in this PR on a best effort basis. Creating a minimal reproducible example is hard and since multiple of us are considering making typing information an integral part of the IR, e.g. by attaching the computation to the node instead of having a separate pass, which would solve the problem automatically no tests have been written. --- src/gt4py/next/iterator/transforms/cse.py | 12 ++++++------ src/gt4py/next/iterator/transforms/inline_lambdas.py | 5 ++++- src/gt4py/next/iterator/type_system/inference.py | 6 +++--- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index ccc1d2195f..4932d376ad 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -14,6 +14,7 @@ import operator from typing import Callable, Iterable, TypeVar, Union, cast +import gt4py.next.iterator.ir_utils.ir_makers as im from gt4py.eve import ( NodeTranslator, NodeVisitor, @@ -241,7 +242,6 @@ def extract_subexpression( Examples: Default case for `(x+y) + ((x+y)+z)`: - >>> import gt4py.next.iterator.ir_utils.ir_makers as im >>> from gt4py.eve.utils import UIDGenerator >>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z")) >>> predicate = lambda subexpr, num_occurences: num_occurences > 1 @@ -433,7 +433,9 @@ def predicate(subexpr: itir.Expr, num_occurences: int): if num_occurences > 1: if is_local_view: return True - else: + # condition is only necessary since typing on lambdas is not preserved during + # the transformation + elif not isinstance(subexpr, itir.Lambda): # only extract fields outside of `as_fieldop` # `as_fieldop(...)(field_expr, field_expr)` # -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)` @@ -451,10 +453,8 @@ def predicate(subexpr: itir.Expr, num_occurences: int): return self.generic_visit(node, **kwargs) # apply remapping - result = itir.FunCall( - fun=itir.Lambda(params=list(extracted.keys()), expr=new_expr), - args=list(extracted.values()), - ) + result = im.let(*extracted.items())(new_expr) + itir_type_inference.copy_type(from_=node, to=result, allow_untyped=True) # if the node id is ignored (because its parent is eliminated), but it occurs # multiple times then we want to visit the final result once more. diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 920d628166..399a7a3dc6 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -14,6 +14,7 @@ from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs +from gt4py.next.iterator.type_system import inference as itir_inference # TODO(tehrengruber): Reduce complexity of the function by removing the different options here @@ -98,7 +99,7 @@ def new_name(name): new_expr.location = node.location return new_expr else: - return ir.FunCall( + new_expr = ir.FunCall( fun=ir.Lambda( params=[ param @@ -110,6 +111,8 @@ def new_name(name): args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], location=node.location, ) + itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) + return new_expr @dataclasses.dataclass diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index edcb9b540c..66d8345b94 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -95,14 +95,14 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: node.type = type_ -def copy_type(from_: itir.Node, to: itir.Node) -> None: +def copy_type(from_: itir.Node, to: itir.Node, allow_untyped=False) -> None: """ Copy type from one node to another. This function mainly exists for readability reasons. """ - assert isinstance(from_.type, ts.TypeSpec) - _set_node_type(to, from_.type) + assert allow_untyped is not None or isinstance(from_.type, ts.TypeSpec) + _set_node_type(to, from_.type) # type: ignore[arg-type] def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: From 998f2792de75650447a1ffd96aec1e4ebc8dc882 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 15 Nov 2024 11:27:56 +0100 Subject: [PATCH 042/178] feat[next]: GTIR embedded and GTFN temporaries with new lowering (#1648) Use new lowering for GTIR embedded, and GTFN. Only the dace iterator backend continues to use the old lowering. Changes: - Use GTIR lowering for all backends except for dace - Old lowering and transformations only used in dace backend - workflows defined in [`gt4py.next.backend.LEGACY_TRANSFORMS`](https://github.com/GridTools/gt4py/pull/1648/files#diff-cf4385d02cbeacc310d4326350903b4cb6f9a61c7cd36dda162a5077ab8b8e86). Variable can be removed in a cleanup PR. - old `apply_common_transforms` in [pass_manager_legacy.py](https://github.com/GridTools/gt4py/pull/1648/files#diff-db17bff48ac16ee75ff974a1b9af98e3cf0c850971ce9898aa55b635bb046b72). Just a straight copy of the old function. No need to review, this is just to avoid deleting until gtir based dace backend is ready. - Re-add `symbolic_sizes` param. Was in temporary extraction, is now part of the domain inference. In preparation of icon-exclaim tests --------- Co-authored-by: Hannes Vogt --- .pre-commit-config.yaml | 1 - src/gt4py/next/backend.py | 14 +- src/gt4py/next/ffront/foast_to_past.py | 6 +- src/gt4py/next/ffront/past_to_itir.py | 2 +- .../next/iterator/ir_utils/domain_utils.py | 23 ++- src/gt4py/next/iterator/ir_utils/ir_makers.py | 5 +- .../next/iterator/transforms/__init__.py | 3 +- .../iterator/transforms/collapse_list_get.py | 52 +++--- .../iterator/transforms/collapse_tuple.py | 85 ++++++--- src/gt4py/next/iterator/transforms/cse.py | 43 +++-- .../iterator/transforms/fuse_as_fieldop.py | 44 +++-- .../next/iterator/transforms/global_tmps.py | 8 +- .../next/iterator/transforms/infer_domain.py | 140 ++++++++++---- .../iterator/transforms/inline_into_scan.py | 2 +- .../iterator/transforms/inline_lambdas.py | 7 +- .../next/iterator/transforms/inline_scalar.py | 31 +++ .../next/iterator/transforms/pass_manager.py | 176 ++++++------------ .../transforms/pass_manager_legacy.py | 175 +++++++++++++++++ .../next/iterator/transforms/remap_symbols.py | 8 +- .../next/iterator/transforms/unroll_reduce.py | 9 +- .../iterator/type_system/type_synthesizer.py | 6 +- .../codegens/gtfn/gtfn_ir.py | 58 ++++-- .../codegens/gtfn/gtfn_module.py | 17 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 26 ++- .../program_processors/formatters/lisp.py | 4 +- .../next/program_processors/runners/dace.py | 11 +- .../runners/dace_iterator/__init__.py | 13 +- .../runners/dace_iterator/workflow.py | 8 +- .../next/program_processors/runners/gtfn.py | 20 +- .../program_processors/runners/roundtrip.py | 16 +- src/gt4py/next/type_system/type_info.py | 4 +- tests/next_tests/definitions.py | 62 +++--- .../ffront_tests/ffront_test_utils.py | 1 - .../ffront_tests/test_decorator.py | 4 +- .../ffront_tests/test_execution.py | 24 ++- .../ffront_tests/test_scalar_if.py | 1 + .../test_temporaries_with_sizes.py | 26 +-- .../iterator_tests/test_builtins.py | 15 +- .../feature_tests/iterator_tests/test_scan.py | 2 + .../ffront_tests/test_icon_like_scan.py | 19 -- .../iterator_tests/test_anton_toy.py | 5 - .../iterator_tests/test_column_stencil.py | 51 ++--- .../iterator_tests/test_fvm_nabla.py | 1 - .../iterator_tests/test_if_stmt.py | 6 +- .../iterator_tests/test_vertical_advection.py | 58 ++---- .../test_with_toy_connectivity.py | 1 - tests/next_tests/unit_tests/conftest.py | 54 +++--- .../transforms_tests/test_collapse_tuple.py | 21 ++- .../transforms_tests/test_cse.py | 26 +-- .../transforms_tests/test_domain_inference.py | 25 ++- .../transforms_tests/test_fuse_as_fieldop.py | 25 +++ .../transforms_tests/test_unroll_reduce.py | 87 ++------- 52 files changed, 945 insertions(+), 586 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/inline_scalar.py create mode 100644 src/gt4py/next/iterator/transforms/pass_manager_legacy.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f56e84f8d9..07f75177ea 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,7 +43,6 @@ repos: - id: check-merge-conflict - id: check-toml - id: check-yaml - - id: debug-statements - repo: https://github.com/astral-sh/ruff-pre-commit ##[[[cog diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 0340d61f89..e223d7771c 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -15,6 +15,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import allocators as next_allocators from gt4py.next.ffront import ( + foast_to_gtir, foast_to_itir, foast_to_past, func_to_foast, @@ -76,7 +77,7 @@ class Transforms(workflow.MultiWorkflow[INPUT_PAIR, stages.CompilableProgram]): ) foast_to_itir: workflow.Workflow[AOT_FOP, itir.Expr] = dataclasses.field( - default_factory=foast_to_itir.adapted_foast_to_itir_factory + default_factory=foast_to_gtir.adapted_foast_to_gtir_factory ) field_view_op_to_prog: workflow.Workflow[AOT_FOP, AOT_PRG] = dataclasses.field( @@ -134,6 +135,17 @@ def step_order(self, inp: INPUT_PAIR) -> list[str]: DEFAULT_TRANSFORMS: Transforms = Transforms() +# FIXME[#1582](havogt): remove after refactoring to GTIR +# note: this step is deliberately placed here, such that the cache is shared +_foast_to_itir_step = foast_to_itir.adapted_foast_to_itir_factory(cached=True) +LEGACY_TRANSFORMS: Transforms = Transforms( + past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=False), + foast_to_itir=_foast_to_itir_step, + field_view_op_to_prog=foast_to_past.operator_to_program_factory( + foast_to_itir_step=_foast_to_itir_step + ), +) + # TODO(tehrengruber): Rename class and `executor` & `transforms` attribute. Maybe: # `Backend` -> `Toolchain` diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 312ac686a2..330bc79809 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -12,7 +12,7 @@ from gt4py.eve import utils as eve_utils from gt4py.next.ffront import ( dialect_ast_enums, - foast_to_itir, + foast_to_gtir, program_ast as past, stages as ffront_stages, type_specifications as ts_ffront, @@ -68,7 +68,7 @@ class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]): ... def copy(a: gtx.Field[[IDim], gtx.float32]) -> gtx.Field[[IDim], gtx.float32]: ... return a - >>> op_to_prog = OperatorToProgram(foast_to_itir.adapted_foast_to_itir_factory()) + >>> op_to_prog = OperatorToProgram(foast_to_gtir.adapted_foast_to_gtir_factory()) >>> compile_time_args = arguments.CompileTimeArgs( ... args=tuple(param.type for param in copy.foast_stage.foast_node.definition.params), @@ -169,7 +169,7 @@ def operator_to_program_factory( ) -> workflow.Workflow[AOT_FOP, AOT_PRG]: """Optionally wrap `OperatorToProgram` in a `CachedStep`.""" wf: workflow.Workflow[AOT_FOP, AOT_PRG] = OperatorToProgram( - foast_to_itir_step or foast_to_itir.adapted_foast_to_itir_factory() + foast_to_itir_step or foast_to_gtir.adapted_foast_to_gtir_factory() ) if cached: wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 14d705576e..c0348bb5c6 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -108,7 +108,7 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra # FIXME[#1582](havogt): remove `to_gtir` arg after refactoring to GTIR def past_to_itir_factory( - cached: bool = True, to_gtir: bool = False + cached: bool = True, to_gtir: bool = True ) -> workflow.Workflow[AOT_PRG, stages.CompilableProgram]: wf = workflow.make_step(functools.partial(past_to_itir, to_gtir=to_gtir)) if cached: diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 8eec405136..8f842e1c13 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -10,7 +10,7 @@ import dataclasses import functools -from typing import Any, Literal, Mapping +from typing import Any, Literal, Mapping, Optional import gt4py.next as gtx from gt4py.next import common @@ -93,6 +93,9 @@ def translate( ..., ], offset_provider: common.OffsetProvider, + #: A dictionary mapping axes names to their length. See + #: func:`gt4py.next.iterator.transforms.infer_domain.infer_expr` for more details. + symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> SymbolicDomain: dims = list(self.ranges.keys()) new_ranges = {dim: self.ranges[dim] for dim in dims} @@ -119,18 +122,24 @@ def translate( trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE, ] - # note: ugly but cheap re-computation, but should disappear - horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) + horizontal_sizes: dict[str, itir.Expr] + if symbolic_domain_sizes is not None: + horizontal_sizes = {k: im.ref(v) for k, v in symbolic_domain_sizes.items()} + else: + # note: ugly but cheap re-computation, but should disappear + horizontal_sizes = { + k: im.literal(str(v), itir.INTEGER_INDEX_BUILTIN) + for k, v in _max_domain_sizes_by_location_type(offset_provider).items() + } old_dim = nbt_provider.origin_axis new_dim = nbt_provider.neighbor_axis assert new_dim not in new_ranges or old_dim == new_dim - # TODO(tehrengruber): Do we need symbolic sizes, e.g., for ICON? new_range = SymbolicRange( im.literal("0", itir.INTEGER_INDEX_BUILTIN), - im.literal(str(horizontal_sizes[new_dim.value]), itir.INTEGER_INDEX_BUILTIN), + horizontal_sizes[new_dim.value], ) new_ranges = dict( (dim, range_) if dim != old_dim else (new_dim, new_range) @@ -140,7 +149,9 @@ def translate( raise AssertionError() return SymbolicDomain(self.grid_type, new_ranges) elif len(shift) > 2: - return self.translate(shift[0:2], offset_provider).translate(shift[2:], offset_provider) + return self.translate(shift[0:2], offset_provider, symbolic_domain_sizes).translate( + shift[2:], offset_provider, symbolic_domain_sizes + ) else: raise AssertionError("Number of shifts must be a multiple of 2.") diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 19e26f24b6..d7a66b8285 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -10,7 +10,6 @@ from typing import Callable, Optional, Union from gt4py._core import definitions as core_defs -from gt4py.eve.extended_typing import Dict, Tuple from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.type_system import type_specifications as ts, type_translation @@ -412,7 +411,7 @@ def _impl(*its: itir.Expr) -> itir.FunCall: def domain( grid_type: Union[common.GridType, str], - ranges: Dict[Union[common.Dimension, str], Tuple[itir.Expr, itir.Expr]], + ranges: dict[Union[common.Dimension, str], tuple[itir.Expr, itir.Expr]], ) -> itir.FunCall: """ >>> str( @@ -446,7 +445,7 @@ def domain( ) -def as_fieldop(expr: itir.Expr, domain: Optional[itir.Expr] = None) -> call: +def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> call: """ Create an `as_fieldop` call. diff --git a/src/gt4py/next/iterator/transforms/__init__.py b/src/gt4py/next/iterator/transforms/__init__.py index 6f9651a397..aeccb5f26d 100644 --- a/src/gt4py/next/iterator/transforms/__init__.py +++ b/src/gt4py/next/iterator/transforms/__init__.py @@ -8,10 +8,9 @@ from gt4py.next.iterator.transforms.pass_manager import ( ITIRTransform, - LiftMode, apply_common_transforms, apply_fieldview_transforms, ) -__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "LiftMode", "ITIRTransform"] +__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "ITIRTransform"] diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index f8a3c08e8f..4a354879ca 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -7,7 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py import eve -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator): @@ -18,32 +19,29 @@ class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator): - `list_get(i, make_const_list(e))` -> `e` """ - def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: + def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node: node = self.generic_visit(node) - if node.fun == ir.SymRef(id="list_get"): - if isinstance(node.args[1], ir.FunCall): - if node.args[1].fun == ir.SymRef(id="neighbors"): - offset_tag = node.args[1].args[0] - offset_index = ( - ir.OffsetLiteral(value=int(node.args[0].value)) - if isinstance(node.args[0], ir.Literal) - else node.args[ - 0 - ] # else-branch: e.g. SymRef from unroll_reduce, TODO(havogt): remove when we replace unroll_reduce by list support in gtfn - ) - it = node.args[1].args[1] - return ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="shift"), args=[offset_tag, offset_index] - ), - args=[it], - ) - ], - ) - if node.args[1].fun == ir.SymRef(id="make_const_list"): - return node.args[1].args[0] + if cpm.is_call_to(node, "list_get"): + if cpm.is_call_to(node.args[1], "if_"): + list_idx = node.args[0] + cond, true_val, false_val = node.args[1].args + return im.if_( + cond, + self.visit(im.call("list_get")(list_idx, true_val)), + self.visit(im.call("list_get")(list_idx, false_val)), + ) + if cpm.is_call_to(node.args[1], "neighbors"): + offset_tag = node.args[1].args[0] + offset_index = ( + itir.OffsetLiteral(value=int(node.args[0].value)) + if isinstance(node.args[0], itir.Literal) + else node.args[ + 0 + ] # else-branch: e.g. SymRef from unroll_reduce, TODO(havogt): remove when we replace unroll_reduce by list support in gtfn + ) + it = node.args[1].args[1] + return im.deref(im.shift(offset_tag, offset_index)(it)) + if cpm.is_call_to(node.args[1], "make_const_list"): + return node.args[1].args[0] return node diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index b61fb2ba87..f84714e779 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -16,6 +16,7 @@ from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, @@ -104,9 +105,10 @@ def apply( *, ignore_tuple_size: bool = False, remove_letified_make_tuple_elements: bool = True, - offset_provider=None, + offset_provider: Optional[common.OffsetProvider] = None, + within_stencil: Optional[bool] = None, # manually passing flags is mostly for allowing separate testing of the modes - flags=None, + flags: Optional[Flag] = None, # allow sym references without a symbol declaration, mostly for testing allow_undeclared_symbols: bool = False, ) -> ir.Node: @@ -126,6 +128,13 @@ def apply( flags = flags or cls.flags offset_provider = offset_provider or {} + if isinstance(node, (ir.Program, ir.FencilDefinition)): + within_stencil = False + assert within_stencil in [ + True, + False, + ], "Parameter 'within_stencil' mandatory if node is not a 'Program'." + if not ignore_tuple_size: node = itir_type_inference.infer( node, @@ -136,7 +145,7 @@ def apply( new_node = cls( ignore_tuple_size=ignore_tuple_size, flags=flags, - ).visit(node) + ).visit(node, within_stencil=within_stencil) # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important # as otherwise two equal expressions containing a tuple will not be equal anymore @@ -150,20 +159,23 @@ def apply( return new_node - def visit_FunCall(self, node: ir.FunCall) -> ir.Node: - node = self.generic_visit(node) - return self.fp_transform(node) + def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: + if cpm.is_call_to(node, "as_fieldop"): + kwargs = {**kwargs, "within_stencil": True} + + node = self.generic_visit(node, **kwargs) + return self.fp_transform(node, **kwargs) - def fp_transform(self, node: ir.Node) -> ir.Node: + def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: while True: - new_node = self.transform(node) + new_node = self.transform(node, **kwargs) if new_node is None: break assert new_node != node node = new_node return node - def transform(self, node: ir.Node) -> Optional[ir.Node]: + def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: if not isinstance(node, ir.FunCall): return None @@ -171,12 +183,14 @@ def transform(self, node: ir.Node) -> Optional[ir.Node]: if self.flags & transformation: assert isinstance(transformation.name, str) method = getattr(self, f"transform_{transformation.name.lower()}") - result = method(node) + result = method(node, **kwargs) if result is not None: return result return None - def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_collapse_make_tuple_tuple_get( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: if node.fun == ir.SymRef(id="make_tuple") and all( isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get") for arg in node.args @@ -202,7 +216,9 @@ def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ return first_expr return None - def transform_collapse_tuple_get_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_collapse_tuple_get_make_tuple( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: if ( node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[1], ir.FunCall) @@ -219,7 +235,7 @@ def transform_collapse_tuple_get_make_tuple(self, node: ir.FunCall) -> Optional[ return node.args[1].args[idx] return None - def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[0], ir.Literal): # TODO(tehrengruber): extend to general symbols as long as the tail call in the let # does not capture @@ -228,7 +244,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: idx, let_expr = node.args return im.call( im.lambda_(*let_expr.fun.params)( # type: ignore[attr-defined] # ensured by is_let - self.fp_transform(im.tuple_get(idx.value, let_expr.fun.expr)) # type: ignore[attr-defined] # ensured by is_let + self.fp_transform(im.tuple_get(idx.value, let_expr.fun.expr), **kwargs) # type: ignore[attr-defined] # ensured by is_let ) )( *let_expr.args # type: ignore[attr-defined] # ensured by is_let @@ -238,12 +254,12 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: cond, true_branch, false_branch = node.args[1].args return im.if_( cond, - self.fp_transform(im.tuple_get(idx.value, true_branch)), - self.fp_transform(im.tuple_get(idx.value, false_branch)), + self.fp_transform(im.tuple_get(idx.value, true_branch), **kwargs), + self.fp_transform(im.tuple_get(idx.value, false_branch), **kwargs), ) return None - def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_letify_make_tuple_elements(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if node.fun == ir.SymRef(id="make_tuple"): # `make_tuple(expr1, expr1)` # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` @@ -258,21 +274,27 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir. new_args.append(arg) if bound_vars: - return self.fp_transform(im.let(*bound_vars.items())(im.call(node.fun)(*new_args))) + return self.fp_transform( + im.let(*bound_vars.items())(im.call(node.fun)(*new_args)), **kwargs + ) return None - def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_inline_trivial_make_tuple(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args] if any(eligible_params): - return self.visit(inline_lambda(node, eligible_params=eligible_params)) + return self.visit(inline_lambda(node, eligible_params=eligible_params), **kwargs) return None - def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + if kwargs["within_stencil"]: + # TODO(tehrengruber): This significantly increases the size of the tree. Skip transformation + # in local-view for now. Revisit. + return None + if not cpm.is_call_to(node, "if_"): - # TODO(tehrengruber): This significantly increases the size of the tree. Revisit. # TODO(tehrengruber): Only inline if type of branch value is a tuple. # Examples: # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` @@ -281,12 +303,16 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.N for i, arg in enumerate(node.args): if cpm.is_call_to(arg, "if_"): cond, true_branch, false_branch = arg.args - new_true_branch = self.fp_transform(_with_altered_arg(node, i, true_branch)) - new_false_branch = self.fp_transform(_with_altered_arg(node, i, false_branch)) + new_true_branch = self.fp_transform( + _with_altered_arg(node, i, true_branch), **kwargs + ) + new_false_branch = self.fp_transform( + _with_altered_arg(node, i, false_branch), **kwargs + ) return im.if_(cond, new_true_branch, new_false_branch) return None - def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} @@ -304,12 +330,15 @@ def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: if outer_vars: return self.fp_transform( im.let(*outer_vars.items())( - self.fp_transform(im.let(*inner_vars.items())(original_inner_expr)) - ) + self.fp_transform( + im.let(*inner_vars.items())(original_inner_expr), **kwargs + ) + ), + **kwargs, ) return None - def transform_inline_trivial_let(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 4932d376ad..38ea1fd53d 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -31,6 +31,21 @@ from gt4py.next.type_system import type_info, type_specifications as ts +def _is_trivial_tuple_expr(node: itir.Expr): + """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" + if cpm.is_call_to(node, "make_tuple") and all( + isinstance(arg, (itir.SymRef, itir.Literal)) or _is_trivial_tuple_expr(arg) + for arg in node.args + ): + return True + if cpm.is_call_to(node, "tuple_get") and ( + isinstance(node.args[1], (itir.SymRef, itir.Literal)) + or _is_trivial_tuple_expr(node.args[1]) + ): + return True + return False + + @dataclasses.dataclass class _NodeReplacer(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type", "domain") @@ -373,7 +388,7 @@ class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator): >>> x = itir.SymRef(id="x") >>> plus = lambda a, b: itir.FunCall(fun=itir.SymRef(id=("plus")), args=[a, b]) >>> expr = plus(plus(x, x), plus(x, x)) - >>> print(CommonSubexpressionElimination.apply(expr, is_local_view=True)) + >>> print(CommonSubexpressionElimination.apply(expr, within_stencil=True)) (λ(_cs_1) → _cs_1 + _cs_1)(x + x) The pass visits the tree top-down starting from the root node, e.g. an itir.Program. @@ -395,33 +410,33 @@ class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator): def apply( cls, node: ProgramOrExpr, - is_local_view: bool | None = None, + within_stencil: bool | None = None, offset_provider: common.OffsetProvider | None = None, ) -> ProgramOrExpr: is_program = isinstance(node, (itir.Program, itir.FencilDefinition)) if is_program: - assert is_local_view is None - is_local_view = False + assert within_stencil is None + within_stencil = False else: assert ( - is_local_view is not None - ), "The expression's context must be specified using `is_local_view`." + within_stencil is not None + ), "The expression's context must be specified using `within_stencil`." offset_provider = offset_provider or {} node = itir_type_inference.infer( node, offset_provider=offset_provider, allow_undeclared_symbols=not is_program ) - return cls().visit(node, is_local_view=is_local_view) + return cls().visit(node, within_stencil=within_stencil) def generic_visit(self, node, **kwargs): if cpm.is_call_to("as_fieldop", node): - assert not kwargs.get("is_local_view") - is_local_view = cpm.is_call_to("as_fieldop", node) or kwargs.get("is_local_view") + assert not kwargs.get("within_stencil") + within_stencil = cpm.is_call_to("as_fieldop", node) or kwargs.get("within_stencil") - return super().generic_visit(node, **(kwargs | {"is_local_view": is_local_view})) + return super().generic_visit(node, **(kwargs | {"within_stencil": within_stencil})) def visit_FunCall(self, node: itir.FunCall, **kwargs): - is_local_view = kwargs["is_local_view"] + within_stencil = kwargs["within_stencil"] if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")): return node @@ -431,7 +446,7 @@ def predicate(subexpr: itir.Expr, num_occurences: int): # view, even though the syntactic context `node` is in field view. # note: what is extracted is sketched in the docstring above. keep it updated. if num_occurences > 1: - if is_local_view: + if within_stencil: return True # condition is only necessary since typing on lambdas is not preserved during # the transformation @@ -439,11 +454,13 @@ def predicate(subexpr: itir.Expr, num_occurences: int): # only extract fields outside of `as_fieldop` # `as_fieldop(...)(field_expr, field_expr)` # -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)` + # only extract if subexpression is not a trivial tuple expressions, e.g., + # `make_tuple(a, b)`, as this would result in a more costly temporary. assert isinstance(subexpr.type, ts.TypeSpec) if all( isinstance(stype, ts.FieldType) for stype in type_info.primitive_constituents(subexpr.type) - ): + ) and not _is_trivial_tuple_expr(subexpr): return True return False diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 51bbd91d83..da238733da 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -13,7 +13,12 @@ from gt4py.eve import utils as eve_utils from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.transforms import inline_lambdas, inline_lifts, trace_shifts +from gt4py.next.iterator.transforms import ( + inline_center_deref_lift_vars, + inline_lambdas, + inline_lifts, + trace_shifts, +) from gt4py.next.iterator.type_system import ( inference as type_inference, type_specifications as it_ts, @@ -54,6 +59,14 @@ def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: return expr +def _is_tuple_expr_of_literals(expr: itir.Expr): + if cpm.is_call_to(expr, "make_tuple"): + return all(_is_tuple_expr_of_literals(arg) for arg in expr.args) + if cpm.is_call_to(expr, "tuple_get"): + return _is_tuple_expr_of_literals(expr.args[1]) + return isinstance(expr, itir.Literal) + + @dataclasses.dataclass class FuseAsFieldOp(eve.NodeTranslator): """ @@ -153,11 +166,15 @@ def visit_FunCall(self, node: itir.FunCall): for stencil_param, arg, arg_shifts in zip(stencil.params, args, shifts, strict=True): assert isinstance(arg.type, ts.TypeSpec) - dtype = type_info.extract_dtype(arg.type) + dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) # TODO(tehrengruber): make this configurable - should_inline = isinstance(arg, itir.Literal) or ( + should_inline = _is_tuple_expr_of_literals(arg) or ( isinstance(arg, itir.FunCall) - and (cpm.is_call_to(arg.fun, "as_fieldop") or cpm.is_call_to(arg, "if_")) + and ( + cpm.is_call_to(arg.fun, "as_fieldop") + and isinstance(arg.fun.args[0], itir.Lambda) + or cpm.is_call_to(arg, "if_") + ) and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) ) if should_inline: @@ -168,7 +185,7 @@ def visit_FunCall(self, node: itir.FunCall): type_ = arg.type arg = im.op_as_fieldop("if_")(*arg.args) arg.type = type_ - elif isinstance(arg, itir.Literal): + elif _is_tuple_expr_of_literals(arg): arg = im.op_as_fieldop(im.lambda_()(arg))() else: raise NotImplementedError() @@ -179,6 +196,7 @@ def visit_FunCall(self, node: itir.FunCall): new_args = _merge_arguments(new_args, extracted_args) else: + assert not isinstance(dtype, it_ts.ListType) new_param: str if isinstance( arg, itir.SymRef @@ -189,15 +207,19 @@ def visit_FunCall(self, node: itir.FunCall): new_param = stencil_param.id new_args = _merge_arguments(new_args, {new_param: arg}) - # simplify stencil directly to keep the tree small - new_stencil_body = inline_lambdas.InlineLambdas.apply( - new_stencil_body, opcount_preserving=True - ) - new_stencil_body = inline_lifts.InlineLifts().visit(new_stencil_body) - new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( *new_args.values() ) + + # simplify stencil directly to keep the tree small + new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( + new_node + ) # to keep the tree small + new_node = inline_lambdas.InlineLambdas.apply( + new_node, opcount_preserving=True, force_inline_lift_args=True + ) + new_node = inline_lifts.InlineLifts().visit(new_node) + type_inference.copy_type(from_=node, to=new_node) return new_node diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 11d3fccec1..90f8a6cded 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -175,7 +175,10 @@ def _transform_stmt( def create_global_tmps( - program: itir.Program, offset_provider: common.OffsetProvider + program: itir.Program, + offset_provider: common.OffsetProvider, + *, + uids: Optional[eve_utils.UIDGenerator] = None, ) -> itir.Program: """ Given an `itir.Program` create temporaries for intermediate values. @@ -186,7 +189,8 @@ def create_global_tmps( program = infer_domain.infer_program(program, offset_provider) program = type_inference.infer(program, offset_provider=offset_provider) - uids = eve_utils.UIDGenerator(prefix="__tmp") + if not uids: + uids = eve_utils.UIDGenerator(prefix="__tmp") declarations = program.declarations.copy() new_body = [] diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 2a85e6f2cf..6852b47a7a 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -10,8 +10,9 @@ import itertools import typing -from typing import Callable, TypeAlias +from typing import Callable, Optional, TypeAlias +from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.iterator import ir as itir @@ -28,6 +29,18 @@ ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN] +class DomainAnnexDebugger(eve.NodeVisitor): + """ + Small utility class to debug missing domain attribute in annex. + """ + + def visit_Node(self, node: itir.Node): + if cpm.is_applied_as_fieldop(node): + if not hasattr(node.annex, "domain"): + breakpoint() # noqa: T100 + return self.generic_visit(node) + + def _split_dict_by_key(pred: Callable, d: dict): """ Split dictionary into two based on predicate. @@ -107,6 +120,7 @@ def _extract_accessed_domains( input_ids: list[str], target_domain: domain_utils.SymbolicDomain, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> ACCESSED_DOMAINS: accessed_domains: dict[str, domain_utils.SymbolicDomain | None] = {} @@ -114,7 +128,9 @@ def _extract_accessed_domains( for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True): new_domains = [ - domain_utils.SymbolicDomain.translate(target_domain, shift, offset_provider) + domain_utils.SymbolicDomain.translate( + target_domain, shift, offset_provider, symbolic_domain_sizes + ) for shift in shifts_list ] # `None` means field is never accessed @@ -125,10 +141,11 @@ def _extract_accessed_domains( return typing.cast(ACCESSED_DOMAINS, accessed_domains) -def infer_as_fieldop( +def _infer_as_fieldop( applied_fieldop: itir.FunCall, target_domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") @@ -161,7 +178,7 @@ def infer_as_fieldop( input_ids.append(id_) inputs_accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( - stencil, input_ids, target_domain, offset_provider + stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes ) # Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s @@ -169,7 +186,7 @@ def infer_as_fieldop( transformed_inputs: list[itir.Expr] = [] for in_field_id, in_field in zip(input_ids, inputs): transformed_input, accessed_domains_tmp = infer_expr( - in_field, inputs_accessed_domains[in_field_id], offset_provider + in_field, inputs_accessed_domains[in_field_id], offset_provider, symbolic_domain_sizes ) transformed_inputs.append(transformed_input) @@ -187,15 +204,16 @@ def infer_as_fieldop( return transformed_call, accessed_domains_without_tmp -def infer_let( +def _infer_let( let_expr: itir.FunCall, input_domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: assert cpm.is_let(let_expr) assert isinstance(let_expr.fun, itir.Lambda) # just to make mypy happy transformed_calls_expr, accessed_domains = infer_expr( - let_expr.fun.expr, input_domain, offset_provider + let_expr.fun.expr, input_domain, offset_provider, symbolic_domain_sizes ) let_params = {param_sym.id for param_sym in let_expr.fun.params} @@ -212,6 +230,7 @@ def infer_let( None, ), offset_provider, + symbolic_domain_sizes, ) accessed_domains_outer = _merge_domains(accessed_domains_outer, accessed_domains_arg) transformed_calls_args.append(transformed_calls_arg) @@ -226,10 +245,11 @@ def infer_let( return transformed_call, accessed_domains_outer -def infer_make_tuple( +def _infer_make_tuple( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "make_tuple") infered_args_expr = [] @@ -245,17 +265,20 @@ def infer_make_tuple( # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` domain = (*domain, *(None for _ in range(len(expr.args) - len(domain)))) for i, arg in enumerate(expr.args): - infered_arg_expr, actual_domains_arg = infer_expr(arg, domain[i], offset_provider) + infered_arg_expr, actual_domains_arg = infer_expr( + arg, domain[i], offset_provider, symbolic_domain_sizes + ) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(*infered_args_expr) return result_expr, actual_domains -def infer_tuple_get( +def _infer_tuple_get( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "tuple_get") actual_domains: ACCESSED_DOMAINS = {} @@ -263,24 +286,29 @@ def infer_tuple_get( assert isinstance(idx_expr, itir.Literal) idx = int(idx_expr.value) tuple_domain = tuple(None if i != idx else domain for i in range(idx + 1)) - infered_arg_expr, actual_domains_arg = infer_expr(tuple_arg, tuple_domain, offset_provider) + infered_arg_expr, actual_domains_arg = infer_expr( + tuple_arg, tuple_domain, offset_provider, symbolic_domain_sizes + ) infered_args_expr = im.tuple_get(idx, infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) return infered_args_expr, actual_domains -def infer_if( +def _infer_if( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: assert cpm.is_call_to(expr, "if_") infered_args_expr = [] actual_domains: ACCESSED_DOMAINS = {} cond, true_val, false_val = expr.args for arg in [true_val, false_val]: - infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, offset_provider) + infered_arg_expr, actual_domains_arg = infer_expr( + arg, domain, offset_provider, symbolic_domain_sizes + ) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(cond, *infered_args_expr) @@ -291,25 +319,26 @@ def _infer_expr( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: if isinstance(expr, itir.SymRef): return expr, {str(expr.id): domain} elif isinstance(expr, itir.Literal): return expr, {} elif cpm.is_applied_as_fieldop(expr): - return infer_as_fieldop(expr, domain, offset_provider) + return _infer_as_fieldop(expr, domain, offset_provider, symbolic_domain_sizes) elif cpm.is_let(expr): - return infer_let(expr, domain, offset_provider) + return _infer_let(expr, domain, offset_provider, symbolic_domain_sizes) elif cpm.is_call_to(expr, "make_tuple"): - return infer_make_tuple(expr, domain, offset_provider) + return _infer_make_tuple(expr, domain, offset_provider, symbolic_domain_sizes) elif cpm.is_call_to(expr, "tuple_get"): - return infer_tuple_get(expr, domain, offset_provider) + return _infer_tuple_get(expr, domain, offset_provider, symbolic_domain_sizes) elif cpm.is_call_to(expr, "if_"): - return infer_if(expr, domain, offset_provider) + return _infer_if(expr, domain, offset_provider, symbolic_domain_sizes) elif ( cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) or cpm.is_call_to(expr, itir.TYPEBUILTINS) - or cpm.is_call_to(expr, "cast_") + or cpm.is_call_to(expr, ("cast_", "index", "unstructured_domain", "cartesian_domain")) ): return expr, {} else: @@ -320,40 +349,79 @@ def infer_expr( expr: itir.Expr, domain: DOMAIN, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + """ + Infer the domain of all field subexpressions of `expr`. + + Given an expression `expr` and the domain it is accessed at, back-propagate the domain of all + (field-typed) subexpression. + + Arguments: + - expr: The expression to be inferred. + - domain: The domain `expr` is read at. + - symbolic_domain_sizes: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol + name that evaluates to the length of that axis. + + Returns: + A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed) + having a domain argument now, and a dictionary mapping symbol names referenced in `expr` to + domain they are accessed at. + """ # this is just a small wrapper that populates the `domain` annex - expr, accessed_domains = _infer_expr(expr, domain, offset_provider) + expr, accessed_domains = _infer_expr(expr, domain, offset_provider, symbolic_domain_sizes) expr.annex.domain = domain return expr, accessed_domains +def _infer_stmt( + stmt: itir.Stmt, + offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]], +): + if isinstance(stmt, itir.SetAt): + transformed_call, _unused_domain = infer_expr( + stmt.expr, + domain_utils.SymbolicDomain.from_expr(stmt.domain), + offset_provider, + symbolic_domain_sizes, + ) + return itir.SetAt( + expr=transformed_call, + domain=stmt.domain, + target=stmt.target, + ) + elif isinstance(stmt, itir.IfStmt): + return itir.IfStmt( + cond=stmt.cond, + true_branch=[ + _infer_stmt(c, offset_provider, symbolic_domain_sizes) for c in stmt.true_branch + ], + false_branch=[ + _infer_stmt(c, offset_provider, symbolic_domain_sizes) for c in stmt.false_branch + ], + ) + raise ValueError(f"Unsupported stmt: {stmt}") + + def infer_program( program: itir.Program, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> itir.Program: - transformed_set_ats: list[itir.SetAt] = [] + """ + Infer the domain of all field subexpressions inside a program. + + See :func:`infer_expr` for more details. + """ assert ( not program.function_definitions ), "Domain propagation does not support function definitions." - for set_at in program.body: - assert isinstance(set_at, itir.SetAt) - - transformed_call, _unused_domain = infer_expr( - set_at.expr, domain_utils.SymbolicDomain.from_expr(set_at.domain), offset_provider - ) - transformed_set_ats.append( - itir.SetAt( - expr=transformed_call, - domain=set_at.domain, - target=set_at.target, - ), - ) - return itir.Program( id=program.id, function_definitions=program.function_definitions, params=program.params, declarations=program.declarations, - body=transformed_set_ats, + body=[_infer_stmt(stmt, offset_provider, symbolic_domain_sizes) for stmt in program.body], ) diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index f899da73b1..33e36bfa4b 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -5,7 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - +# FIXME[#1582](tehrengruber): This transformation is not used anymore. Decide on its fate. from typing import Sequence, TypeGuard from gt4py import eve diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 399a7a3dc6..5ec9ec5d0b 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -111,6 +111,9 @@ def new_name(name): args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], location=node.location, ) + for attr in ("type", "recorded_shifts", "domain"): + if hasattr(node.annex, attr): + setattr(new_expr.annex, attr, getattr(node.annex, attr)) itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) return new_expr @@ -120,10 +123,10 @@ class InlineLambdas(PreserveLocationVisitor, NodeTranslator): """ Inline lambda calls by substituting every argument by its value. - Note: This pass preserves, but doesn't use the `type` and `recorded_shifts` annex. + Note: This pass preserves, but doesn't use the `type` `recorded_shifts`, `domain` annex. """ - PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts") + PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain") opcount_preserving: bool diff --git a/src/gt4py/next/iterator/transforms/inline_scalar.py b/src/gt4py/next/iterator/transforms/inline_scalar.py new file mode 100644 index 0000000000..c6e2c38b90 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/inline_scalar.py @@ -0,0 +1,31 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py import eve +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.transforms import inline_lambdas +from gt4py.next.iterator.type_system import inference as itir_inference +from gt4py.next.type_system import type_specifications as ts + + +class InlineScalar(eve.NodeTranslator): + @classmethod + def apply(cls, program: itir.Program, offset_provider: common.OffsetProvider): + program = itir_inference.infer(program, offset_provider=offset_provider) + return cls().visit(program) + + def visit_Expr(self, node: itir.Expr): + node = self.generic_visit(node) + + if cpm.is_let(node): + eligible_params = [isinstance(arg.type, ts.ScalarType) for arg in node.args] + node = inline_lambdas.inline_lambda(node, eligible_params=eligible_params) + return node + return node diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 0c08bf2b9d..52a452155a 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -6,28 +6,30 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import enum from typing import Callable, Optional, Protocol from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import fencil_to_program, infer_domain, inline_fundefs +from gt4py.next.iterator.transforms import ( + fencil_to_program, + fuse_as_fieldop, + global_tmps, + infer_domain, + inline_fundefs, + inline_lifts, +) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination -from gt4py.next.iterator.transforms.eta_reduction import EtaReduction from gt4py.next.iterator.transforms.fuse_maps import FuseMaps -from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars -from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas -from gt4py.next.iterator.transforms.inline_lifts import InlineLifts +from gt4py.next.iterator.transforms.inline_scalar import InlineScalar from gt4py.next.iterator.transforms.merge_let import MergeLet from gt4py.next.iterator.transforms.normalize_shifts import NormalizeShifts -from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref -from gt4py.next.iterator.transforms.scan_eta_reduction import ScanEtaReduction from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce +from gt4py.next.iterator.type_system.inference import infer class ITIRTransform(Protocol): @@ -36,45 +38,12 @@ def __call__( ) -> itir.Program: ... -@enum.unique -class LiftMode(enum.Enum): - FORCE_INLINE = enum.auto() - USE_TEMPORARIES = enum.auto() - - -def _inline_lifts(ir, lift_mode): - if lift_mode == LiftMode.FORCE_INLINE: - return InlineLifts().visit(ir) - elif lift_mode == LiftMode.USE_TEMPORARIES: - return InlineLifts( - flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT - | InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet. - ).visit(ir) - else: - raise ValueError() - - return ir - - -def _inline_into_scan(ir, *, max_iter=10): - for _ in range(10): - # in case there are multiple levels of lambdas around the scan we have to do multiple iterations - inlined = InlineIntoScan().visit(ir) - inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift_args=True) - if inlined == ir: - break - ir = inlined - else: - raise RuntimeError(f"Inlining into 'scan' did not converge within {max_iter} iterations.") - return ir - - # TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward -# `lift_mode` and `temporary_extraction_heuristics` which is inconvenient. +# `extract_temporaries` and `temporary_extraction_heuristics` which is inconvenient. def apply_common_transforms( ir: itir.Program | itir.FencilDefinition, *, - lift_mode=None, + extract_temporaries=False, offset_provider=None, unroll_reduce=False, common_subexpression_elimination=True, @@ -84,57 +53,52 @@ def apply_common_transforms( temporary_extraction_heuristics: Optional[ Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, - # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place + #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for + #: more details. symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> itir.Program: + # FIXME[#1582](tehrengruber): Rewrite iterator tests with itir.Program and remove this if isinstance(ir, itir.FencilDefinition): - ir = fencil_to_program.FencilToProgram().apply( - ir - ) # FIXME[#1582](havogt): should be removed after refactoring to combined IR - else: - assert isinstance(ir, itir.Program) - # FIXME[#1582](havogt): note: currently the case when using the roundtrip backend - pass + ir = fencil_to_program.FencilToProgram.apply(ir) + assert isinstance(ir, itir.Program) - icdlv_uids = eve_utils.UIDGenerator() + tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") + mergeasfop_uids = eve_utils.UIDGenerator() - if lift_mode is None: - lift_mode = LiftMode.FORCE_INLINE - assert isinstance(lift_mode, LiftMode) ir = MergeLet().visit(ir) ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program - ir = PropagateDeref.apply(ir) ir = NormalizeShifts().visit(ir) + # note: this increases the size of the tree + # Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` + ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) + # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) + ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program + ir = infer_domain.infer_program( + ir, # type: ignore[arg-type] # always an itir.Program + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + ) + for _ in range(10): inlined = ir - inlined = InlineCenterDerefLiftVars.apply(inlined, uids=icdlv_uids) # type: ignore[arg-type] # always a fencil - inlined = _inline_lifts(inlined, lift_mode) - - inlined = InlineLambdas.apply( - inlined, - opcount_preserving=True, - force_inline_lift_args=(lift_mode == LiftMode.FORCE_INLINE), - # If trivial lifts are not inlined we might create temporaries for constants. In all - # other cases we want it anyway. - force_inline_trivial_lift_args=True, - ) - inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # still a `itir.Program` + inlined = InlineLambdas.apply(inlined, opcount_preserving=True) + inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # always an itir.Program # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply( # type: ignore[assignment] # still a `itir.Program` - inlined, - offset_provider=offset_provider, - # TODO(tehrengruber): disabled since it increases compile-time too much right now - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + inlined = CollapseTuple.apply(inlined, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program + inlined = InlineScalar.apply(inlined, offset_provider=offset_provider) + + # This pass is required to run after CollapseTuple as otherwise we can not inline + # expressions like `tuple_get(make_tuple(as_fieldop(stencil)(...)))` where stencil returns + # a list. Such expressions must be inlined however because no backend supports such + # field operators right now. + inlined = fuse_as_fieldop.FuseAsFieldOp.apply( + inlined, uids=mergeasfop_uids, offset_provider=offset_provider ) - # This pass is required such that a deref outside of a - # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the - # `tuple_get` is removed by the `CollapseTuple` pass. - inlined = PropagateDeref.apply(inlined) if inlined == ir: break @@ -142,48 +106,21 @@ def apply_common_transforms( else: raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") - if lift_mode != LiftMode.FORCE_INLINE: - # FIXME[#1582](tehrengruber): implement new temporary pass here - raise NotImplementedError() - # ruff: noqa: ERA001 - # assert offset_provider is not None - # ir = CreateGlobalTmps().visit( - # ir, - # offset_provider=offset_provider, - # extraction_heuristics=temporary_extraction_heuristics, - # symbolic_sizes=symbolic_domain_sizes, - # ) - # - # for _ in range(10): - # inlined = InlineLifts().visit(ir) - # inlined = InlineLambdas.apply( - # inlined, opcount_preserving=True, force_inline_lift_args=True - # ) - # if inlined == ir: - # break - # ir = inlined - # else: - # raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") - # - # # If after creating temporaries, the scan is not at the top, we inline. - # # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. - # # λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))` - # ir = _inline_into_scan(ir) + # breaks in test_zero_dim_tuple_arg as trivial tuple_get is not inlined + if common_subexpression_elimination: + ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) + ir = MergeLet().visit(ir) + ir = InlineLambdas.apply(ir, opcount_preserving=True) + + if extract_temporaries: + ir = infer(ir, inplace=True, offset_provider=offset_provider) + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # type: ignore[arg-type] # always an itir.Program # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can # only run the unconditional version here instead of in the loop above. if unconditionally_collapse_tuples: - ir = CollapseTuple.apply( # type: ignore[assignment] # still a `itir.Program` - ir, - ignore_tuple_size=True, - offset_provider=offset_provider, - # TODO(tehrengruber): disabled since it increases compile-time too much right now - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, - ) - - if lift_mode == LiftMode.FORCE_INLINE: - ir = _inline_into_scan(ir) + ir = CollapseTuple.apply(ir, ignore_tuple_size=True, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program ir = NormalizeShifts().visit(ir) @@ -198,18 +135,13 @@ def apply_common_transforms( ir = unrolled # type: ignore[assignment] # still a `itir.Program` ir = CollapseListGet().visit(ir) ir = NormalizeShifts().visit(ir) - ir = _inline_lifts(ir, LiftMode.FORCE_INLINE) + # this is required as nested neighbor reductions can contain lifts, e.g., + # `neighbors(V2Eₒ, ↑f(...))` + ir = inline_lifts.InlineLifts().visit(ir) ir = NormalizeShifts().visit(ir) else: raise RuntimeError("Reduction unrolling failed.") - ir = EtaReduction().visit(ir) - ir = ScanEtaReduction().visit(ir) - - if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) - ir = MergeLet().visit(ir) - ir = InlineLambdas.apply( ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args ) diff --git a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py new file mode 100644 index 0000000000..792bb421f1 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py @@ -0,0 +1,175 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +# FIXME[#1582](tehrengruber): file should be removed after refactoring to GTIR +import enum +from typing import Callable, Optional + +from gt4py.eve import utils as eve_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.transforms import fencil_to_program, inline_fundefs +from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet +from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding +from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination +from gt4py.next.iterator.transforms.eta_reduction import EtaReduction +from gt4py.next.iterator.transforms.fuse_maps import FuseMaps +from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars +from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan +from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas +from gt4py.next.iterator.transforms.inline_lifts import InlineLifts +from gt4py.next.iterator.transforms.merge_let import MergeLet +from gt4py.next.iterator.transforms.normalize_shifts import NormalizeShifts +from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref +from gt4py.next.iterator.transforms.scan_eta_reduction import ScanEtaReduction +from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce + + +@enum.unique +class LiftMode(enum.Enum): + FORCE_INLINE = enum.auto() + USE_TEMPORARIES = enum.auto() + + +def _inline_lifts(ir, lift_mode): + if lift_mode == LiftMode.FORCE_INLINE: + return InlineLifts().visit(ir) + elif lift_mode == LiftMode.USE_TEMPORARIES: + return InlineLifts( + flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT + | InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet. + ).visit(ir) + else: + raise ValueError() + + return ir + + +def _inline_into_scan(ir, *, max_iter=10): + for _ in range(10): + # in case there are multiple levels of lambdas around the scan we have to do multiple iterations + inlined = InlineIntoScan().visit(ir) + inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift_args=True) + if inlined == ir: + break + ir = inlined + else: + raise RuntimeError(f"Inlining into 'scan' did not converge within {max_iter} iterations.") + return ir + + +def apply_common_transforms( + ir: itir.Node, + *, + lift_mode=None, + offset_provider=None, + unroll_reduce=False, + common_subexpression_elimination=True, + force_inline_lambda_args=False, + unconditionally_collapse_tuples=False, + temporary_extraction_heuristics: Optional[ + Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] + ] = None, + symbolic_domain_sizes: Optional[dict[str, str]] = None, +) -> itir.Program: + assert isinstance(ir, itir.FencilDefinition) + ir = fencil_to_program.FencilToProgram().apply(ir) + icdlv_uids = eve_utils.UIDGenerator() + + if lift_mode is None: + lift_mode = LiftMode.FORCE_INLINE + assert isinstance(lift_mode, LiftMode) + ir = MergeLet().visit(ir) + ir = inline_fundefs.InlineFundefs().visit(ir) + + ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program + ir = PropagateDeref.apply(ir) + ir = NormalizeShifts().visit(ir) + + for _ in range(10): + inlined = ir + + inlined = InlineCenterDerefLiftVars.apply(inlined, uids=icdlv_uids) # type: ignore[arg-type] # always a fencil + inlined = _inline_lifts(inlined, lift_mode) + + inlined = InlineLambdas.apply( + inlined, + opcount_preserving=True, + force_inline_lift_args=(lift_mode == LiftMode.FORCE_INLINE), + # If trivial lifts are not inlined we might create temporaries for constants. In all + # other cases we want it anyway. + force_inline_trivial_lift_args=True, + ) + inlined = ConstantFolding.apply(inlined) + # This pass is required to be in the loop such that when an `if_` call with tuple arguments + # is constant-folded the surrounding tuple_get calls can be removed. + inlined = CollapseTuple.apply( + inlined, + offset_provider=offset_provider, + # TODO(tehrengruber): disabled since it increases compile-time too much right now + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + ) + # This pass is required such that a deref outside of a + # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the + # `tuple_get` is removed by the `CollapseTuple` pass. + inlined = PropagateDeref.apply(inlined) + + if inlined == ir: + break + ir = inlined + else: + raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") + + if lift_mode != LiftMode.FORCE_INLINE: + raise NotImplementedError() + + # Since `CollapseTuple` relies on the type inference which does not support returning tuples + # larger than the number of closure outputs as given by the unconditional collapse, we can + # only run the unconditional version here instead of in the loop above. + if unconditionally_collapse_tuples: + ir = CollapseTuple.apply( + ir, + ignore_tuple_size=True, + offset_provider=offset_provider, + # TODO(tehrengruber): disabled since it increases compile-time too much right now + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + ) + + if lift_mode == LiftMode.FORCE_INLINE: + ir = _inline_into_scan(ir) + + ir = NormalizeShifts().visit(ir) + + ir = FuseMaps().visit(ir) + ir = CollapseListGet().visit(ir) + + if unroll_reduce: + for _ in range(10): + unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) + if unrolled == ir: + break + ir = unrolled + ir = CollapseListGet().visit(ir) + ir = NormalizeShifts().visit(ir) + ir = _inline_lifts(ir, LiftMode.FORCE_INLINE) + ir = NormalizeShifts().visit(ir) + else: + raise RuntimeError("Reduction unrolling failed.") + + ir = EtaReduction().visit(ir) + ir = ScanEtaReduction().visit(ir) + + if common_subexpression_elimination: + ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) # type: ignore[type-var] # always an itir.Program + ir = MergeLet().visit(ir) + + ir = InlineLambdas.apply( + ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args + ) + + assert isinstance(ir, itir.Program) + return ir diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 02180a3699..08d896121d 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -13,8 +13,8 @@ class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): - # This pass preserves, but doesn't use the `type` and `recorded_shifts` annex. - PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts") + # This pass preserves, but doesn't use the `type`, `recorded_shifts`, `domain` annex. + PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain") def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]): return symbol_map.get(str(node.id), node) @@ -32,8 +32,8 @@ def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override] class RenameSymbols(PreserveLocationVisitor, NodeTranslator): - # This pass preserves, but doesn't use the `type` and `recorded_shifts` annex. - PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts") + # This pass preserves, but doesn't use the `type`, `recorded_shifts`, `domain` annex. + PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain") def visit_Sym( self, node: ir.Sym, *, name_map: Dict[str, str], active: Optional[Set[str]] = None diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 700b8571a5..ec9c3efb2b 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -30,7 +30,14 @@ def _is_neighbors_or_lifted_and_neighbors(arg: itir.Expr) -> TypeGuard[itir.FunC def _get_neighbors_args(reduce_args: Iterable[itir.Expr]) -> Iterator[itir.FunCall]: - return filter(_is_neighbors_or_lifted_and_neighbors, reduce_args) + flat_reduce_args: list[itir.Expr] = [] + for arg in reduce_args: + if cpm.is_call_to(arg, "if_"): + flat_reduce_args.extend(_get_neighbors_args(arg.args[1:3])) + else: + flat_reduce_args.append(arg) + + return filter(_is_neighbors_or_lifted_and_neighbors, flat_reduce_args) def _is_list_of_funcalls(lst: list) -> TypeGuard[list[itir.FunCall]]: diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 6579107197..43c4465576 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -299,7 +299,11 @@ def as_fieldop( @TypeSynthesizer def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: - if any(isinstance(f, ts.DeferredType) for f in fields): + if any( + isinstance(el, ts.DeferredType) + for f in fields + for el in type_info.primitive_constituents(f) + ): return ts.DeferredType(constraint=None) stencil_return = stencil( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 20a1a0cf76..85a100a88d 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import ClassVar, Optional, Union +from typing import Callable, ClassVar, Optional, Union from gt4py.eve import Coerced, SymbolName, datamodels from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait @@ -96,25 +96,23 @@ class Backend(Node): domain: Union[SymRef, CartesianDomain, UnstructuredDomain] -def _is_ref_literal_or_tuple_expr_of_ref(expr: Expr) -> bool: +def _is_tuple_expr_of(pred: Callable[[Expr], bool], expr: Expr) -> bool: if ( isinstance(expr, FunCall) and isinstance(expr.fun, SymRef) and expr.fun.id == "tuple_get" and len(expr.args) == 2 - and _is_ref_literal_or_tuple_expr_of_ref(expr.args[1]) + and _is_tuple_expr_of(pred, expr.args[1]) ): return True if ( isinstance(expr, FunCall) and isinstance(expr.fun, SymRef) and expr.fun.id == "make_tuple" - and all(_is_ref_literal_or_tuple_expr_of_ref(arg) for arg in expr.args) + and all(_is_tuple_expr_of(pred, arg) for arg in expr.args) ): return True - if isinstance(expr, (SymRef, Literal)): - return True - return False + return pred(expr) class SidComposite(Expr): @@ -126,14 +124,32 @@ def _values_validator( ) -> None: if not all( isinstance(el, (SidFromScalar, SidComposite)) - or _is_ref_literal_or_tuple_expr_of_ref(el) + or _is_tuple_expr_of(lambda expr: isinstance(expr, (SymRef, Literal)), el) for el in value ): raise ValueError( - "Only 'SymRef', tuple expr of 'SymRef', 'SidFromScalar', or 'SidComposite' allowed." + "Only 'SymRef', 'Literal', tuple expr thereof, 'SidFromScalar', or 'SidComposite' allowed." ) +def _might_be_scalar_expr(expr: Expr) -> bool: + if isinstance(expr, BinaryExpr): + return all(_is_tuple_expr_of(_might_be_scalar_expr, arg) for arg in (expr.lhs, expr.rhs)) + if isinstance(expr, UnaryExpr): + return _is_tuple_expr_of(_might_be_scalar_expr, expr.expr) + if ( + isinstance(expr, FunCall) + and isinstance(expr.fun, SymRef) + and expr.fun.id in ARITHMETIC_BUILTINS + ): + return all(_might_be_scalar_expr(arg) for arg in expr.args) + if isinstance(expr, CastExpr): + return _might_be_scalar_expr(expr.obj_expr) + if _is_tuple_expr_of(lambda e: isinstance(e, (SymRef, Literal)), expr): + return True + return False + + class SidFromScalar(Expr): arg: Expr @@ -141,8 +157,10 @@ class SidFromScalar(Expr): def _arg_validator( self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: Expr ) -> None: - if not _is_ref_literal_or_tuple_expr_of_ref(value): - raise ValueError("Only 'SymRef' or tuple expr of 'SymRef' allowed.") + if not _might_be_scalar_expr(value): + raise ValueError( + "Only 'SymRef', 'Literal', arithmetic op or tuple expr thereof allowed." + ) class Stmt(Node): @@ -155,6 +173,24 @@ class StencilExecution(Stmt): output: Union[SymRef, SidComposite] inputs: list[Union[SymRef, SidComposite, SidFromScalar, FunCall]] + @datamodels.validator("inputs") + def _arg_validator( + self: datamodels.DataModelTP, attribute: datamodels.Attribute, inputs: list[Expr] + ) -> None: + for inp in inputs: + if not _is_tuple_expr_of( + lambda expr: isinstance(expr, (SymRef, SidComposite, SidFromScalar)) + or ( + isinstance(expr, FunCall) + and isinstance(expr.fun, SymRef) + and expr.fun.id == "index" + ), + inp, + ): + raise ValueError( + "Only 'SymRef', 'SidComposite', 'SidFromScalar', 'index' call or tuple expr thereof allowed." + ) + class Scan(Node): function: SymRef diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 66d74d53cc..ce459f7970 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -20,7 +20,7 @@ from gt4py.next import common from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import LiftMode, fencil_to_program, pass_manager +from gt4py.next.iterator.transforms import pass_manager from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import cpp_interface, interface from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen @@ -51,7 +51,6 @@ class GTFNTranslationStep( # TODO replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135 enable_itir_transforms: bool = True use_imperative_backend: bool = False - lift_mode: Optional[LiftMode] = None device_type: core_defs.DeviceType = core_defs.DeviceType.CPU symbolic_domain_sizes: Optional[dict[str, str]] = None temporary_extraction_heuristics: Optional[ @@ -168,14 +167,9 @@ def _preprocess_program( program: itir.FencilDefinition | itir.Program, offset_provider: dict[str, common.Connectivity | common.Dimension], ) -> itir.Program: - if isinstance(program, itir.FencilDefinition) and not self.enable_itir_transforms: - return fencil_to_program.FencilToProgram().apply( - program - ) # FIXME[#1582](tehrengruber): should be removed after refactoring to combined IR - apply_common_transforms = functools.partial( pass_manager.apply_common_transforms, - lift_mode=self.lift_mode, + extract_temporaries=True, offset_provider=offset_provider, # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements unconditionally_collapse_tuples=True, @@ -203,7 +197,12 @@ def generate_stencil_source( offset_provider: dict[str, common.Connectivity | common.Dimension], column_axis: Optional[common.Dimension], ) -> str: - new_program = self._preprocess_program(program, offset_provider) + if self.enable_itir_transforms: + new_program = self._preprocess_program(program, offset_provider) + else: + assert isinstance(program, itir.Program) + new_program = program + gtfn_ir = GTFN_lowering.apply( new_program, offset_provider=offset_provider, column_axis=column_axis ) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index fb2645208c..bc2bd645e8 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -15,7 +15,7 @@ from gt4py.eve.concepts import SymbolName from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import ( Backend, @@ -67,6 +67,27 @@ def pytype_to_cpptype(t: ts.ScalarType | str) -> Optional[str]: _horizontal_dimension = "gtfn::unstructured::dim::horizontal" +def _is_tuple_of_ref_or_literal(expr: itir.Expr) -> bool: + if ( + isinstance(expr, itir.FunCall) + and isinstance(expr.fun, itir.SymRef) + and expr.fun.id == "tuple_get" + and len(expr.args) == 2 + and _is_tuple_of_ref_or_literal(expr.args[1]) + ): + return True + if ( + isinstance(expr, itir.FunCall) + and isinstance(expr.fun, itir.SymRef) + and expr.fun.id == "make_tuple" + and all(_is_tuple_of_ref_or_literal(arg) for arg in expr.args) + ): + return True + if isinstance(expr, (itir.SymRef, itir.Literal)): + return True + return False + + def _get_domains(nodes: Iterable[itir.Stmt]) -> Iterable[itir.FunCall]: result = set() for node in nodes: @@ -587,6 +608,9 @@ def visit_IfStmt(self, node: itir.IfStmt, **kwargs: Any) -> IfStmt: def visit_SetAt( self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any ) -> Union[StencilExecution, ScanExecution]: + if _is_tuple_of_ref_or_literal(node.expr): + node.expr = im.as_fieldop("deref", node.domain)(node.expr) + assert cpm.is_applied_as_fieldop(node.expr) stencil = node.expr.fun.args[0] # type: ignore[attr-defined] # checked in assert domain = node.domain diff --git a/src/gt4py/next/program_processors/formatters/lisp.py b/src/gt4py/next/program_processors/formatters/lisp.py index 7b722a7c1a..0a8253595e 100644 --- a/src/gt4py/next/program_processors/formatters/lisp.py +++ b/src/gt4py/next/program_processors/formatters/lisp.py @@ -51,9 +51,7 @@ class ToLispLike(TemplatedGenerator): @classmethod def apply(cls, root: itir.FencilDefinition, **kwargs: Any) -> str: # type: ignore[override] - transformed = apply_common_transforms( - root, lift_mode=kwargs.get("lift_mode"), offset_provider=kwargs["offset_provider"] - ) + transformed = apply_common_transforms(root, offset_provider=kwargs["offset_provider"]) generated_code = super().apply(transformed, **kwargs) try: from yasi import indent_code diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace.py index 9a45b6a29a..95186e0b5d 100644 --- a/src/gt4py/next/program_processors/runners/dace.py +++ b/src/gt4py/next/program_processors/runners/dace.py @@ -9,7 +9,6 @@ import factory from gt4py.next import backend -from gt4py.next.ffront import foast_to_gtir, foast_to_past, past_to_itir from gt4py.next.program_processors.runners.dace_fieldview import workflow as dace_fieldview_workflow from gt4py.next.program_processors.runners.dace_iterator import workflow as dace_iterator_workflow from gt4py.next.program_processors.runners.gtfn import GTFNBackendFactory @@ -33,7 +32,7 @@ class Params: lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.itir" ) - transforms = backend.DEFAULT_TRANSFORMS + transforms = backend.LEGACY_TRANSFORMS run_dace_cpu = DaCeIteratorBackendFactory(cached=True, auto_optimize=True) @@ -59,13 +58,7 @@ class Params: lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.gtir" ) - transforms = backend.Transforms( - past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), - foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(), - field_view_op_to_prog=foast_to_past.operator_to_program_factory( - foast_to_itir_step=foast_to_gtir.adapted_foast_to_gtir_factory() - ), - ) + transforms = backend.DEFAULT_TRANSFORMS gtir_cpu = DaCeFieldviewBackendFactory(cached=True, auto_optimize=False) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 6383d4bb44..fc2772027e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -25,7 +25,10 @@ from gt4py.next.ffront import decorator from gt4py.next.iterator import transforms as itir_transforms from gt4py.next.iterator.ir import SymRef -from gt4py.next.iterator.transforms import program_to_fencil +from gt4py.next.iterator.transforms import ( + pass_manager_legacy as legacy_itir_transforms, + program_to_fencil, +) from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.type_system import type_specifications as ts @@ -36,14 +39,14 @@ def preprocess_program( program: itir.FencilDefinition, offset_provider: Mapping[str, Any], - lift_mode: itir_transforms.LiftMode, + lift_mode: legacy_itir_transforms.LiftMode, symbolic_domain_sizes: Optional[dict[str, str]] = None, temporary_extraction_heuristics: Optional[ Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, unroll_reduce: bool = False, ): - node = itir_transforms.apply_common_transforms( + node = legacy_itir_transforms.apply_common_transforms( program, common_subexpression_elimination=False, force_inline_lambda_args=True, @@ -73,7 +76,7 @@ def build_sdfg_from_itir( auto_optimize: bool = False, on_gpu: bool = False, column_axis: Optional[common.Dimension] = None, - lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE, + lift_mode: legacy_itir_transforms.LiftMode = legacy_itir_transforms.LiftMode.FORCE_INLINE, symbolic_domain_sizes: Optional[dict[str, str]] = None, temporary_extraction_heuristics: Optional[ Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] @@ -234,7 +237,7 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: } sdfg.offset_providers_per_input_field = {} - itir_tmp = itir_transforms.apply_common_transforms( + itir_tmp = legacy_itir_transforms.apply_common_transforms( self.itir, offset_provider=offset_provider ) itir_tmp_fencil = program_to_fencil.program_to_fencil(itir_tmp) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py index 7a442e3819..740f1979cd 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py @@ -18,7 +18,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common, config from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import LiftMode +from gt4py.next.iterator.transforms import program_to_fencil from gt4py.next.otf import languages, recipes, stages, step_types, workflow from gt4py.next.otf.binding import interface from gt4py.next.otf.languages import LanguageSettings @@ -36,7 +36,6 @@ class DaCeTranslator( step_types.TranslationStep[languages.SDFG, languages.LanguageSettings], ): auto_optimize: bool = False - lift_mode: LiftMode = LiftMode.FORCE_INLINE device_type: core_defs.DeviceType = core_defs.DeviceType.CPU symbolic_domain_sizes: Optional[dict[str, str]] = None temporary_extraction_heuristics: Optional[ @@ -69,7 +68,6 @@ def generate_sdfg( auto_optimize=self.auto_optimize, on_gpu=on_gpu, column_axis=column_axis, - lift_mode=self.lift_mode, symbolic_domain_sizes=self.symbolic_domain_sizes, temporary_extraction_heuristics=self.temporary_extraction_heuristics, load_sdfg_from_file=False, @@ -82,7 +80,9 @@ def __call__( ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: """Generate DaCe SDFG file from the ITIR definition.""" program: itir.FencilDefinition | itir.Program = inp.data - assert isinstance(program, itir.FencilDefinition) + + if isinstance(program, itir.Program): + program = program_to_fencil.program_to_fencil(program) sdfg = self.generate_sdfg( program, diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 4a788bf40c..965c6417b2 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -20,7 +20,7 @@ from gt4py.eve.utils import content_hash from gt4py.next import backend, common, config from gt4py.next.common import Connectivity, Dimension -from gt4py.next.iterator import ir as itir, transforms +from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler @@ -166,19 +166,19 @@ class Params: cached_translation = factory.Trait( translation=factory.LazyAttribute( lambda o: workflow.CachedStep( - o.translation_, + o.bare_translation, hash_function=fingerprint_compilable_program, cache=FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), ) ), ) - translation_ = factory.SubFactory( + bare_translation = factory.SubFactory( gtfn_module.GTFNTranslationStepFactory, device_type=factory.SelfAttribute("..device_type"), ) - translation = factory.LazyAttribute(lambda o: o.translation_) + translation = factory.LazyAttribute(lambda o: o.bare_translation) bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] = ( nanobind.bind_source @@ -213,12 +213,6 @@ class Params: ), name_cached="_cached", ) - use_temporaries = factory.Trait( - # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place - otf_workflow__translation__lift_mode=transforms.LiftMode.USE_TEMPORARIES, - # otf_workflow__translation__temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, # noqa: ERA001 - name_temps="_with_temporaries", - ) device_type = core_defs.DeviceType.CPU hash_function = compilation_hash otf_workflow = factory.SubFactory( @@ -242,8 +236,10 @@ class Params: run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True) -run_gtfn_with_temporaries = GTFNBackendFactory(use_temporaries=True) - run_gtfn_gpu = GTFNBackendFactory(gpu=True) run_gtfn_gpu_cached = GTFNBackendFactory(gpu=True, cached=True) + +run_gtfn_no_transforms = GTFNBackendFactory( + otf_workflow__bare_translation__enable_itir_transforms=False +) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 57785ceb33..4d518d7fcc 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -103,6 +103,7 @@ def fencil_generator( Arguments: ir: The iterator IR (ITIR) node. debug: Keep module source containing fencil implementation. + extract_temporaries: Extract intermediate field values into temporaries. use_embedded: Directly use builtins from embedded backend instead of generic dispatcher. Gives faster performance and is easier to debug. @@ -209,7 +210,7 @@ def decorated_fencil( ) -> None: if out is not None: args = (*args, out) - if not column_axis: + if not column_axis: # TODO(tehrengruber): This variable is never used. Bug? column_axis = inp.args.column_axis fencil( *args, @@ -222,11 +223,13 @@ def decorated_fencil( return decorated_fencil +# TODO(tehrengruber): introduce factory default = next_backend.Backend( name="roundtrip", executor=Roundtrip( transforms=functools.partial( - itir_transforms.apply_common_transforms, lift_mode=itir_transforms.LiftMode.FORCE_INLINE + itir_transforms.apply_common_transforms, + extract_temporaries=False, ) ), allocator=next_allocators.StandardCPUFieldBufferAllocator(), @@ -237,12 +240,18 @@ def decorated_fencil( executor=Roundtrip( transforms=functools.partial( itir_transforms.apply_common_transforms, - lift_mode=itir_transforms.LiftMode.USE_TEMPORARIES, + extract_temporaries=True, ) ), allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.DEFAULT_TRANSFORMS, ) +no_transforms = next_backend.Backend( + name="roundtrip", + executor=Roundtrip(transforms=lambda o, *, offset_provider: o), + allocator=next_allocators.StandardCPUFieldBufferAllocator(), + transforms=next_backend.DEFAULT_TRANSFORMS, +) gtir = next_backend.Backend( @@ -257,3 +266,4 @@ def decorated_fencil( ), ), ) +foast_to_gtir_step = foast_to_gtir.adapted_foast_to_gtir_factory(cached=True) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 5bda9a6f2e..66f8937dc5 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -459,7 +459,9 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: """ if isinstance(symbol_type, ts.DeferredType) and ( - symbol_type.constraint is None or issubclass(type_class(to_type), symbol_type.constraint) + symbol_type.constraint is None + or (isinstance(to_type, ts.DeferredType) and to_type.constraint is None) + or issubclass(type_class(to_type), symbol_type.constraint) ): return True elif is_concrete(symbol_type): diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 1bcc3554a7..c86ba88ead 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -43,11 +43,10 @@ def short_id(self, num_components: int = 2) -> str: class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): GTFN_CPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn" GTFN_CPU_IMPERATIVE = "gt4py.next.program_processors.runners.gtfn.run_gtfn_imperative" - GTFN_CPU_WITH_TEMPORARIES = ( - "gt4py.next.program_processors.runners.gtfn.run_gtfn_with_temporaries" - ) + GTFN_CPU_NO_TRANSFORMS = "gt4py.next.program_processors.runners.gtfn.run_gtfn_no_transforms" GTFN_GPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn_gpu" ROUNDTRIP = "gt4py.next.program_processors.runners.roundtrip.default" + ROUNDTRIP_NO_TRANSFORMS = "gt4py.next.program_processors.runners.roundtrip.no_transforms" GTIR_EMBEDDED = "gt4py.next.program_processors.runners.roundtrip.gtir" ROUNDTRIP_WITH_TEMPORARIES = "gt4py.next.program_processors.runners.roundtrip.with_temporaries" DOUBLE_ROUNDTRIP = "gt4py.next.program_processors.runners.double_roundtrip.backend" @@ -102,6 +101,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions" USES_SCAN = "uses_scan" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" +USES_SCAN_IN_STENCIL = "uses_scan_in_stencil" USES_SCAN_WITHOUT_FIELD_ARGS = "uses_scan_without_field_args" USES_SCAN_NESTED = "uses_scan_nested" USES_SCAN_REQUIRING_PROJECTOR = "uses_scan_requiring_projector" @@ -130,13 +130,18 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): COMMON_SKIP_TEST_LIST = [ (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), - (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] +# Markers to skip because of missing features in the domain inference +DOMAIN_INFERENCE_SKIP_LIST = [ + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), +] DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ + (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), (USES_IR_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCALAR_IN_DOMAIN_AND_FO, XFAIL, UNSUPPORTED_MESSAGE), (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), @@ -148,8 +153,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), ] -GTIR_DACE_SKIP_TEST_LIST = [ - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), +GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), @@ -164,14 +168,22 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): UNSUPPORTED_MESSAGE, ), # we can't extract the field type from scan args ] -GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ - # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 - (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - # max_over broken, see https://github.com/GridTools/gt4py/issues/1289 - (USES_MAX_OVER, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCAN_REQUIRING_PROJECTOR, XFAIL, UNSUPPORTED_MESSAGE), +ROUNDTRIP_SKIP_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ + (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] +GTFN_SKIP_TEST_LIST = ( + COMMON_SKIP_TEST_LIST + + DOMAIN_INFERENCE_SKIP_LIST + + [ + # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 + (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + # max_over broken, see https://github.com/GridTools/gt4py/issues/1289 + (USES_MAX_OVER, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN_REQUIRING_PROJECTOR, XFAIL, UNSUPPORTED_MESSAGE), + ] +) #: Skip matrix, contains for each backend processor a list of tuples with following fields: #: (, ) @@ -192,20 +204,18 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], - ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST - + [(ALL, XFAIL, UNSUPPORTED_MESSAGE), (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE)], - ProgramFormatterId.GTFN_CPP_FORMATTER: [ - (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE) - ], - ProgramBackendId.ROUNDTRIP: [(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE)], - ProgramBackendId.GTIR_EMBEDDED: [ - (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + ProgramFormatterId.GTFN_CPP_FORMATTER: DOMAIN_INFERENCE_SKIP_LIST + + [ + (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), ], - ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES: [ + ProgramFormatterId.LISP_FORMATTER: DOMAIN_INFERENCE_SKIP_LIST, + ProgramBackendId.ROUNDTRIP: ROUNDTRIP_SKIP_LIST, + ProgramBackendId.DOUBLE_ROUNDTRIP: ROUNDTRIP_SKIP_LIST, + ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES: ROUNDTRIP_SKIP_LIST + + [ (ALL, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ], + ProgramBackendId.GTIR_EMBEDDED: ROUNDTRIP_SKIP_LIST, } diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 0ed3365969..c64efb27d2 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -49,7 +49,6 @@ def __gt_allocator__( next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, next_tests.definitions.ProgramBackendId.GTFN_CPU, next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, - next_tests.definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, pytest.param( next_tests.definitions.ProgramBackendId.GTFN_GPU, marks=pytest.mark.requires_gpu ), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index f26424bf0e..47419c278b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -30,9 +30,9 @@ def testee_op(a: cases.IField) -> cases.IField: def testee(a: cases.IField, out: cases.IField): testee_op(a, out=out) - assert isinstance(testee.itir, (itir.FencilDefinition, itir.Program)) + assert isinstance(testee.itir, (itir.Program, itir.FencilDefinition)) assert isinstance( - testee.with_backend(cartesian_case.backend).itir, (itir.FencilDefinition, itir.Program) + testee.with_backend(cartesian_case.backend).itir, (itir.Program, itir.FencilDefinition) ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 27f94960dc..f10f195d3a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -301,6 +301,21 @@ def testee(a: cases.IJKField, b: int32) -> cases.IJKField: cases.verify(cartesian_case, testee, a, b, out=out, ref=ref) +@pytest.mark.uses_tuple_args +def test_double_use_scalar(cartesian_case): + # TODO(tehrengruber): This should be a regression test on ITIR level, but tracing doesn't + # work for this case. + @gtx.field_operator + def testee(a: np.int32, b: np.int32, c: cases.IField) -> cases.IField: + tmp = a * b + tmp2 = tmp * tmp + # important part here is that we use the intermediate twice so that it is + # not inlined + return tmp2 * tmp2 * c + + cases.verify_with_default_data(cartesian_case, testee, ref=lambda a, b, c: a * b * a * b * c) + + @pytest.mark.uses_scalar_in_domain_and_fo def test_scalar_in_domain_spec_and_fo_call(cartesian_case): @gtx.field_operator @@ -688,9 +703,6 @@ def simple_scan_operator(carry: float) -> float: @pytest.mark.uses_lift_expressions @pytest.mark.uses_scan_nested def test_solve_triag(cartesian_case): - if cartesian_case.backend == gtfn.run_gtfn_with_temporaries: - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0)) def tridiag_forward( state: tuple[float, float], a: float, b: float, c: float, d: float @@ -789,9 +801,6 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: @pytest.mark.uses_scan def test_ternary_scan(cartesian_case): - if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]: - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def simple_scan_operator(carry: float, a: float) -> float: return carry if carry > a else carry + 1.0 @@ -814,9 +823,6 @@ def simple_scan_operator(carry: float, a: float) -> float: @pytest.mark.uses_scan_without_field_args @pytest.mark.uses_tuple_returns def test_scan_nested_tuple_output(forward, cartesian_case): - if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]: - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - init = (1, (2, 3)) k_size = cartesian_case.default_sizes[KDim] expected = np.arange(1, 1 + k_size, 1, dtype=int32) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index 0efb599f9e..7ff7edf226 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -56,6 +56,7 @@ def simple_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField cases.verify(cartesian_case, simple_if, a, b, condition, out=out, ref=a if condition else b) +# TODO(tehrengruber): test with fields on different domains @pytest.mark.parametrize("condition1, condition2", [[True, False], [True, False]]) @pytest.mark.uses_if_stmts def test_simple_if_conditional(condition1, condition2, cartesian_case): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 0305a5841a..11e28de9e1 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -5,14 +5,15 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import platform import pytest from numpy import int32, int64 from gt4py import next as gtx from gt4py.next import backend, common -from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms -from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries +from gt4py.next.iterator.transforms import apply_common_transforms +from gt4py.next.program_processors.runners.gtfn import run_gtfn from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -34,8 +35,8 @@ def run_gtfn_with_temporaries_and_symbolic_sizes(): return backend.Backend( name="run_gtfn_with_temporaries_and_sizes", transforms=backend.DEFAULT_TRANSFORMS, - executor=run_gtfn_with_temporaries.executor.replace( - translation=run_gtfn_with_temporaries.executor.translation.replace( + executor=run_gtfn.executor.replace( + translation=run_gtfn.executor.translation.replace( symbolic_domain_sizes={ "Cell": "num_cells", "Edge": "num_edges", @@ -43,7 +44,7 @@ def run_gtfn_with_temporaries_and_symbolic_sizes(): } ) ), - allocator=run_gtfn_with_temporaries.allocator, + allocator=run_gtfn.allocator, ) @@ -64,8 +65,14 @@ def prog( def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh_descriptor): - # FIXME[#1582](tehrengruber): enable when temporary pass has been implemented - pytest.xfail("Temporary pass not implemented.") + if platform.machine() == "x86_64": + pytest.xfail( + reason="The C++ code generated in this test contains unicode characters " + "(coming from the ssa pass) which is not supported by gcc 9 used" + "in the CI. Bumping the container version sadly did not work for" + "unrelated and unclear reasons. Since the issue is not present" + "on Alps we just skip the test for now before investing more time." + ) unstructured_case = Case( run_gtfn_with_temporaries_and_symbolic_sizes, @@ -100,12 +107,9 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh def test_temporary_symbols(testee, mesh_descriptor): - # FIXME[#1582](tehrengruber): enable when temporary pass has been implemented - pytest.xfail("Temporary pass not implemented.") - itir_with_tmp = apply_common_transforms( testee.itir, - lift_mode=LiftMode.USE_TEMPORARIES, + extract_temporaries=True, offset_provider=mesh_descriptor.offset_provider, ) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index c2f72e4ca7..3fc4ed9945 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -45,8 +45,9 @@ plus, shift, xor_, + as_fieldop, ) -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, closure, fendef, fundef, offset from gt4py.next.program_processors.runners.gtfn import run_gtfn from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data @@ -87,7 +88,9 @@ def dispatch(arg0): @fendef(offset_provider={}, column_axis=column_axis) def fenimpl(size, arg0, out): - closure(cartesian_domain(named_range(IDim, 0, size)), dispatch, out, [arg0]) + domain = cartesian_domain(named_range(IDim, 0, size)) + + set_at(as_fieldop(dispatch, domain)(arg0), domain, out) elif len(inps) == 2: @@ -102,7 +105,9 @@ def dispatch(arg0, arg1): @fendef(offset_provider={}, column_axis=column_axis) def fenimpl(size, arg0, arg1, out): - closure(cartesian_domain(named_range(IDim, 0, size)), dispatch, out, [arg0, arg1]) + domain = cartesian_domain(named_range(IDim, 0, size)) + + set_at(as_fieldop(dispatch, domain)(arg0, arg1), domain, out) elif len(inps) == 3: @@ -117,7 +122,9 @@ def dispatch(arg0, arg1, arg2): @fendef(offset_provider={}, column_axis=column_axis) def fenimpl(size, arg0, arg1, arg2, out): - closure(cartesian_domain(named_range(IDim, 0, size)), dispatch, out, [arg0, arg1, arg2]) + domain = cartesian_domain(named_range(IDim, 0, size)) + + set_at(as_fieldop(dispatch, domain)(arg0, arg1, arg2), domain, out) else: raise AssertionError("Add overload.") diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py index a86959d075..e462aa07eb 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py @@ -18,7 +18,9 @@ @pytest.mark.uses_index_fields +@pytest.mark.uses_scan_in_stencil def test_scan_in_stencil(program_processor): + # FIXME[#1582](tehrengruber): Remove test after scan is reworked. program_processor, validate = program_processor isize = 1 diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 505879a506..19664f2dab 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -227,14 +227,6 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): - if ( - test_setup.case.backend - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() - ): - pytest.xfail( - "Needs implementation of scan projector. Breaks in type inference as executed" - "again after CollapseTuple." - ) if test_setup.case.backend == test_definitions.ProgramBackendId.ROUNDTRIP.load(): pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") @@ -254,12 +246,6 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like(test_setup): - if ( - test_setup.case.backend - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() - ): - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - cases.run( test_setup.case, solve_nonhydro_stencil_52_like, @@ -276,11 +262,6 @@ def test_solve_nonhydro_stencil_52_like(test_setup): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup): - if ( - test_setup.case.backend - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() - ): - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") if test_setup.case.backend == test_definitions.ProgramBackendId.ROUNDTRIP.load(): pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 14271efb27..3ce9d6b470 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -78,11 +78,6 @@ def naive_lap(inp): def test_anton_toy(stencil, program_processor): program_processor, validate = program_processor - if program_processor in [ - gtfn.run_gtfn_with_temporaries.executor, - ]: - pytest.xfail("TODO: issue with temporaries that crashes the application") - if stencil is lap: pytest.xfail( "Type inference does not support calling lambdas with offset arguments of changing type." diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 2b858f3025..f8e9f22eff 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -12,7 +12,7 @@ import gt4py.next as gtx from gt4py.next import field_utils from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.cases import IDim, KDim from next_tests.unit_tests.conftest import program_processor, run_processor @@ -170,23 +170,14 @@ def test_k_level_condition(program_processor, fun, k_level, inp_function, ref_fu @fundef -def sum_scanpass(state, inp): +def ksum(state, inp): return state + deref(inp) -@fundef -def ksum(inp): - return scan(sum_scanpass, True, 0.0)(inp) - - @fendef(column_axis=KDim) def ksum_fencil(i_size, k_start, k_end, inp, out): - closure( - cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, k_start, k_end)), - ksum, - out, - [inp], - ) + domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, k_start, k_end)) + set_at(as_fieldop(scan(ksum, True, 0.0), domain)(inp), domain, out) @pytest.mark.parametrize( @@ -214,19 +205,10 @@ def test_ksum_scan(program_processor, kstart, reference): assert np.allclose(reference, out.asnumpy()) -@fundef -def ksum_back(inp): - return scan(sum_scanpass, False, 0.0)(inp) - - @fendef(column_axis=KDim) def ksum_back_fencil(i_size, k_size, inp, out): - closure( - cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, 0, k_size)), - ksum_back, - out, - [inp], - ) + domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, 0, k_size)) + set_at(as_fieldop(scan(ksum, False, 0.0), domain)(inp), domain, out) def test_ksum_back_scan(program_processor): @@ -252,23 +234,14 @@ def test_ksum_back_scan(program_processor): @fundef -def doublesum_scanpass(state, inp0, inp1): +def kdoublesum(state, inp0, inp1): return make_tuple(tuple_get(0, state) + deref(inp0), tuple_get(1, state) + deref(inp1)) -@fundef -def kdoublesum(inp0, inp1): - return scan(doublesum_scanpass, True, make_tuple(0.0, 0))(inp0, inp1) - - @fendef(column_axis=KDim) def kdoublesum_fencil(i_size, k_start, k_end, inp0, inp1, out): - closure( - cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, k_start, k_end)), - kdoublesum, - out, - [inp0, inp1], - ) + domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, k_start, k_end)) + set_at(as_fieldop(scan(kdoublesum, True, make_tuple(0.0, 0)), domain)(inp0, inp1), domain, out) @pytest.mark.parametrize( @@ -325,7 +298,8 @@ def sum_shifted(inp0, inp1): @fendef(column_axis=KDim) def sum_shifted_fencil(out, inp0, inp1, k_size): - closure(cartesian_domain(named_range(KDim, 1, k_size)), sum_shifted, out, [inp0, inp1]) + domain = cartesian_domain(named_range(KDim, 1, k_size)) + set_at(as_fieldop(sum_shifted, domain)(inp0, inp1), domain, out) def test_different_vertical_sizes(program_processor): @@ -352,7 +326,8 @@ def sum(inp0, inp1): @fendef(column_axis=KDim) def sum_fencil(out, inp0, inp1, k_size): - closure(cartesian_domain(named_range(KDim, 0, k_size)), sum, out, [inp0, inp1]) + domain = cartesian_domain(named_range(KDim, 0, k_size)) + set_at(as_fieldop(sum, domain)(inp0, inp1), domain, out) @pytest.mark.uses_origin diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 156bc1c37f..3db4497910 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -30,7 +30,6 @@ unstructured_domain, ) from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.iterator.transforms.pass_manager import LiftMode from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( assert_close, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py index 2dde7d7653..c38a29bc61 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py @@ -28,7 +28,7 @@ from gt4py.next.iterator.runtime import set_at, if_stmt, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn -from next_tests.unit_tests.conftest import program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor_no_transforms, run_processor i = offset("i") @@ -43,8 +43,8 @@ def multiply(alpha, inp): @pytest.mark.uses_ir_if_stmts @pytest.mark.parametrize("cond", [True, False]) -def test_if_stmt(program_processor, cond): - program_processor, validate = program_processor +def test_if_stmt(program_processor_no_transforms, cond): + program_processor, validate = program_processor_no_transforms size = 10 @fendef(offset_provider={"i": IDim}) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index a89f250571..30ceaf9376 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -10,9 +10,9 @@ import pytest import gt4py.next as gtx +from gt4py.next import backend from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef -from gt4py.next.iterator.transforms import LiftMode +from gt4py.next.iterator.runtime import set_at, fendef, fundef from gt4py.next.program_processors.formatters import gtfn as gtfn_formatters from gt4py.next.program_processors.runners import gtfn @@ -42,22 +42,17 @@ def tridiag_backward2(x_kp1, cp, dp): @fundef -def solve_tridiag(a, b, c, d): - cpdp = lift(scan(tridiag_forward, True, make_tuple(0.0, 0.0)))(a, b, c, d) - return scan(tridiag_backward, False, 0.0)(cpdp) - - -def tuple_get_it(i, x): - def stencil(x): - return tuple_get(i, deref(x)) - - return lift(stencil)(x) +def solve_tridiag(domain, a, b, c, d): + cpdp = as_fieldop(scan(tridiag_forward, True, make_tuple(0.0, 0.0)), domain)(a, b, c, d) + return as_fieldop(scan(tridiag_backward, False, 0.0), domain)(cpdp) @fundef -def solve_tridiag2(a, b, c, d): - cpdp = lift(scan(tridiag_forward, True, make_tuple(0.0, 0.0)))(a, b, c, d) - return scan(tridiag_backward2, False, 0.0)(tuple_get_it(0, cpdp), tuple_get_it(1, cpdp)) +def solve_tridiag2(domain, a, b, c, d): + cpdp = as_fieldop(scan(tridiag_forward, True, make_tuple(0.0, 0.0)), domain)(a, b, c, d) + return as_fieldop(scan(tridiag_backward2, False, 0.0), domain)( + tuple_get(0, cpdp), tuple_get(1, cpdp) + ) @pytest.fixture @@ -80,40 +75,27 @@ def tridiag_reference(): @fendef def fen_solve_tridiag(i_size, j_size, k_size, a, b, c, d, x): - closure( - cartesian_domain( - named_range(IDim, 0, i_size), named_range(JDim, 0, j_size), named_range(KDim, 0, k_size) - ), - solve_tridiag, - x, - [a, b, c, d], + domain = cartesian_domain( + named_range(IDim, 0, i_size), named_range(JDim, 0, j_size), named_range(KDim, 0, k_size) ) + set_at(solve_tridiag(domain, a, b, c, d), domain, x) @fendef def fen_solve_tridiag2(i_size, j_size, k_size, a, b, c, d, x): - closure( - cartesian_domain( - named_range(IDim, 0, i_size), named_range(JDim, 0, j_size), named_range(KDim, 0, k_size) - ), - solve_tridiag2, - x, - [a, b, c, d], + domain = cartesian_domain( + named_range(IDim, 0, i_size), named_range(JDim, 0, j_size), named_range(KDim, 0, k_size) ) + set_at(solve_tridiag2(domain, a, b, c, d), domain, x) @pytest.mark.parametrize("fencil", [fen_solve_tridiag, fen_solve_tridiag2]) -@pytest.mark.uses_lift_expressions def test_tridiag(fencil, tridiag_reference, program_processor): program_processor, validate = program_processor - if program_processor in [ - gtfn.run_gtfn, - gtfn.run_gtfn_imperative, - gtfn_formatters.format_cpp, - ]: - pytest.skip("gtfn does only support lifted scans when using temporaries") - if program_processor == gtfn.run_gtfn_with_temporaries: - pytest.xfail("tuple_get on columns not supported.") + + if isinstance(program_processor, backend.Backend) and "dace" in program_processor.name: + pytest.xfail("Dace ITIR backend doesn't support the IR format used in this test.") + a, b, c, d, x = tridiag_reference shape = a.shape as_3d_field = gtx.as_field.partial([IDim, JDim, KDim]) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 6fb1d4c152..6fdc6a77a1 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -383,7 +383,6 @@ def test_shift_sparse_input_field2(program_processor): if program_processor in [ gtfn.run_gtfn, gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, ]: pytest.xfail( "Bug in bindings/compilation/caching: only the first program seems to be compiled." diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 8a4aa50730..ca66b45d6d 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -25,7 +25,30 @@ ProgramProcessor: TypeAlias = backend.Backend | program_formatter.ProgramFormatter -@pytest.fixture( +def _program_processor(request) -> tuple[ProgramProcessor, bool]: + """ + Fixture creating program processors on-demand for tests. + + Notes: + Check ADR 15 for details on the test-exclusion matrices. + """ + processor_id, is_backend = request.param + if processor_id is None: + return None, is_backend + + processor = processor_id.load() + + for marker, skip_mark, msg in next_tests.definitions.BACKEND_SKIP_TEST_MATRIX.get( + processor_id, [] + ): + if marker == next_tests.definitions.ALL or request.node.get_closest_marker(marker): + skip_mark(msg.format(marker=marker, backend=processor_id)) + + return processor, is_backend + + +program_processor = pytest.fixture( + _program_processor, params=[ (None, True), (next_tests.definitions.ProgramBackendId.ROUNDTRIP, True), @@ -33,7 +56,6 @@ (next_tests.definitions.ProgramBackendId.DOUBLE_ROUNDTRIP, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), - (next_tests.definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, True), # pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation (next_tests.definitions.ProgramFormatterId.LISP_FORMATTER, False), (next_tests.definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), @@ -50,26 +72,16 @@ ], ids=lambda p: p[0].short_id() if p[0] is not None else "None", ) -def program_processor(request) -> tuple[ProgramProcessor, bool]: - """ - Fixture creating program processors on-demand for tests. - - Notes: - Check ADR 15 for details on the test-exclusion matrices. - """ - processor_id, is_backend = request.param - if processor_id is None: - return None, is_backend - - processor = processor_id.load() - - for marker, skip_mark, msg in next_tests.definitions.BACKEND_SKIP_TEST_MATRIX.get( - processor_id, [] - ): - if marker == next_tests.definitions.ALL or request.node.get_closest_marker(marker): - skip_mark(msg.format(marker=marker, backend=processor_id)) - return processor, is_backend +program_processor_no_transforms = pytest.fixture( + _program_processor, + params=[ + (None, True), + (next_tests.definitions.ProgramBackendId.GTFN_CPU_NO_TRANSFORMS, True), + (next_tests.definitions.ProgramBackendId.ROUNDTRIP_NO_TRANSFORMS, True), + ], + ids=lambda p: p[0].short_id() if p[0] is not None else "None", +) def run_processor( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 720076c8c2..28090ff1e2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -20,6 +20,7 @@ def test_simple_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) expected = tuple_of_size_2 @@ -37,6 +38,7 @@ def test_nested_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == tup_of_size2_from_lambda @@ -52,6 +54,7 @@ def test_different_tuples_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == testee # did nothing @@ -65,6 +68,7 @@ def test_incompatible_order_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == testee # did nothing @@ -76,6 +80,7 @@ def test_incompatible_size_make_tuple_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == testee # did nothing @@ -87,6 +92,7 @@ def test_merged_with_smaller_outer_size_make_tuple_tuple_get(): ignore_tuple_size=True, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == im.make_tuple("first", "second") @@ -99,6 +105,7 @@ def test_simple_tuple_get_make_tuple(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE, allow_undeclared_symbols=True, + within_stencil=False, ) assert expected == actual @@ -111,6 +118,7 @@ def test_propagate_tuple_get(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert expected == actual @@ -128,6 +136,7 @@ def test_letify_make_tuple_elements(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -141,6 +150,7 @@ def test_letify_make_tuple_with_trivial_elements(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -154,6 +164,7 @@ def test_inline_trivial_make_tuple(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -172,6 +183,7 @@ def test_propagate_to_if_on_tuples(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -189,6 +201,7 @@ def test_propagate_to_if_on_tuples_with_let(): flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -201,6 +214,7 @@ def test_propagate_nested_lift(): remove_letified_make_tuple_elements=False, flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -211,7 +225,10 @@ def test_if_on_tuples_with_let(): )(im.tuple_get(0, "val")) expected = im.if_("pred", 1, 3) actual = CollapseTuple.apply( - testee, remove_letified_make_tuple_elements=False, allow_undeclared_symbols=True + testee, + remove_letified_make_tuple_elements=False, + allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -220,5 +237,5 @@ def test_tuple_get_on_untyped_ref(): # test pass gracefully handles untyped nodes. testee = im.tuple_get(0, im.ref("val", ts.DeferredType(constraint=None))) - actual = CollapseTuple.apply(testee, allow_undeclared_symbols=True) + actual = CollapseTuple.apply(testee, allow_undeclared_symbols=True, within_stencil=False) assert actual == testee diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index 3204b49371..e04856b75f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -37,7 +37,7 @@ def test_trivial(): ), args=[common], ) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -45,7 +45,7 @@ def test_lambda_capture(): common = ir.FunCall(fun=ir.SymRef(id="plus"), args=[ir.SymRef(id="x"), ir.SymRef(id="y")]) testee = ir.FunCall(fun=ir.Lambda(params=[ir.Sym(id="x")], expr=common), args=[common]) expected = testee - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -53,7 +53,7 @@ def test_lambda_no_capture(): common = im.plus("x", "y") testee = im.call(im.lambda_("z")(im.plus("x", "y")))(im.plus("x", "y")) expected = im.let("_cs_1", common)("_cs_1") - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -65,7 +65,7 @@ def common_expr(): testee = im.call(im.lambda_("x", "y")(common_expr()))(common_expr(), common_expr()) # (λ(_cs_1) → _cs_1 + _cs_1)(x + y) expected = im.let("_cs_1", common_expr())(im.plus("_cs_1", "_cs_1")) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -79,7 +79,7 @@ def common_expr(): expected = im.lambda_("x")( im.let("_cs_1", common_expr())(im.plus("z", im.plus("_cs_1", "_cs_1"))) ) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -93,7 +93,7 @@ def common_expr(): ) # (λ(_cs_1) → _cs_1(2) + _cs_1(3))(λ(a) → a + 1) expected = im.let("_cs_1", common_expr())(im.plus(im.call("_cs_1")(2), im.call("_cs_1")(3))) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -109,7 +109,7 @@ def common_expr(): expected = im.let("_cs_1", common_expr())( im.let("_cs_2", im.call("_cs_1")(2))(im.plus("_cs_2", "_cs_2")) ) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -133,7 +133,7 @@ def common_expr(): ) ) ) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -157,7 +157,7 @@ def test_if_can_deref_no_extraction(offset_provider): ) ) - actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True) + actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) assert actual == expected @@ -178,7 +178,7 @@ def test_if_can_deref_eligible_extraction(offset_provider): ) ) - actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True) + actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) assert actual == expected @@ -191,7 +191,7 @@ def test_if_eligible_extraction(offset_provider): # (λ(_cs_1) → if _cs_1 ∧ _cs_1 then c else d)(a ∧ b) expected = im.let("_cs_1", im.and_("a", "b"))(im.if_(im.and_("_cs_1", "_cs_1"), "c", "d")) - actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True) + actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) assert actual == expected @@ -268,7 +268,7 @@ def test_no_extraction_outside_asfieldop(): identity_fieldop(im.ref("a", field_type)), identity_fieldop(im.ref("b", field_type)) ) - actual = CSE.apply(testee, is_local_view=False) + actual = CSE.apply(testee, within_stencil=False) assert actual == testee @@ -289,5 +289,5 @@ def test_field_extraction_outside_asfieldop(): # ) expected = im.let("_cs_1", identity_fieldop(field))(plus_fieldop("_cs_1", "_cs_1")) - actual = CSE.apply(testee, is_local_view=False) + actual = CSE.apply(testee, within_stencil=False) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 50756f40e7..141091b450 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -84,9 +84,13 @@ def run_test_expr( domain: itir.FunCall, expected_domains: dict[str, itir.Expr | dict[str | Dimension, tuple[itir.Expr, itir.Expr]]], offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]] = None, ): actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, + domain_utils.SymbolicDomain.from_expr(domain), + offset_provider, + symbolic_domain_sizes, ) folded_call = constant_fold_domain_exprs(actual_call) folded_domains = constant_fold_accessed_domains(actual_domains) if actual_domains else None @@ -1021,3 +1025,22 @@ def test_scan(offset_provider): {"a": im.domain(common.GridType.CARTESIAN, {IDim: (1, 12)})}, offset_provider, ) + + +def test_symbolic_domain_sizes(unstructured_offset_provider): + stencil = im.lambda_("arg0")(im.deref(im.shift("E2V", 1)("arg0"))) + domain = im.domain(common.GridType.UNSTRUCTURED, {Edge: (0, 1)}) + symbolic_domain_sizes = {"Vertex": "num_vertices"} + + testee, expected = setup_test_as_fieldop( + stencil, + domain, + ) + run_test_expr( + testee, + expected, + domain, + {"in_field1": {Vertex: (0, im.ref("num_vertices"))}}, + unstructured_offset_provider, + symbolic_domain_sizes, + ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index da2c16336e..b5b9a62009 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -45,6 +45,31 @@ def test_trivial_literal(): assert actual == expected +def test_tuple_arg(): + d = im.domain("cartesian_domain", {}) + testee = im.op_as_fieldop("plus", d)( + im.op_as_fieldop(im.lambda_("t")(im.plus(im.tuple_get(0, "t"), im.tuple_get(1, "t"))), d)( + im.make_tuple(1, 2) + ), + 3, + ) + expected = im.as_fieldop( + im.lambda_()( + im.plus( + im.let("t", im.make_tuple(1, 2))( + im.plus(im.tuple_get(0, "t"), im.tuple_get(1, "t")) + ), + 3, + ) + ), + d, + )() + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={}, allow_undeclared_symbols=True + ) + assert actual == expected + + def test_symref_used_twice(): d = im.domain("cartesian_domain", {IDim: (0, 1)}) testee = im.as_fieldop(im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), d)( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py index 09ed204a91..28bd88b853 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py @@ -26,93 +26,35 @@ def has_skip_values(request): @pytest.fixture def basic_reduction(): UIDs.reset_sequence() - return ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], - ), - args=[ - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim"), ir.SymRef(id="x")], - ) - ], - ) + return im.call(im.call("reduce")("foo", 0.0))(im.neighbors("Dim", "x")) @pytest.fixture def reduction_with_shift_on_second_arg(): UIDs.reset_sequence() - return ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], - ), - args=[ - ir.SymRef(id="x"), - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim"), ir.SymRef(id="y")], - ), - ], - ) + return im.call(im.call("reduce")("foo", 0.0))("x", im.neighbors("Dim", "y")) @pytest.fixture def reduction_with_incompatible_shifts(): UIDs.reset_sequence() - return ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], - ), - args=[ - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim"), ir.SymRef(id="x")], - ), - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim2"), ir.SymRef(id="y")], - ), - ], + return im.call(im.call("reduce")("foo", 0.0))( + im.neighbors("Dim", "x"), im.neighbors("Dim2", "y") ) @pytest.fixture def reduction_with_irrelevant_full_shift(): UIDs.reset_sequence() - return ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], - ), - args=[ - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ - ir.OffsetLiteral(value="Dim"), - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="shift"), - args=[ - ir.OffsetLiteral(value="IrrelevantDim"), - ir.OffsetLiteral(value="0"), - ], - ), - args=[ir.SymRef(id="x")], - ), - ], - ), - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim"), ir.SymRef(id="y")], - ), - ], + return im.call(im.call("reduce")("foo", 0.0))( + im.neighbors("Dim", im.shift("IrrelevantDim", 0)("x")), im.neighbors("Dim", "y") ) -# TODO add a test with lift +@pytest.fixture +def reduction_if(): + UIDs.reset_sequence() + return im.call(im.call("reduce")("foo", 0.0))(im.if_(True, im.neighbors("Dim", "x"), "y")) @pytest.mark.parametrize( @@ -121,6 +63,7 @@ def reduction_with_irrelevant_full_shift(): "basic_reduction", "reduction_with_irrelevant_full_shift", "reduction_with_shift_on_second_arg", + "reduction_if", ], ) def test_get_partial_offsets(reduction, request): @@ -178,6 +121,14 @@ def test_reduction_with_shift_on_second_arg(reduction_with_shift_on_second_arg, assert actual == expected +def test_reduction_with_if(reduction_if): + expected = _expected(reduction_if, "Dim", 2, False) + + offset_provider = {"Dim": DummyConnectivity(max_neighbors=2, has_skip_values=False)} + actual = UnrollReduce.apply(reduction_if, offset_provider=offset_provider) + assert actual == expected + + def test_reduction_with_irrelevant_full_shift(reduction_with_irrelevant_full_shift): expected = _expected(reduction_with_irrelevant_full_shift, "Dim", 3, False) From aeff1e37bb483faebc280776e18b83287aacbe49 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 15 Nov 2024 15:07:22 +0100 Subject: [PATCH 043/178] refactor[catesian]: Type hints and code redability improvements (#1724) ## Description This PR is split off the work for the new GT4Py - DaCe bridge, which should allow to expose control flow statements (`if` and `while`) to DaCe to better use DaCe's analytics capabilities. This PR is concerned with adding type hints and generally improving code readability. Main parts are - `daceir_builder.py`: early returns and renamed variable - `sdfg_builder.py`: type hints and early returns - `tasklet_codegen.py`: type hints and early returns `TaskletCodegen` was given `sdfg_ctx`, which wasn't used. That parameter was thus removed. Parent issue: https://github.com/GEOS-ESM/NDSL/issues/53 ## Requirements - [x] All fixes and/or new features come with corresponding tests. Assumed to be covered by existing tests. - [ ] Important design decisions have been documented in the approriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A --------- Co-authored-by: Roman Cattaneo <> --- src/gt4py/cartesian/gtc/common.py | 2 +- src/gt4py/cartesian/gtc/dace/daceir.py | 2 +- .../gtc/dace/expansion/daceir_builder.py | 84 +++++++++---------- .../gtc/dace/expansion/sdfg_builder.py | 36 ++++---- .../gtc/dace/expansion/tasklet_codegen.py | 64 +++++++------- src/gt4py/cartesian/gtc/dace/utils.py | 5 +- src/gt4py/eve/trees.py | 2 +- src/gt4py/eve/visitors.py | 4 +- 8 files changed, 98 insertions(+), 101 deletions(-) diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index bfe434e7f3..dcb01db7ca 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -311,7 +311,7 @@ class CartesianOffset(eve.Node): k: int @classmethod - def zero(cls) -> "CartesianOffset": + def zero(cls) -> CartesianOffset: return cls(i=0, j=0, k=0) def to_dict(self) -> Dict[str, int]: diff --git a/src/gt4py/cartesian/gtc/dace/daceir.py b/src/gt4py/cartesian/gtc/dace/daceir.py index 0ecb02b50f..78451c30f5 100644 --- a/src/gt4py/cartesian/gtc/dace/daceir.py +++ b/src/gt4py/cartesian/gtc/dace/daceir.py @@ -730,7 +730,7 @@ class Literal(common.Literal, Expr): class ScalarAccess(common.ScalarAccess, Expr): - name: eve.Coerced[eve.SymbolRef] + pass class VariableKOffset(common.VariableKOffset[Expr]): diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index a8a3a3cb54..5f2007871e 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -74,8 +74,8 @@ def _get_tasklet_inout_memlets( *, get_outputs: bool, global_ctx: DaCeIRBuilder.GlobalContext, - **kwargs, -): + **kwargs: Any, +) -> List[dcir.Memlet]: access_infos = compute_dcir_access_infos( node, block_extents=global_ctx.library_node.get_extents, @@ -85,7 +85,7 @@ def _get_tasklet_inout_memlets( **kwargs, ) - res = list() + memlets: List[dcir.Memlet] = [] for name, offset, tasklet_symbol in _access_iter(node, get_outputs=get_outputs): access_info = access_infos[name] if not access_info.variable_offset_axes: @@ -95,26 +95,27 @@ def _get_tasklet_inout_memlets( axis, extent=(offset_dict[axis.lower()], offset_dict[axis.lower()]) ) - memlet = dcir.Memlet( - field=name, - connector=tasklet_symbol, - access_info=access_info, - is_read=not get_outputs, - is_write=get_outputs, + memlets.append( + dcir.Memlet( + field=name, + connector=tasklet_symbol, + access_info=access_info, + is_read=not get_outputs, + is_write=get_outputs, + ) ) - res.append(memlet) - return res + return memlets -def _all_stmts_same_region(scope_nodes, axis: dcir.Axis, interval): - def all_statements_in_region(scope_nodes): +def _all_stmts_same_region(scope_nodes, axis: dcir.Axis, interval: Any) -> bool: + def all_statements_in_region(scope_nodes: List[eve.Node]) -> bool: return all( isinstance(stmt, dcir.HorizontalRestriction) for tasklet in eve.walk_values(scope_nodes).if_isinstance(dcir.Tasklet) for stmt in tasklet.stmts ) - def all_regions_same(scope_nodes): + def all_regions_same(scope_nodes: List[eve.Node]) -> bool: return ( len( set( @@ -179,11 +180,11 @@ def _get_dcir_decl( oir_decl: oir.Decl = self.library_node.declarations[field] assert isinstance(oir_decl, oir.FieldDecl) dace_array = self.arrays[field] - for s in dace_array.strides: - for sym in dace.symbolic.symlist(s).values(): - symbol_collector.add_symbol(str(sym)) - for sym in access_info.grid_subset.free_symbols: - symbol_collector.add_symbol(sym) + for stride in dace_array.strides: + for symbol in dace.symbolic.symlist(stride).values(): + symbol_collector.add_symbol(str(symbol)) + for symbol in access_info.grid_subset.free_symbols: + symbol_collector.add_symbol(symbol) return dcir.FieldDecl( name=field, @@ -236,11 +237,7 @@ def push_expansion_item(self, item: Union[Map, Loop]) -> DaCeIRBuilder.Iteration if not isinstance(item, (Map, Loop)): raise ValueError - if isinstance(item, Map): - iterations = item.iterations - else: - iterations = [item] - + iterations = item.iterations if isinstance(item, Map) else [item] grid_subset = self.grid_subset for it in iterations: axis = it.axis @@ -267,13 +264,13 @@ def pop(self) -> DaCeIRBuilder.IterationContext: class SymbolCollector: symbol_decls: Dict[str, dcir.SymbolDecl] = dataclasses.field(default_factory=dict) - def add_symbol(self, name: str, dtype: common.DataType = common.DataType.INT32): + def add_symbol(self, name: str, dtype: common.DataType = common.DataType.INT32) -> None: if name not in self.symbol_decls: self.symbol_decls[name] = dcir.SymbolDecl(name=name, dtype=dtype) else: assert self.symbol_decls[name].dtype == dtype - def remove_symbol(self, name: eve.SymbolRef): + def remove_symbol(self, name: eve.SymbolRef) -> None: if name in self.symbol_decls: del self.symbol_decls[name] @@ -304,11 +301,14 @@ def visit_HorizontalRestriction( symbol_collector.add_symbol(axis.iteration_symbol()) if bound.level == common.LevelMarker.END: symbol_collector.add_symbol(axis.domain_symbol()) + return dcir.HorizontalRestriction( mask=node.mask, body=self.visit(node.body, symbol_collector=symbol_collector, **kwargs) ) - def visit_VariableKOffset(self, node: oir.VariableKOffset, **kwargs): + def visit_VariableKOffset( + self, node: oir.VariableKOffset, **kwargs: Any + ) -> dcir.VariableKOffset: return dcir.VariableKOffset(k=self.visit(node.k, **kwargs)) def visit_LocalScalar(self, node: oir.LocalScalar, **kwargs: Any) -> dcir.LocalScalarDecl: @@ -419,7 +419,7 @@ def visit_HorizontalExecution( symbol_collector: DaCeIRBuilder.SymbolCollector, loop_order, k_interval, - **kwargs, + **kwargs: Any, ): # skip type checking due to https://github.com/python/mypy/issues/5485 extent = global_ctx.library_node.get_extents(node) # type: ignore @@ -581,7 +581,7 @@ def to_dataflow( nodes = flatten_list(nodes) if all(isinstance(n, (dcir.NestedSDFG, dcir.DomainMap, dcir.Tasklet)) for n in nodes): return nodes - elif not all(isinstance(n, (dcir.ComputationState, dcir.DomainLoop)) for n in nodes): + if not all(isinstance(n, (dcir.ComputationState, dcir.DomainLoop)) for n in nodes): raise ValueError("Can't mix dataflow and state nodes on same level.") read_memlets, write_memlets, field_memlets = union_inout_memlets(nodes) @@ -615,10 +615,10 @@ def to_state(self, nodes, *, grid_subset: dcir.GridSubset): nodes = flatten_list(nodes) if all(isinstance(n, (dcir.ComputationState, dcir.DomainLoop)) for n in nodes): return nodes - elif all(isinstance(n, (dcir.NestedSDFG, dcir.DomainMap, dcir.Tasklet)) for n in nodes): + if all(isinstance(n, (dcir.NestedSDFG, dcir.DomainMap, dcir.Tasklet)) for n in nodes): return [dcir.ComputationState(computations=nodes, grid_subset=grid_subset)] - else: - raise ValueError("Can't mix dataflow and state nodes on same level.") + + raise ValueError("Can't mix dataflow and state nodes on same level.") def _process_map_item( self, @@ -628,8 +628,8 @@ def _process_map_item( global_ctx: DaCeIRBuilder.GlobalContext, iteration_ctx: DaCeIRBuilder.IterationContext, symbol_collector: DaCeIRBuilder.SymbolCollector, - **kwargs, - ): + **kwargs: Any, + ) -> List[dcir.DomainMap]: grid_subset = iteration_ctx.grid_subset read_memlets, write_memlets, _ = union_inout_memlets(list(scope_nodes)) scope_nodes = self.to_dataflow( @@ -723,11 +723,11 @@ def _process_loop_item( scope_nodes, item: Loop, *, - global_ctx, + global_ctx: DaCeIRBuilder.GlobalContext, iteration_ctx: DaCeIRBuilder.IterationContext, symbol_collector: DaCeIRBuilder.SymbolCollector, - **kwargs, - ): + **kwargs: Any, + ) -> List[dcir.DomainLoop]: grid_subset = union_node_grid_subsets(list(scope_nodes)) read_memlets, write_memlets, _ = union_inout_memlets(list(scope_nodes)) scope_nodes = self.to_state(scope_nodes, grid_subset=grid_subset) @@ -793,14 +793,14 @@ def _process_loop_item( def _process_iteration_item(self, scope, item, **kwargs): if isinstance(item, Map): return self._process_map_item(scope, item, **kwargs) - elif isinstance(item, Loop): + if isinstance(item, Loop): return self._process_loop_item(scope, item, **kwargs) - else: - raise ValueError("Invalid expansion specification set.") + + raise ValueError("Invalid expansion specification set.") def visit_VerticalLoop( - self, node: oir.VerticalLoop, *, global_ctx: DaCeIRBuilder.GlobalContext, **kwargs - ): + self, node: oir.VerticalLoop, *, global_ctx: DaCeIRBuilder.GlobalContext, **kwargs: Any + ) -> dcir.NestedSDFG: overall_extent = Extent.zeros(2) for he in node.walk_values().if_isinstance(oir.HorizontalExecution): overall_extent = overall_extent.union(global_ctx.library_node.get_extents(he)) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py index 7b0f0ab7c4..6728ccaa7d 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py @@ -89,7 +89,7 @@ def visit_Memlet( scope_node: dcir.ComputationNode, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, node_ctx: StencilComputationSDFGBuilder.NodeContext, - connector_prefix="", + connector_prefix: str = "", symtable: ChainMap[eve.SymbolRef, dcir.Decl], ) -> None: field_decl = symtable[node.field] @@ -139,13 +139,12 @@ def visit_Tasklet( sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, node_ctx: StencilComputationSDFGBuilder.NodeContext, symtable: ChainMap[eve.SymbolRef, dcir.Decl], - **kwargs, + **kwargs: Any, ) -> None: code = TaskletCodegen.apply_codegen( node, read_memlets=node.read_memlets, write_memlets=node.write_memlets, - sdfg_ctx=sdfg_ctx, symtable=symtable, ) @@ -177,7 +176,7 @@ def visit_Tasklet( tasklet, tasklet, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx ) - def visit_Range(self, node: dcir.Range, **kwargs) -> Dict[str, str]: + def visit_Range(self, node: dcir.Range, **kwargs: Any) -> Dict[str, str]: start, end = node.interval.to_dace_symbolic() return {node.var: str(dace.subsets.Range([(start, end - 1, node.stride)]))} @@ -187,7 +186,7 @@ def visit_DomainMap( *, node_ctx: StencilComputationSDFGBuilder.NodeContext, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - **kwargs, + **kwargs: Any, ) -> None: ndranges = { k: v @@ -248,7 +247,7 @@ def visit_DomainLoop( node: dcir.DomainLoop, *, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - **kwargs, + **kwargs: Any, ) -> None: sdfg_ctx = sdfg_ctx.add_loop(node.index_range) self.visit(node.loop_states, sdfg_ctx=sdfg_ctx, **kwargs) @@ -259,7 +258,7 @@ def visit_ComputationState( node: dcir.ComputationState, *, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - **kwargs, + **kwargs: Any, ) -> None: sdfg_ctx.add_state() read_acc_and_conn: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] = {} @@ -289,7 +288,7 @@ def visit_FieldDecl( *, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, non_transients: Set[eve.SymbolRef], - **kwargs, + **kwargs: Any, ) -> None: assert len(node.strides) == len(node.shape) sdfg_ctx.sdfg.add_array( @@ -307,7 +306,7 @@ def visit_SymbolDecl( node: dcir.SymbolDecl, *, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - **kwargs, + **kwargs: Any, ) -> None: if node.name not in sdfg_ctx.sdfg.symbols: sdfg_ctx.sdfg.add_symbol(node.name, stype=data_type_to_dace_typeclass(node.dtype)) @@ -319,7 +318,7 @@ def visit_NestedSDFG( sdfg_ctx: Optional[StencilComputationSDFGBuilder.SDFGContext] = None, node_ctx: Optional[StencilComputationSDFGBuilder.NodeContext] = None, symtable: ChainMap[eve.SymbolRef, Any], - **kwargs, + **kwargs: Any, ) -> dace.nodes.NestedSDFG: sdfg = dace.SDFG(node.label) inner_sdfg_ctx = StencilComputationSDFGBuilder.SDFGContext( @@ -365,13 +364,12 @@ def visit_NestedSDFG( StencilComputationSDFGBuilder._add_empty_edges( nsdfg, nsdfg, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx ) - else: - nsdfg = dace.nodes.NestedSDFG( - label=sdfg.label, - sdfg=sdfg, - inputs={memlet.connector for memlet in node.read_memlets}, - outputs={memlet.connector for memlet in node.write_memlets}, - symbol_mapping=symbol_mapping, - ) + return nsdfg - return nsdfg + return dace.nodes.NestedSDFG( + label=sdfg.label, + sdfg=sdfg, + inputs={memlet.connector for memlet in node.read_memlets}, + outputs={memlet.connector for memlet in node.write_memlets}, + symbol_mapping=symbol_mapping, + ) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py index 696dc27387..8033c64710 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py @@ -31,7 +31,7 @@ def _visit_offset( *, access_info: dcir.FieldAccessInfo, decl: dcir.FieldDecl, - **kwargs, + **kwargs: Any, ) -> str: int_sizes: List[Optional[int]] = [] for i, axis in enumerate(access_info.axes()): @@ -60,27 +60,27 @@ def _visit_offset( res = dace.subsets.Range([r for i, r in enumerate(ranges.ranges) if int_sizes[i] != 1]) return str(res) - def visit_CartesianOffset(self, node: common.CartesianOffset, **kwargs): + def visit_CartesianOffset(self, node: common.CartesianOffset, **kwargs: Any) -> str: return self._visit_offset(node, **kwargs) - def visit_VariableKOffset(self, node: common.CartesianOffset, **kwargs): + def visit_VariableKOffset(self, node: dcir.VariableKOffset, **kwargs: Any) -> str: return self._visit_offset(node, **kwargs) def visit_IndexAccess( self, node: dcir.IndexAccess, *, - is_target, - sdfg_ctx, + is_target: bool, symtable: ChainMap[eve.SymbolRef, dcir.Decl], - **kwargs, - ): + **kwargs: Any, + ) -> str: if is_target: memlets = kwargs["write_memlets"] else: # if this node is not a target, it will still use the symbol of the write memlet if the # field was previously written in the same memlet. memlets = kwargs["read_memlets"] + kwargs["write_memlets"] + try: memlet = next(mem for mem in memlets if mem.connector == node.name) except StopIteration: @@ -101,12 +101,12 @@ def visit_IndexAccess( ) ) index_strs.extend( - self.visit(idx, sdfg_ctx=sdfg_ctx, symtable=symtable, in_idx=True, **kwargs) - for idx in node.data_index + self.visit(idx, symtable=symtable, in_idx=True, **kwargs) for idx in node.data_index ) return f"{node.name}[{','.join(index_strs)}]" - def visit_AssignStmt(self, node: dcir.AssignStmt, **kwargs): + def visit_AssignStmt(self, node: dcir.AssignStmt, **kwargs: Any) -> str: + # Visiting order matters because targets must not contain the target symbols from the left visit right = self.visit(node.right, is_target=False, **kwargs) left = self.visit(node.left, is_target=True, **kwargs) return f"{left} = {right}" @@ -120,18 +120,18 @@ def visit_AssignStmt(self, node: dcir.AssignStmt, **kwargs): def visit_BuiltInLiteral(self, builtin: common.BuiltInLiteral, **kwargs: Any) -> str: if builtin == common.BuiltInLiteral.TRUE: return "True" - elif builtin == common.BuiltInLiteral.FALSE: + if builtin == common.BuiltInLiteral.FALSE: return "False" raise NotImplementedError("Not implemented BuiltInLiteral encountered.") - def visit_Literal(self, literal: dcir.Literal, *, in_idx=False, **kwargs): + def visit_Literal(self, literal: dcir.Literal, *, in_idx=False, **kwargs: Any) -> str: value = self.visit(literal.value, in_idx=in_idx, **kwargs) if in_idx: return str(value) - else: - return "{dtype}({value})".format( - dtype=self.visit(literal.dtype, in_idx=in_idx, **kwargs), value=value - ) + + return "{dtype}({value})".format( + dtype=self.visit(literal.dtype, in_idx=in_idx, **kwargs), value=value + ) Cast = as_fmt("{dtype}({expr})") @@ -178,26 +178,26 @@ def visit_NativeFuncCall(self, call: common.NativeFuncCall, **kwargs: Any) -> st def visit_DataType(self, dtype: common.DataType, **kwargs: Any) -> str: if dtype == common.DataType.BOOL: return "dace.bool_" - elif dtype == common.DataType.INT8: + if dtype == common.DataType.INT8: return "dace.int8" - elif dtype == common.DataType.INT16: + if dtype == common.DataType.INT16: return "dace.int16" - elif dtype == common.DataType.INT32: + if dtype == common.DataType.INT32: return "dace.int32" - elif dtype == common.DataType.INT64: + if dtype == common.DataType.INT64: return "dace.int64" - elif dtype == common.DataType.FLOAT32: + if dtype == common.DataType.FLOAT32: return "dace.float32" - elif dtype == common.DataType.FLOAT64: + if dtype == common.DataType.FLOAT64: return "dace.float64" raise NotImplementedError("Not implemented DataType encountered.") def visit_UnaryOperator(self, op: common.UnaryOperator, **kwargs: Any) -> str: if op == common.UnaryOperator.NOT: return " not " - elif op == common.UnaryOperator.NEG: + if op == common.UnaryOperator.NEG: return "-" - elif op == common.UnaryOperator.POS: + if op == common.UnaryOperator.POS: return "+" raise NotImplementedError("Not implemented UnaryOperator encountered.") @@ -207,16 +207,16 @@ def visit_UnaryOperator(self, op: common.UnaryOperator, **kwargs: Any) -> str: LocalScalarDecl = as_fmt("{name}: {dtype}") - def visit_Tasklet(self, node: dcir.Tasklet, **kwargs): + def visit_Tasklet(self, node: dcir.Tasklet, **kwargs: Any) -> str: return "\n".join(self.visit(node.decls, **kwargs) + self.visit(node.stmts, **kwargs)) def _visit_conditional( self, cond: Optional[Union[dcir.Expr, common.HorizontalMask]], body: List[dcir.Stmt], - keyword, - **kwargs, - ): + keyword: str, + **kwargs: Any, + ) -> str: mask_str = "" indent = "" if cond is not None and (cond_str := self.visit(cond, is_target=False, **kwargs)): @@ -226,16 +226,16 @@ def _visit_conditional( body_code = [indent + b for b in body_code] return "\n".join([mask_str, *body_code]) - def visit_MaskStmt(self, node: dcir.MaskStmt, **kwargs): + def visit_MaskStmt(self, node: dcir.MaskStmt, **kwargs: Any) -> str: return self._visit_conditional(cond=node.mask, body=node.body, keyword="if", **kwargs) - def visit_HorizontalRestriction(self, node: dcir.HorizontalRestriction, **kwargs): + def visit_HorizontalRestriction(self, node: dcir.HorizontalRestriction, **kwargs: Any) -> str: return self._visit_conditional(cond=node.mask, body=node.body, keyword="if", **kwargs) - def visit_While(self, node: dcir.While, **kwargs): + def visit_While(self, node: dcir.While, **kwargs: Any) -> Any: return self._visit_conditional(cond=node.cond, body=node.body, keyword="while", **kwargs) - def visit_HorizontalMask(self, node: common.HorizontalMask, **kwargs): + def visit_HorizontalMask(self, node: common.HorizontalMask, **kwargs: Any) -> str: clauses: List[str] = [] for axis, interval in zip(dcir.Axis.dims_horizontal(), node.intervals): diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index b5c23d2735..517e80ceb3 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -333,10 +333,9 @@ def compute_dcir_access_infos( global_grid_subset=access_info.global_grid_subset, ) ) - else: - res = ctx.access_infos + return res - return res + return ctx.access_infos def make_dace_subset( diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index 27f19d2670..c8e8658413 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -32,7 +32,7 @@ try: - # For perfomance reasons, try to use cytoolz when possible (using cython) + # For performance reasons, try to use cytoolz when possible (using cython) import cytoolz as toolz except ModuleNotFoundError: # Fall back to pure Python toolz diff --git a/src/gt4py/eve/visitors.py b/src/gt4py/eve/visitors.py index 28d1e2acf6..59b4ef0881 100644 --- a/src/gt4py/eve/visitors.py +++ b/src/gt4py/eve/visitors.py @@ -45,7 +45,7 @@ class NodeVisitor: 3. ``self.generic_visit()``. This dispatching mechanism is implemented in the main :meth:`visit` - method and can be overriden in subclasses. Therefore, a simple way to extend + method and can be overridden in subclasses. Therefore, a simple way to extend the behavior of a visitor is by inheriting from lightweight `trait` classes with a custom ``visit()`` method, which wraps the call to the superclass' ``visit()`` and adds extra pre and post visit logic. Check :mod:`eve.traits` @@ -82,7 +82,7 @@ def apply(cls, tree, init_var, foo, bar=5, **kwargs): Notes: If you want to apply changes to nodes during the traversal, - use the :class:`NodeMutator` subclass, which handles correctly + use the :class:`NodeTranslator` subclass, which handles correctly structural modifications of the visited tree. """ From ea8d9dbfefa16ac71dee5afa48367d7018861721 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 15 Nov 2024 15:57:00 +0100 Subject: [PATCH 044/178] ci: Bump gitlab ci on todi to ubuntu 22.04, cuda 12.6.2, cupy 13.3.0 (#1727) We were using ubuntu 22.04 which shipped with gcc 9.x.x. In order to get something more recent with proper utf-8 support I bumped to 22.04 on todi. On daint strange hangs occured so I kept everything as is there. --- ci/base.Dockerfile | 10 ++++++---- ci/cscs-ci.yml | 8 ++++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ci/base.Dockerfile b/ci/base.Dockerfile index d20d9ca6ef..ea7c4722c7 100644 --- a/ci/base.Dockerfile +++ b/ci/base.Dockerfile @@ -1,5 +1,6 @@ -ARG CUDA_VERSION=12.5.0 -FROM docker.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 +ARG CUDA_VERSION=12.6.2 +ARG UBUNTU_VERSION=22.04 +FROM docker.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} ENV LANG C.UTF-8 ENV LC_ALL C.UTF-8 @@ -22,7 +23,7 @@ RUN apt-get update -qq && apt-get install -qq -y --no-install-recommends \ tk-dev \ libffi-dev \ liblzma-dev \ - python-openssl \ + $( [ "${UBUNTU_VERSION}" = "20.04" ] && echo "python-openssl" || echo "python3-openssl" ) \ libreadline-dev \ git \ rustc \ @@ -55,4 +56,5 @@ RUN pyenv update && \ ENV PATH="/root/.pyenv/shims:${PATH}" ARG CUPY_PACKAGE=cupy-cuda12x -RUN pip install --upgrade pip setuptools wheel tox ${CUPY_PACKAGE}==12.3.0 +ARG CUPY_VERSION=13.3.0 +RUN pip install --upgrade pip setuptools wheel tox ${CUPY_PACKAGE}==${CUPY_VERSION} diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index 7fcd65106d..e2833e3cd9 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -42,17 +42,21 @@ stages: DOCKERFILE: ci/base.Dockerfile # change to 'always' if you want to rebuild, even if target tag exists already (if-not-exists is the default, i.e. we could also skip the variable) CSCS_REBUILD_POLICY: if-not-exists - DOCKER_BUILD_ARGS: '["CUDA_VERSION=$CUDA_VERSION", "CUPY_PACKAGE=$CUPY_PACKAGE", "PYVERSION=$PYVERSION", "CI_PROJECT_DIR=$CI_PROJECT_DIR"]' + DOCKER_BUILD_ARGS: '["CUDA_VERSION=$CUDA_VERSION", "CUPY_PACKAGE=$CUPY_PACKAGE", "CUPY_VERSION=$CUPY_VERSION", "UBUNTU_VERSION=$UBUNTU_VERSION", "PYVERSION=$PYVERSION", "CI_PROJECT_DIR=$CI_PROJECT_DIR"]' .build_baseimage_x86_64: extends: [.container-builder-cscs-zen2, .build_baseimage] variables: CUDA_VERSION: 11.2.2 CUPY_PACKAGE: cupy-cuda11x + CUPY_VERSION: 12.3.0 # latest version that supports cuda 11 + UBUNTU_VERSION: 20.04 # 22.04 hangs on daint in some tests for unknown reasons. .build_baseimage_aarch64: extends: [.container-builder-cscs-gh200, .build_baseimage] variables: - CUDA_VERSION: 12.4.1 + CUDA_VERSION: 12.6.2 CUPY_PACKAGE: cupy-cuda12x + CUPY_VERSION: 13.3.0 + UBUNTU_VERSION: 22.04 # TODO: enable CI job when Todi is back in operational state when: manual From a00154ada421a05a700343fc648693c0ce78efc8 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 18 Nov 2024 08:51:48 +0100 Subject: [PATCH 045/178] feat[next][dace]: Use offset_type to represent neighborhood information for local dimensions (#1734) This PR adopts the `offset_type` design concept implemented in #1703 for Embedded-GTIR and applies it to the DaCe-GTIR backend. The only functional change is that the if-builtin is now expected to return the exact same data type, including the same `offset_type` if a local dimension is present in the result field. This change required updates to `test_gtir_reduce_with_cond_neighbors`. --- .../gtir_builtin_translators.py | 155 +++++---- .../runners/dace_fieldview/gtir_dataflow.py | 315 ++++++++++-------- .../runners/dace_fieldview/gtir_sdfg.py | 18 +- .../runners/dace_fieldview/utility.py | 9 +- .../dace_tests/test_gtir_to_sdfg.py | 213 +++++------- 5 files changed, 352 insertions(+), 358 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index bb37440fe2..69aedf44d6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -44,15 +44,50 @@ class FieldopData: Args: dc_node: DaCe access node to the data storage. - gt_dtype: GT4Py type definition, which includes the field domain information. - local_offset: Provides information about the local dimension in`FieldType` data. - Set to 'None' for scalar data. Can be 'None' for `FieldType` data with - only global (horizontal or vertical) dimensions. + gt_type: GT4Py type definition, which includes the field domain information. """ dc_node: dace.nodes.AccessNode - gt_dtype: ts.FieldType | ts.ScalarType - local_offset: Optional[str] + gt_type: ts.FieldType | ts.ScalarType + + def get_local_view( + self, domain: FieldopDomain + ) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: + """Helper method to access a field in local view, given a field operator domain.""" + if isinstance(self.gt_type, ts.ScalarType): + return gtir_dataflow.MemletExpr( + dc_node=self.dc_node, gt_dtype=self.gt_type, subset=sbs.Indices([0]) + ) + + if isinstance(self.gt_type, ts.FieldType): + indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { + dim: gtir_dataflow.SymbolExpr(dace_gtir_utils.get_map_variable(dim), INDEX_DTYPE) + for dim, _, _ in domain + } + local_dims = [ + dim for dim in self.gt_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL + ] + + if len(local_dims) == 0: + return gtir_dataflow.IteratorExpr( + self.dc_node, self.gt_type.dtype, self.gt_type.dims, indices + ) + + elif len(local_dims) == 1: + field_dtype = itir_ts.ListType( + element_type=self.gt_type.dtype, offset_type=local_dims[0] + ) + field_dims = [ + dim for dim in self.gt_type.dims if dim.kind != gtx_common.DimensionKind.LOCAL + ] + return gtir_dataflow.IteratorExpr(self.dc_node, field_dtype, field_dims, indices) + + else: + raise ValueError( + f"Unexpected data field {self.dc_node.data} with more than one local dimension." + ) + + raise NotImplementedError(f"Node type {type(self.gt_type)} not supported.") FieldopDomain: TypeAlias = list[ @@ -111,31 +146,13 @@ def _parse_fieldop_arg( ) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: """Helper method to visit an expression passed as argument to a field operator.""" - arg = sdfg_builder.visit( - node, - sdfg=sdfg, - head_state=state, - ) + arg = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) # arguments passed to field operator should be plain fields, not tuples of fields if not isinstance(arg, FieldopData): raise ValueError(f"Received {node} as argument to field operator, expected a field.") - if isinstance(arg.gt_dtype, ts.ScalarType): - return gtir_dataflow.MemletExpr(arg.dc_node, sbs.Indices([0])) - elif isinstance(arg.gt_dtype, ts.FieldType): - indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { - dim: gtir_dataflow.SymbolExpr(dace_gtir_utils.get_map_variable(dim), INDEX_DTYPE) - for dim, _, _ in domain - } - dims = arg.gt_dtype.dims + ( - # we add an extra anonymous dimension in the iterator definition to enable - # dereferencing elements in `ListType` - [gtx_common.Dimension("")] if isinstance(arg.gt_dtype.dtype, itir_ts.ListType) else [] - ) - return gtir_dataflow.IteratorExpr(arg.dc_node, dims, indices, arg.local_offset) - else: - raise NotImplementedError(f"Node type {type(arg.gt_dtype)} not supported.") + return arg.get_local_view(domain) def _get_field_shape( @@ -178,20 +195,27 @@ def _create_temporary_field( if isinstance(output_desc, dace.data.Array): assert isinstance(node_type.dtype, itir_ts.ListType) assert isinstance(node_type.dtype.element_type, ts.ScalarType) - field_dtype = node_type.dtype.element_type + assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype.element_type) # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) field_shape.extend(output_desc.shape) elif isinstance(output_desc, dace.data.Scalar): - field_dtype = node_type.dtype + assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) else: raise ValueError(f"Cannot create field for dace type {output_desc}.") # allocate local temporary storage - temp_name, _ = sdfg.add_temp_transient(field_shape, dace_utils.as_dace_type(field_dtype)) + temp_name, _ = sdfg.add_temp_transient(field_shape, output_desc.dtype) field_node = state.add_access(temp_name) - field_type = ts.FieldType(field_dims, node_type.dtype) - return FieldopData(field_node, field_type, local_offset=dataflow_output.result.local_offset) + if isinstance(dataflow_output.result.gt_dtype, ts.ScalarType): + field_dtype = dataflow_output.result.gt_dtype + else: + assert isinstance(dataflow_output.result.gt_dtype.element_type, ts.ScalarType) + field_dtype = dataflow_output.result.gt_dtype.element_type + assert dataflow_output.result.gt_dtype.offset_type is not None + field_dims.append(dataflow_output.result.gt_dtype.offset_type) + + return FieldopData(field_node, ts.FieldType(field_dims, field_dtype)) def extract_domain(node: gtir.Node) -> FieldopDomain: @@ -273,7 +297,6 @@ def translate_as_fieldop( if isinstance(node.type.dtype, itir_ts.ListType): assert isinstance(output_desc, dace.data.Array) - assert set(output_desc.offset) == {0} # additional local dimension for neighbors # TODO(phimuell): Investigate if we should swap the two. output_subset = sbs.Range.from_indices(domain_indices) + sbs.Range.from_array(output_desc) @@ -383,7 +406,7 @@ def translate_broadcast_scalar( external_edges=True, ) - return FieldopData(output_node, ts.FieldType(field_dims, gt_dtype), local_offset=None) + return FieldopData(output_node, ts.FieldType(field_dims, gt_dtype)) def translate_if( @@ -439,14 +462,14 @@ def translate_if( head_state=false_state, ) - def make_temps(output_data: FieldopData) -> FieldopData: - desc = output_data.dc_node.desc(sdfg) - data_name, _ = sdfg.add_temp_transient_like(desc) - data_node = state.add_access(data_name) + def construct_output(inner_data: FieldopData) -> FieldopData: + inner_desc = inner_data.dc_node.desc(sdfg) + outer, _ = sdfg.add_temp_transient_like(inner_desc) + outer_node = state.add_access(outer) - return FieldopData(data_node, output_data.gt_dtype, output_data.local_offset) + return FieldopData(outer_node, inner_data.gt_type) - result_temps = gtx_utils.tree_map(make_temps)(true_br_args) + result_temps = gtx_utils.tree_map(construct_output)(true_br_args) fields: Iterable[tuple[FieldopData, FieldopData, FieldopData]] = zip( gtx_utils.flatten_nested_tuple((true_br_args,)), @@ -456,7 +479,10 @@ def make_temps(output_data: FieldopData) -> FieldopData: ) for true_br, false_br, temp in fields: - assert true_br.gt_dtype == false_br.gt_dtype + if true_br.gt_type != false_br.gt_type: + raise ValueError( + f"Different type of result fields on if-branches '{true_br.gt_type}' vs '{false_br.gt_type}'." + ) true_br_node = true_br.dc_node false_br_node = false_br.dc_node @@ -482,40 +508,31 @@ def _get_data_nodes( sdfg: dace.SDFG, state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, - sym_name: str, - sym_type: ts.DataType, + data_name: str, + data_type: ts.DataType, ) -> FieldopResult: - if isinstance(sym_type, ts.FieldType): - sym_node = state.add_access(sym_name) - local_dims = [dim for dim in sym_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL] - if len(local_dims) > 1: - raise ValueError(f"Field {sym_name} has more than one local dimension.") - elif len(local_dims) == 1: - # we ensure that the name of the local dimension corresponds to a valid - # connectivity-based offset provider - local_offset = next(iter(local_dims)).value - assert isinstance( - sdfg_builder.get_offset_provider(local_offset), gtx_common.Connectivity - ) - else: - local_offset = None - return FieldopData(sym_node, sym_type, local_offset) - elif isinstance(sym_type, ts.ScalarType): - if sym_name in sdfg.symbols: - sym_node = _get_symbolic_value( - sdfg, state, sdfg_builder, sym_name, sym_type, temp_name=f"__{sym_name}" + if isinstance(data_type, ts.FieldType): + data_node = state.add_access(data_name) + return FieldopData(data_node, data_type) + + elif isinstance(data_type, ts.ScalarType): + if data_name in sdfg.symbols: + data_node = _get_symbolic_value( + sdfg, state, sdfg_builder, data_name, data_type, temp_name=f"__{data_name}" ) else: - sym_node = state.add_access(sym_name) - return FieldopData(sym_node, sym_type, local_offset=None) - elif isinstance(sym_type, ts.TupleType): - tuple_fields = dace_gtir_utils.get_tuple_fields(sym_name, sym_type) + data_node = state.add_access(data_name) + return FieldopData(data_node, data_type) + + elif isinstance(data_type, ts.TupleType): + tuple_fields = dace_gtir_utils.get_tuple_fields(data_name, data_type) return tuple( _get_data_nodes(sdfg, state, sdfg_builder, fname, ftype) for fname, ftype in tuple_fields ) + else: - raise NotImplementedError(f"Symbol type {type(sym_type)} not supported.") + raise NotImplementedError(f"Symbol type {type(data_type)} not supported.") def _get_symbolic_value( @@ -562,7 +579,7 @@ def translate_literal( data_type = node.type data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type) - return FieldopData(data_node, data_type, local_offset=None) + return FieldopData(data_node, data_type) def translate_make_tuple( @@ -646,7 +663,7 @@ def translate_scalar_expr( sdfg=sdfg, head_state=state, ) - if not (isinstance(arg, FieldopData) and isinstance(arg.gt_dtype, ts.ScalarType)): + if not (isinstance(arg, FieldopData) and isinstance(arg.gt_type, ts.ScalarType)): raise ValueError(f"Invalid argument to scalar expression {arg_expr}.") param = f"__arg{i}" args.append(arg.dc_node) @@ -691,7 +708,7 @@ def translate_scalar_expr( dace.Memlet(data=temp_name, subset="0"), ) - return FieldopData(temp_node, node.type, local_offset=None) + return FieldopData(temp_node, node.type) def translate_symbol_ref( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index cf91d15aba..73b6e2ed4c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -29,6 +29,12 @@ from gt4py.next.type_system import type_info as ti, type_specifications as ts +# Magic local dimension for the result of a `make_const_list`. +# A clean implementation will probably involve to tag the `make_const_list` +# with the neighborhood it is meant to be used with. +_CONST_DIM = gtx_common.Dimension(value="_CONST_DIM", kind=gtx_common.DimensionKind.LOCAL) + + @dataclasses.dataclass(frozen=True) class ValueExpr: """ @@ -41,15 +47,12 @@ class ValueExpr: the result of a field operator, basically the data storage outside a global map. Args: - dc_node: Access node to the data storage, can be either a scalar or a local list. - gt_dtype: GT4Py type definition, which includes the field domain information. - local_offset: Provides information about the local dimension in`FieldType` data. - For a more detailed explanation see `gtir_builtin_translators.FieldopData`. + dc_node: Access node to the data container, can be either a scalar or a local list. + gt_dtype: GT4Py data type, which includes the `offset_type` local dimension for lists. """ dc_node: dace.nodes.AccessNode gt_dtype: itir_ts.ListType | ts.ScalarType - local_offset: Optional[str] = None @dataclasses.dataclass(frozen=True) @@ -58,15 +61,14 @@ class MemletExpr: Scalar or array data access through a memlet. Args: - dc_node: Access node to the data storage, can be either a scalar or a local list. + dc_node: Access node to the data container, can be either a scalar or a local list. + gt_dtype: GT4Py data type, which includes the `offset_type` local dimension for lists. subset: Represents the subset to use in memlet to access the above data. - local_offset: Provides information about the local dimension in`FieldType` data. - For a more detailed explanation see `gtir_builtin_translators.FieldopData`. """ dc_node: dace.nodes.AccessNode + gt_dtype: itir_ts.ListType | ts.ScalarType subset: sbs.Indices | sbs.Range - local_offset: Optional[str] = None @dataclasses.dataclass(frozen=True) @@ -87,19 +89,17 @@ class IteratorExpr: Args: field: Access node to the field this iterator operates on. + gt_dtype: GT4Py data type, which includes the `offset_type` local dimension for lists. dimensions: Field domain represented as a sorted list of dimensions, needed to order the map index variables and dereference an element in the field. indices: Maps each dimension to an index value, which could be either a symbolic value or the result of a tasklet computation like neighbors connectivity or dynamic offset. - local_offset: Provides information about the local dimension in`FieldType` data. - For a more detailed explanation see `gtir_builtin_translators.FieldopData`. - """ field: dace.nodes.AccessNode + gt_dtype: itir_ts.ListType | ts.ScalarType dimensions: list[gtx_common.Dimension] indices: dict[gtx_common.Dimension, DataExpr] - local_offset: Optional[str] = None class DataflowInputEdge(Protocol): @@ -383,18 +383,18 @@ def _construct_tasklet_result( dc_dtype: dace.typeclass, src_node: dace.nodes.Tasklet, src_connector: str, - local_offset: Optional[str] = None, use_array: bool = False, ) -> ValueExpr: - temp_name = self.sdfg.temp_data_name() + data_type = dace_utils.as_itir_type(dc_dtype) if use_array: # In some cases, such as result data with list-type annotation, we want # that output data is represented as an array (single-element 1D array) # in order to allow for composition of array shape in external memlets. - self.sdfg.add_array(temp_name, (1,), dc_dtype, transient=True) + temp_name, _ = self.sdfg.add_temp_transient((1,), dc_dtype) else: + temp_name = self.sdfg.temp_data_name() self.sdfg.add_scalar(temp_name, dc_dtype, transient=True) - data_type = dace_utils.as_itir_type(dc_dtype) + temp_node = self.state.add_access(temp_name) self._add_edge( src_node, @@ -403,7 +403,14 @@ def _construct_tasklet_result( None, dace.Memlet(data=temp_name, subset="0"), ) - return ValueExpr(temp_node, data_type, local_offset) + return ValueExpr( + dc_node=temp_node, + gt_dtype=( + itir_ts.ListType(element_type=data_type, offset_type=_CONST_DIM) + if use_array + else data_type + ), + ) def _visit_deref(self, node: gtir.FunCall) -> DataExpr: """ @@ -435,81 +442,87 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: # deref a zero-dimensional field assert len(arg_expr.dimensions) == 0 assert isinstance(node.type, ts.ScalarType) - return MemletExpr(arg_expr.field, subset="0") + return MemletExpr(arg_expr.field, arg_expr.gt_dtype, subset="0") + # default case: deref a field with one or more dimensions - assert len(field_desc.shape) == len(arg_expr.dimensions) if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()): # when all indices are symblic expressions, we can perform direct field access through a memlet + if isinstance(arg_expr.gt_dtype, itir_ts.ListType): + assert len(field_desc.shape) == len(arg_expr.dimensions) + 1 + assert arg_expr.gt_dtype.offset_type is not None + field_dims = [*arg_expr.dimensions, arg_expr.gt_dtype.offset_type] + else: + assert len(field_desc.shape) == len(arg_expr.dimensions) + field_dims = arg_expr.dimensions + field_subset = sbs.Range( (arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr] if dim in arg_expr.indices else (0, size - 1, 1) - for dim, size in zip(arg_expr.dimensions, field_desc.shape) + for dim, size in zip(field_dims, field_desc.shape) ) - return MemletExpr(arg_expr.field, field_subset, arg_expr.local_offset) + return MemletExpr(arg_expr.field, arg_expr.gt_dtype, field_subset) - else: - # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, - # either indirection through connectivity table or dynamic cartesian offset. - assert all(dim in arg_expr.indices for dim in arg_expr.dimensions) - field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions] - index_connectors = [ - IndexConnectorFmt.format(dim=dim.value) - for dim, index in field_indices - if not isinstance(index, SymbolExpr) - ] - # here `internals` refer to the names used as index in the tasklet code string: - # an index can be either a connector name (for dynamic/indirect indices) - # or a symbol value (for literal values and scalar arguments). - index_internals = ",".join( - str(index.value) - if isinstance(index, SymbolExpr) - else IndexConnectorFmt.format(dim=dim.value) - for dim, index in field_indices - ) - deref_node = self._add_tasklet( - "runtime_deref", - {"field"} | set(index_connectors), - {"val"}, - code=f"val = field[{index_internals}]", - ) - # add new termination point for the field parameter - self._add_input_data_edge( - arg_expr.field, - sbs.Range.from_array(field_desc), - deref_node, - "field", - ) + # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, + # either indirection through connectivity table or dynamic cartesian offset. + assert all(dim in arg_expr.indices for dim in arg_expr.dimensions) + assert len(field_desc.shape) == len(arg_expr.dimensions) + field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions] + index_connectors = [ + IndexConnectorFmt.format(dim=dim.value) + for dim, index in field_indices + if not isinstance(index, SymbolExpr) + ] + # here `internals` refer to the names used as index in the tasklet code string: + # an index can be either a connector name (for dynamic/indirect indices) + # or a symbol value (for literal values and scalar arguments). + index_internals = ",".join( + str(index.value) + if isinstance(index, SymbolExpr) + else IndexConnectorFmt.format(dim=dim.value) + for dim, index in field_indices + ) + deref_node = self._add_tasklet( + "runtime_deref", + {"field"} | set(index_connectors), + {"val"}, + code=f"val = field[{index_internals}]", + ) + # add new termination point for the field parameter + self._add_input_data_edge( + arg_expr.field, + sbs.Range.from_array(field_desc), + deref_node, + "field", + ) - for dim, index_expr in field_indices: - # add termination points for the dynamic iterator indices - deref_connector = IndexConnectorFmt.format(dim=dim.value) - if isinstance(index_expr, MemletExpr): - self._add_input_data_edge( - index_expr.dc_node, - index_expr.subset, - deref_node, - deref_connector, - ) + for dim, index_expr in field_indices: + # add termination points for the dynamic iterator indices + deref_connector = IndexConnectorFmt.format(dim=dim.value) + if isinstance(index_expr, MemletExpr): + self._add_input_data_edge( + index_expr.dc_node, + index_expr.subset, + deref_node, + deref_connector, + ) - elif isinstance(index_expr, ValueExpr): - self._add_edge( - index_expr.dc_node, - None, - deref_node, - deref_connector, - dace.Memlet(data=index_expr.dc_node.data, subset="0"), - ) - else: - assert isinstance(index_expr, SymbolExpr) + elif isinstance(index_expr, ValueExpr): + self._add_edge( + index_expr.dc_node, + None, + deref_node, + deref_connector, + dace.Memlet(data=index_expr.dc_node.data, subset="0"), + ) + else: + assert isinstance(index_expr, SymbolExpr) - return self._construct_tasklet_result( - field_desc.dtype, deref_node, "val", arg_expr.local_offset - ) + return self._construct_tasklet_result(field_desc.dtype, deref_node, "val") def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: - assert len(node.args) == 2 assert isinstance(node.type, itir_ts.ListType) + assert len(node.args) == 2 assert isinstance(node.args[0], gtir.OffsetLiteral) offset = node.args[0].value @@ -543,8 +556,9 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: # to the view nodes. The simplify pass will remove the redundant access nodes. field_slice = self._construct_local_view( MemletExpr( - it.field, - sbs.Range.from_string( + dc_node=it.field, + gt_dtype=node.type, + subset=sbs.Range.from_string( ",".join( it.indices[dim].value # type: ignore[union-attr] if dim != offset_provider.neighbor_axis @@ -556,8 +570,11 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: ) connectivity_slice = self._construct_local_view( MemletExpr( - self.state.add_access(connectivity), - sbs.Range.from_string(f"{origin_index.value}, 0:{offset_provider.max_neighbors}"), + dc_node=self.state.add_access(connectivity), + gt_dtype=node.type, + subset=sbs.Range.from_string( + f"{origin_index.value}, 0:{offset_provider.max_neighbors}" + ), ) ) @@ -565,8 +582,8 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: (offset_provider.max_neighbors,), field_desc.dtype ) neighbors_node = self.state.add_access(neighbors_temp) - - neighbor_idx = dace_gtir_utils.get_map_variable(offset) + offset_type = gtx_common.Dimension(offset, gtx_common.DimensionKind.LOCAL) + neighbor_idx = dace_gtir_utils.get_map_variable(offset_type) index_connector = "__index" output_connector = "__val" @@ -604,7 +621,9 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: external_edges=True, ) - return ValueExpr(neighbors_node, node.type, offset) + return ValueExpr( + dc_node=neighbors_node, gt_dtype=itir_ts.ListType(node.type.element_type, offset_type) + ) def _visit_map(self, node: gtir.FunCall) -> ValueExpr: """ @@ -629,8 +648,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(node.type.element_type, ts.ScalarType) dc_dtype = dace_utils.as_dace_type(node.type.element_type) - input_args = [self.visit(arg) for arg in node.args] - input_connectors = [f"__arg{i}" for i in range(len(input_args))] + input_connectors = [f"__arg{i}" for i in range(len(node.args))] output_connector = "__out" # Here we build the body of the tasklet @@ -638,27 +656,37 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: fun_python_code = gtir_python_codegen.get_source(fun_node) tasklet_expression = f"{output_connector} = {fun_python_code}" - input_local_offsets = [ - input_arg.local_offset for input_arg in input_args if input_arg.local_offset is not None - ] - if len(input_local_offsets) == 0: + input_args = [self.visit(arg) for arg in node.args] + input_connectivities: dict[gtx_common.Dimension, gtx_common.Connectivity] = {} + for input_arg in input_args: + assert isinstance(input_arg.gt_dtype, itir_ts.ListType) + assert input_arg.gt_dtype.offset_type is not None + offset_type = input_arg.gt_dtype.offset_type + if offset_type == _CONST_DIM: + # this input argument is the result of `make_const_list` + continue + offset_provider = self.subgraph_builder.get_offset_provider(offset_type.value) + assert isinstance(offset_provider, gtx_common.Connectivity) + input_connectivities[offset_type] = offset_provider + + if len(input_connectivities) == 0: raise ValueError(f"Missing information on local dimension for map node {node}.") # GT4Py guarantees that all connectivities used to generate lists of neighbors # have the same length, that is the same value of 'max_neighbors'. - local_connectivities = dace_utils.filter_connectivities( - { - offset: self.subgraph_builder.get_offset_provider(offset) - for offset in input_local_offsets - } - ) - if len(set(table.max_neighbors for table in local_connectivities.values())) != 1: - raise ValueError( - "Unexpected arguments to map expression with different local dimensions." + if ( + len( + set( + (conn.has_skip_values, conn.max_neighbors) + for conn in input_connectivities.values() + ) ) - local_offset, offset_provider = next(iter(local_connectivities.items())) + != 1 + ): + raise ValueError("Unexpected arguments to map expression with different neighborhood.") + offset_type, offset_provider = next(iter(input_connectivities.items())) local_size = offset_provider.max_neighbors - map_index = dace_gtir_utils.get_map_variable(local_offset) + map_index = dace_gtir_utils.get_map_variable(offset_type) # The dataflow we build in this class has some loose connections on input edges. # These edges are described as set of nodes, that will have to be connected to @@ -668,47 +696,31 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: # than representing map-to-map edges (which require memlets with 2 pass-nodes). input_memlets = {} input_nodes = {} - skip_value_connectivities: dict[str, gtx_common.Connectivity] = {} - for conn, input_expr in zip(input_connectors, input_args): - input_node = self._construct_local_view(input_expr).dc_node + for conn, input_arg in zip(input_connectors, input_args): + input_node = self._construct_local_view(input_arg).dc_node input_desc = input_node.desc(self.sdfg) # we assume that there is a single local dimension if len(input_desc.shape) != 1: raise ValueError(f"More than one local dimension in map expression {node}.") input_size = input_desc.shape[0] if input_size == 1: + assert input_arg.gt_dtype.offset_type == _CONST_DIM input_memlets[conn] = dace.Memlet(data=input_node.data, subset="0") - elif input_size != local_size: + elif input_size == local_size: + input_memlets[conn] = dace.Memlet(data=input_node.data, subset=map_index) + else: raise ValueError( f"Argument to map node with local size {input_size}, expected {local_size}." ) - else: - assert input_expr.local_offset - input_memlets[conn] = dace.Memlet(data=input_node.data, subset=map_index) - input_nodes[input_node.data] = input_node result, _ = self.sdfg.add_temp_transient((local_size,), dc_dtype) result_node = self.state.add_access(result) - skip_value_connectivities = { - offset: offset_provider - for offset, offset_provider in local_connectivities.items() - if offset_provider.has_skip_values - } - - if len(skip_value_connectivities) == 0: - result_offset = local_offset - else: - # In case one or more of input expressions contain skip values, we use + if offset_provider.has_skip_values: + # In case the `map_` input expressions contain skip values, we use # the connectivity-based offset provider as mask for map computation. - # Therefore, the result of map computation will also contain skip values. - # GT4Py guarantees that the skip values are placed in the same positions - # for all input expressions. - - result_offset, offset_provider = next(iter(skip_value_connectivities.items())) - - connectivity = dace_utils.connectivity_identifier(result_offset) + connectivity = dace_utils.connectivity_identifier(offset_type.value) connectivity_desc = self.sdfg.arrays[connectivity] connectivity_desc.transient = False @@ -716,8 +728,13 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: connectivity_slice = self._construct_local_view( MemletExpr( - self.state.add_access(connectivity), - sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider.max_neighbors}"), + dc_node=self.state.add_access(connectivity), + gt_dtype=itir_ts.ListType( + element_type=node.type.element_type, offset_type=offset_type + ), + subset=sbs.Range.from_string( + f"{origin_map_index}, 0:{offset_provider.max_neighbors}" + ), ) ) @@ -749,7 +766,10 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: external_edges=True, ) - return ValueExpr(result_node, dc_dtype, result_offset) + return ValueExpr( + dc_node=result_node, + gt_dtype=itir_ts.ListType(node.type.element_type, offset_type), + ) def _make_reduce_with_skip_values( self, @@ -774,8 +794,12 @@ def _make_reduce_with_skip_values( """ origin_map_index = dace_gtir_utils.get_map_variable(offset_provider.origin_axis) - assert input_expr.local_offset is not None - connectivity = dace_utils.connectivity_identifier(input_expr.local_offset) + assert ( + isinstance(input_expr.gt_dtype, itir_ts.ListType) + and input_expr.gt_dtype.offset_type is not None + ) + offset_type = input_expr.gt_dtype.offset_type + connectivity = dace_utils.connectivity_identifier(offset_type.value) connectivity_node = self.state.add_access(connectivity) connectivity_desc = connectivity_node.desc(self.sdfg) connectivity_desc.transient = False @@ -881,8 +905,12 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: input_expr = self.visit(node.args[0]) assert isinstance(input_expr, (MemletExpr, ValueExpr)) - assert input_expr.local_offset is not None - offset_provider = self.subgraph_builder.get_offset_provider(input_expr.local_offset) + assert ( + isinstance(input_expr.gt_dtype, itir_ts.ListType) + and input_expr.gt_dtype.offset_type is not None + ) + offset_type = input_expr.gt_dtype.offset_type + offset_provider = self.subgraph_builder.get_offset_provider(offset_type.value) assert isinstance(offset_provider, gtx_common.Connectivity) if offset_provider.has_skip_values: @@ -998,9 +1026,13 @@ def _make_cartesian_shift( # a new iterator with a shifted index along one dimension return IteratorExpr( - it.field, - it.dimensions, - {dim: (new_index if dim == offset_dim else index) for dim, index in it.indices.items()}, + field=it.field, + gt_dtype=it.gt_dtype, + dimensions=it.dimensions, + indices={ + dim: (new_index if dim == offset_dim else index) + for dim, index in it.indices.items() + }, ) def _make_dynamic_neighbor_offset( @@ -1068,8 +1100,9 @@ def _make_unstructured_shift( if isinstance(offset_expr, SymbolExpr): # use memlet to retrieve the neighbor index shifted_indices[neighbor_dim] = MemletExpr( - offset_table_node, - sbs.Indices([origin_index.value, offset_expr.value]), + dc_node=offset_table_node, + gt_dtype=it.gt_dtype, + subset=sbs.Indices([origin_index.value, offset_expr.value]), ) else: # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node @@ -1077,7 +1110,9 @@ def _make_unstructured_shift( offset_expr, offset_table_node, origin_index ) - return IteratorExpr(it.field, it.dimensions, shifted_indices) + return IteratorExpr( + field=it.field, gt_dtype=it.gt_dtype, dimensions=it.dimensions, indices=shifted_indices + ) def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: # convert builtin-index type to dace type diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index da940e883c..ad8f490f12 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -344,9 +344,7 @@ def make_temps( head_state.add_nedge( field.dc_node, temp_node, sdfg.make_array_memlet(field.dc_node.data) ) - return gtir_builtin_translators.FieldopData( - temp_node, field.gt_dtype, field.local_offset - ) + return gtir_builtin_translators.FieldopData(temp_node, field.gt_type) temp_result = gtx_utils.tree_map(make_temps)(result) return list(gtx_utils.flatten_nested_tuple((temp_result,))) @@ -489,9 +487,9 @@ def visit_SetAt( target_desc = sdfg.arrays[target.dc_node.data] assert not target_desc.transient - if isinstance(target.gt_dtype, ts.FieldType): + if isinstance(target.gt_type, ts.FieldType): subset = ",".join( - f"{domain[dim][0]}:{domain[dim][1]}" for dim in target.gt_dtype.dims + f"{domain[dim][0]}:{domain[dim][1]}" for dim in target.gt_type.dims ) else: assert len(domain) == 0 @@ -582,7 +580,7 @@ def visit_Lambda( sym: self.global_symbols[sym] for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) } | { - pname: dace_gtir_utils.get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_dtype + pname: dace_gtir_utils.get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type for pname, arg in lambda_args_mapping } @@ -742,9 +740,7 @@ def construct_output_for_nested_sdfg( head_state.add_edge( nsdfg_node, connector, outer_node, None, sdfg.make_array_memlet(outer) ) - outer_data = gtir_builtin_translators.FieldopData( - outer_node, inner_data.gt_dtype, inner_data.local_offset - ) + outer_data = gtir_builtin_translators.FieldopData(outer_node, inner_data.gt_type) elif inner_data.dc_node.data in lambda_arg_nodes: # This if branch and the next one handle the non-transient result nodes. # Non-transient nodes are just input nodes that are immediately returned @@ -753,9 +749,7 @@ def construct_output_for_nested_sdfg( outer_data = lambda_arg_nodes[inner_data.dc_node.data] else: outer_node = head_state.add_access(inner_data.dc_node.data) - outer_data = gtir_builtin_translators.FieldopData( - outer_node, inner_data.gt_dtype, inner_data.local_offset - ) + outer_data = gtir_builtin_translators.FieldopData(outer_node, inner_data.gt_type) # Isolated access node will make validation fail. # Isolated access nodes can be found in the join-state of an if-expression # or in lambda expressions that just construct tuples from input arguments. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index baae8a6ccd..caec6cd87e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -19,15 +19,10 @@ from gt4py.next.type_system import type_specifications as ts -def get_map_variable(dim: gtx_common.Dimension | str) -> str: +def get_map_variable(dim: gtx_common.Dimension) -> str: """ Format map variable name based on the naming convention for application-specific SDFG transformations. """ - if not isinstance(dim, gtx_common.Dimension): - if len(dim) != 0: - dim = gtx_common.Dimension(dim, gtx_common.DimensionKind.LOCAL) - else: - raise ValueError("Dimension name cannot be empty.") suffix = "dim" if dim.kind == gtx_common.DimensionKind.LOCAL else "" return f"i_{dim.value}_gtx_{dim.kind}{suffix}" @@ -68,7 +63,7 @@ def get_tuple_type(data: tuple[Any, ...]) -> ts.TupleType: Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. """ return ts.TupleType( - types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_dtype for d in data] + types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index cc72adae4f..a94157ecd2 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1352,7 +1352,9 @@ def test_gtir_reduce_with_skip_values(): e = np.random.rand(SKIP_VALUE_MESH.num_edges) v_ref = [ functools.reduce( - lambda x, y: x + y, [e[i] if i != -1 else 0.0 for i in v2e_neighbors], init_value + lambda x, y: x + y, + [e[i] if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 for i in v2e_neighbors], + init_value, ) for v2e_neighbors in connectivity_V2E.table ] @@ -1394,120 +1396,74 @@ def test_gtir_reduce_dot_product(): init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - # create mesh with skip values - connectivity_V2E_skip = copy.deepcopy(connectivity_V2E) - connectivity_V2E_skip.has_skip_values = True - connectivity_V2E_skip.table = np.asarray( - [ - [x if i != skip_idx else gtx_common._DEFAULT_SKIP_VALUE for i, x in enumerate(row)] - for skip_idx, row in zip( - np.random.randint(0, connectivity_V2E.max_neighbors, size=SIMPLE_MESH.num_vertices), - connectivity_V2E.table, - strict=True, - ) - ], - dtype=connectivity_V2E.table.dtype, - ) - # safety check that the connectivity table actually contains skip values - assert len(np.where(connectivity_V2E.table == gtx_common._DEFAULT_SKIP_VALUE)) != 0 - - offset_provider = SIMPLE_MESH_OFFSET_PROVIDER | { - "V2E_skip": connectivity_V2E_skip, - } - - V2E_SKIP_SYMBOLS = dict( - __connectivity_V2E_skip_size_0=SIMPLE_MESH.num_vertices, - __connectivity_V2E_skip_size_1=connectivity_V2E_skip.max_neighbors, - __connectivity_V2E_skip_stride_0=connectivity_V2E_skip.max_neighbors, - __connectivity_V2E_skip_stride_1=1, - ) - - e = np.random.rand(SIMPLE_MESH.num_edges) - v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) + v2e_field = np.random.rand(SKIP_VALUE_MESH.num_vertices, connectivity_V2E.max_neighbors) + e = np.random.rand(SKIP_VALUE_MESH.num_edges) + v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) v_ref = [ functools.reduce( lambda x, y: x + y, map( lambda x: 0.0 if x[1] == gtx_common._DEFAULT_SKIP_VALUE else x[0], - zip((e[v2e_neighbors] * e[v2e_skip_neighbors]) + 1.0, v2e_skip_neighbors), + zip((e[v2e_neighbors] * v2e_values) + 1.0, v2e_neighbors), ), init_value, ) - for v2e_neighbors, v2e_skip_neighbors in zip( - connectivity_V2E.table, connectivity_V2E_skip.table - ) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field) ] - stencil_inlined = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.map_("plus")( - im.map_("multiplies")( - im.neighbors("V2E", "it"), - im.neighbors("V2E_skip", "it"), + testee = gtir.Program( + id=f"reduce_dot_product", + function_definitions=[], + params=[ + gtir.Sym(id="v2e_field", type=V2E_FTYPE), + gtir.Sym(id="edges", type=EFTYPE), + gtir.Sym(id="vertices", type=VFTYPE), + gtir.Sym(id="nvertices", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")( + im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( + im.deref("it") + ) ), - im.call("make_const_list")(1.0), + vertex_domain, ) - ) - ), - vertex_domain, - ) - )("edges") - - stencil_fieldview = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) - ), - vertex_domain, - ) - )( - im.op_as_fieldop(im.map_("plus"), vertex_domain)( - im.op_as_fieldop(im.map_("multiplies"), vertex_domain)( - im.as_fieldop_neighbors("V2E", "edges", vertex_domain), - im.as_fieldop_neighbors("V2E_skip", "edges", vertex_domain), - ), - im.op_as_fieldop("make_const_list", vertex_domain)(1.0), - ) + )( + im.op_as_fieldop(im.map_("plus"), vertex_domain)( + im.op_as_fieldop(im.map_("multiplies"), vertex_domain)( + im.as_fieldop_neighbors("V2E", "edges", vertex_domain), + "v2e_field", + ), + im.op_as_fieldop("make_const_list", vertex_domain)(1.0), + ) + ), + domain=vertex_domain, + target=gtir.SymRef(id="vertices"), + ) + ], ) - for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): - testee = gtir.Program( - id=f"reduce_dot_product_{i}", - function_definitions=[], - params=[ - gtir.Sym(id="edges", type=EFTYPE), - gtir.Sym(id="vertices", type=VFTYPE), - gtir.Sym(id="nvertices", type=SIZE_TYPE), - ], - declarations=[], - body=[ - gtir.SetAt( - expr=stencil, - domain=vertex_domain, - target=gtir.SymRef(id="vertices"), - ) - ], - ) - - sdfg = dace_backend.build_sdfg_from_gtir(testee, offset_provider) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) - sdfg( - e, - v, - connectivity_V2E=connectivity_V2E.table, - connectivity_V2E_skip=connectivity_V2E_skip.table, - **FSYMBOLS, - **make_mesh_symbols(SIMPLE_MESH), - **V2E_SKIP_SYMBOLS, - ) - assert np.allclose(v, v_ref) + sdfg( + v2e_field, + e, + v, + connectivity_V2E=connectivity_V2E.table, + **make_mesh_symbols(SKIP_VALUE_MESH), + __v2e_field_size_0=SKIP_VALUE_MESH.num_vertices, + __v2e_field_size_1=connectivity_V2E.max_neighbors, + __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_stride_1=1, + ) + assert np.allclose(v, v_ref) def test_gtir_reduce_with_cond_neighbors(): @@ -1518,6 +1474,7 @@ def test_gtir_reduce_with_cond_neighbors(): function_definitions=[], params=[ gtir.Sym(id="pred", type=ts.ScalarType(ts.ScalarKind.BOOL)), + gtir.Sym(id="v2e_field", type=V2E_FTYPE), gtir.Sym(id="edges", type=EFTYPE), gtir.Sym(id="vertices", type=VFTYPE), gtir.Sym(id="nvertices", type=SIZE_TYPE), @@ -1535,7 +1492,7 @@ def test_gtir_reduce_with_cond_neighbors(): )( im.if_( "pred", - im.as_fieldop_neighbors("V2E_FULL", "edges", vertex_domain), + "v2e_field", im.as_fieldop_neighbors("V2E", "edges", vertex_domain), ) ), @@ -1545,49 +1502,45 @@ def test_gtir_reduce_with_cond_neighbors(): ], ) - connectivity_V2E_simple = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] - assert isinstance(connectivity_V2E_simple, gtx_common.NeighborTable) - connectivity_V2E_skip_values = copy.deepcopy(SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"]) - assert isinstance(connectivity_V2E_skip_values, gtx_common.NeighborTable) - assert SKIP_VALUE_MESH.num_vertices <= SIMPLE_MESH.num_vertices - connectivity_V2E_skip_values.table = np.concatenate( - ( - connectivity_V2E_skip_values.table[:, 0 : connectivity_V2E_simple.max_neighbors], - connectivity_V2E_simple.table[SKIP_VALUE_MESH.num_vertices :, :], - ), - axis=0, - ) - connectivity_V2E_skip_values.max_neighbors = connectivity_V2E_simple.max_neighbors + connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] + assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - e = np.random.rand(SIMPLE_MESH.num_edges) + v2e_field = np.random.rand(SKIP_VALUE_MESH.num_vertices, connectivity_V2E.max_neighbors) + e = np.random.rand(SKIP_VALUE_MESH.num_edges) - for use_full in [False, True]: - sdfg = dace_backend.build_sdfg_from_gtir( - testee, - SIMPLE_MESH_OFFSET_PROVIDER | {"V2E_FULL": connectivity_V2E_skip_values}, - ) + for use_sparse in [False, True]: + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) - v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) + v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) v_ref = [ functools.reduce( - lambda x, y: x + y, [e[i] if i != -1 else 0.0 for i in v2e_neighbors], init_value + lambda x, y: x + y, + [ + v if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 + for i, v in zip(v2e_neighbors, v2e_values, strict=True) + ], + init_value, ) - for v2e_neighbors in ( - connectivity_V2E_simple.table if use_full else connectivity_V2E_skip_values.table + if use_sparse + else functools.reduce( + lambda x, y: x + y, + [e[i] if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 for i in v2e_neighbors], + init_value, ) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field, strict=True) ] sdfg( - np.bool_(use_full), + np.bool_(use_sparse), + v2e_field, e, v, - connectivity_V2E=connectivity_V2E_skip_values.table, - connectivity_V2E_FULL=connectivity_V2E_simple.table, + connectivity_V2E=connectivity_V2E.table, **FSYMBOLS, - **make_mesh_symbols(SIMPLE_MESH), - __connectivity_V2E_FULL_size_0=SIMPLE_MESH.num_edges, - __connectivity_V2E_FULL_size_1=connectivity_V2E_skip_values.max_neighbors, - __connectivity_V2E_FULL_stride_0=connectivity_V2E_skip_values.max_neighbors, - __connectivity_V2E_FULL_stride_1=1, + **make_mesh_symbols(SKIP_VALUE_MESH), + __v2e_field_size_0=SKIP_VALUE_MESH.num_vertices, + __v2e_field_size_1=connectivity_V2E.max_neighbors, + __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) From a9a99928c1b1ba5c05234e45e36ccf0ac7c79214 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 18 Nov 2024 15:21:01 +0100 Subject: [PATCH 046/178] feat[next]: Upgrade dace dependency to v1.0.0 (#1740) DaCe version upgraded to `1.0.0`. It is also constrained to `< 1.1.0 `because the plan for DaCe v1.x is to introduce some breaking changes. GPU tests still fail with GTIR DaCe backend (`test_double_use_scalar`) so they will be enabled in a separate PR. Additional changes: - Removed limitation on SymPy version since DaCe is now compatible with SymPy v1.13 - CUDA version upgraded from 11.2 to 11.4 to avoid this compile error in gpu build: `dace/codegen/../runtime/include/dace/math.h(499): error: A __device__ variable cannot be marked constexpr` --- .pre-commit-config.yaml | 2 +- ci/cscs-ci.yml | 2 +- constraints.txt | 13 ++++++------- min-extra-requirements-test.txt | 3 +-- pyproject.toml | 2 +- requirements-dev.txt | 13 ++++++------- .../dace_fieldview/transformations/loop_blocking.py | 4 ---- 7 files changed, 16 insertions(+), 23 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 07f75177ea..1c3b6e693f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,7 +50,7 @@ repos: ## version = re.search('ruff==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: v{version}") ##]]] - rev: v0.7.3 + rev: v0.7.4 ##[[[end]]] hooks: # Run the linter. diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index e2833e3cd9..7adb88459e 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -46,7 +46,7 @@ stages: .build_baseimage_x86_64: extends: [.container-builder-cscs-zen2, .build_baseimage] variables: - CUDA_VERSION: 11.2.2 + CUDA_VERSION: 11.4.3 CUPY_PACKAGE: cupy-cuda11x CUPY_VERSION: 12.3.0 # latest version that supports cuda 11 UBUNTU_VERSION: 20.04 # 22.04 hangs on daint in some tests for unknown reasons. diff --git a/constraints.txt b/constraints.txt index 4247f4951d..b4b8bc00d4 100644 --- a/constraints.txt +++ b/constraints.txt @@ -33,7 +33,7 @@ contourpy==1.1.1 # via matplotlib coverage==7.6.1 # via -r requirements-dev.in, pytest-cov cycler==0.12.1 # via matplotlib cytoolz==1.0.0 # via gt4py (pyproject.toml) -dace==0.16.1 # via gt4py (pyproject.toml) +dace==1.0.0 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in debugpy==1.8.8 # via ipykernel decorator==5.1.1 # via ipython @@ -50,7 +50,7 @@ factory-boy==3.3.1 # via gt4py (pyproject.toml), pytest-factoryboy faker==33.0.0 # via factory-boy fastjsonschema==2.20.0 # via nbformat filelock==3.16.1 # via tox, virtualenv -fonttools==4.54.1 # via matplotlib +fonttools==4.55.0 # via matplotlib fparser==0.1.4 # via dace frozendict==2.4.6 # via gt4py (pyproject.toml) gitdb==4.0.11 # via gitpython @@ -67,7 +67,7 @@ iniconfig==2.0.0 # via pytest ipykernel==6.29.5 # via nbmake ipython==8.12.3 # via ipykernel jedi==0.19.2 # via ipython -jinja2==3.1.4 # via dace, gt4py (pyproject.toml), sphinx +jinja2==3.1.4 # via gt4py (pyproject.toml), sphinx jsonschema==4.23.0 # via nbformat jsonschema-specifications==2023.12.1 # via jsonschema jupyter-client==8.6.3 # via ipykernel, nbclient @@ -95,7 +95,7 @@ ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.9.1 # via pre-commit numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), matplotlib, scipy orderly-set==5.2.2 # via deepdiff -packaging==24.2 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox +packaging==24.2 # via black, build, dace, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox parso==0.8.4 # via jedi pathspec==0.12.1 # via black pexpect==4.9.0 # via ipython @@ -139,7 +139,7 @@ requests==2.32.3 # via sphinx rich==13.9.4 # via bump-my-version, rich-click, tach rich-click==1.8.4 # via bump-my-version rpds-py==0.20.1 # via jsonschema, referencing -ruff==0.7.3 # via -r requirements-dev.in +ruff==0.7.4 # via -r requirements-dev.in scipy==1.10.1 # via gt4py (pyproject.toml) setuptools-scm==8.1.0 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil @@ -157,7 +157,7 @@ sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx stack-data==0.6.3 # via ipython stdlib-list==0.10.0 # via tach -sympy==1.12.1 # via dace, gt4py (pyproject.toml) +sympy==1.13.3 # via dace tabulate==0.9.0 # via gt4py (pyproject.toml) tach==0.14.3 # via -r requirements-dev.in tomli==2.1.0 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox @@ -173,7 +173,6 @@ urllib3==2.2.3 # via requests virtualenv==20.27.1 # via pre-commit, tox wcmatch==10.0 # via bump-my-version wcwidth==0.2.13 # via prompt-toolkit -websockets==13.1 # via dace wheel==0.45.0 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.20.2 # via importlib-metadata, importlib-resources diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 4190570105..57c0d3969d 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -61,7 +61,7 @@ cmake==3.22 cogapp==3.3 coverage[toml]==5.0 cytoolz==0.12.1 -dace==0.16.1 +dace==1.0.0 darglint==1.6 deepdiff==5.6.0 devtools==0.6 @@ -101,7 +101,6 @@ scipy==1.9.2 setuptools==65.5.0 sphinx==4.4 sphinx_rtd_theme==1.0 -sympy==1.9 tabulate==0.8.10 tach==0.10.7 tomli==2.0.1; python_version < "3.11" diff --git a/pyproject.toml b/pyproject.toml index 1504c8b17b..02d301957c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ all-cuda12 = ['gt4py[cuda12,dace,formatting,jax-cuda12,performance,testing]'] # Other extras cuda11 = ['cupy-cuda11x>=12.0'] cuda12 = ['cupy-cuda12x>=12.0'] -dace = ['dace>=0.16.1', 'sympy>=1.9,<1.13'] # see https://github.com/spcl/dace/pull/1620 +dace = ['dace>=1.0.0,<1.1.0'] # v1.x will contain breaking changes, see https://github.com/spcl/dace/milestone/4 formatting = ['clang-format>=9.0'] gpu = ['cupy>=12.0'] jax-cpu = ['jax[cpu]>=0.4.18; python_version>="3.10"'] diff --git a/requirements-dev.txt b/requirements-dev.txt index ca7eb32487..9f95779fd5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -33,7 +33,7 @@ contourpy==1.1.1 # via -c constraints.txt, matplotlib coverage[toml]==7.6.1 # via -c constraints.txt, -r requirements-dev.in, pytest-cov cycler==0.12.1 # via -c constraints.txt, matplotlib cytoolz==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) -dace==0.16.1 # via -c constraints.txt, gt4py (pyproject.toml) +dace==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) darglint==1.8.1 # via -c constraints.txt, -r requirements-dev.in debugpy==1.8.8 # via -c constraints.txt, ipykernel decorator==5.1.1 # via -c constraints.txt, ipython @@ -50,7 +50,7 @@ factory-boy==3.3.1 # via -c constraints.txt, gt4py (pyproject.toml), pyte faker==33.0.0 # via -c constraints.txt, factory-boy fastjsonschema==2.20.0 # via -c constraints.txt, nbformat filelock==3.16.1 # via -c constraints.txt, tox, virtualenv -fonttools==4.54.1 # via -c constraints.txt, matplotlib +fonttools==4.55.0 # via -c constraints.txt, matplotlib fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython @@ -67,7 +67,7 @@ iniconfig==2.0.0 # via -c constraints.txt, pytest ipykernel==6.29.5 # via -c constraints.txt, nbmake ipython==8.12.3 # via -c constraints.txt, ipykernel jedi==0.19.2 # via -c constraints.txt, ipython -jinja2==3.1.4 # via -c constraints.txt, dace, gt4py (pyproject.toml), sphinx +jinja2==3.1.4 # via -c constraints.txt, gt4py (pyproject.toml), sphinx jsonschema==4.23.0 # via -c constraints.txt, nbformat jsonschema-specifications==2023.12.1 # via -c constraints.txt, jsonschema jupyter-client==8.6.3 # via -c constraints.txt, ipykernel, nbclient @@ -95,7 +95,7 @@ ninja==1.11.1.1 # via -c constraints.txt, gt4py (pyproject.toml) nodeenv==1.9.1 # via -c constraints.txt, pre-commit numpy==1.24.4 # via -c constraints.txt, contourpy, dace, gt4py (pyproject.toml), matplotlib orderly-set==5.2.2 # via -c constraints.txt, deepdiff -packaging==24.2 # via -c constraints.txt, black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox +packaging==24.2 # via -c constraints.txt, black, build, dace, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox parso==0.8.4 # via -c constraints.txt, jedi pathspec==0.12.1 # via -c constraints.txt, black pexpect==4.9.0 # via -c constraints.txt, ipython @@ -139,7 +139,7 @@ requests==2.32.3 # via -c constraints.txt, sphinx rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach rich-click==1.8.4 # via -c constraints.txt, bump-my-version rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing -ruff==0.7.3 # via -c constraints.txt, -r requirements-dev.in +ruff==0.7.4 # via -c constraints.txt, -r requirements-dev.in setuptools-scm==8.1.0 # via -c constraints.txt, fparser six==1.16.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil smmap==5.0.1 # via -c constraints.txt, gitdb @@ -156,7 +156,7 @@ sphinxcontrib-qthelp==1.0.3 # via -c constraints.txt, sphinx sphinxcontrib-serializinghtml==1.1.5 # via -c constraints.txt, sphinx stack-data==0.6.3 # via -c constraints.txt, ipython stdlib-list==0.10.0 # via -c constraints.txt, tach -sympy==1.12.1 # via -c constraints.txt, dace, gt4py (pyproject.toml) +sympy==1.13.3 # via -c constraints.txt, dace tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) tach==0.14.3 # via -c constraints.txt, -r requirements-dev.in tomli==2.1.0 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox @@ -172,7 +172,6 @@ urllib3==2.2.3 # via -c constraints.txt, requests virtualenv==20.27.1 # via -c constraints.txt, pre-commit, tox wcmatch==10.0 # via -c constraints.txt, bump-my-version wcwidth==0.2.13 # via -c constraints.txt, prompt-toolkit -websockets==13.1 # via -c constraints.txt, dace wheel==0.45.0 # via -c constraints.txt, astunparse, pip-tools xxhash==3.0.0 # via -c constraints.txt, gt4py (pyproject.toml) zipp==3.20.2 # via -c constraints.txt, importlib-metadata, importlib-resources diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py index 7acd997a0d..d7326e1131 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py @@ -63,16 +63,12 @@ class LoopBlocking(dace_transformation.SingleStateTransformation): dtype=set, allow_none=True, default=None, - optional=True, - optional_condition=lambda _: False, desc="Set of nodes that are independent of the blocking parameter.", ) dependent_nodes = dace_properties.Property( dtype=set, allow_none=True, default=None, - optional=True, - optional_condition=lambda _: False, desc="Set of nodes that are dependent on the blocking parameter.", ) From 9dbc8842a2e6e16855da030934f7aecc23f8417b Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 19 Nov 2024 07:34:23 +0100 Subject: [PATCH 047/178] feat[next]: Enable GPU tests on GTIR DaCe backend (#1741) DaCe v1.0.0 allows to enable GPU tests on the GTIR backend. An issue was found in `test_double_use_scalar`. The dace gpu transformations have a bug that produces invalid code for SDFGs containing scalar expressions outside the field operator. A workaround is to run the simplify pass in order to bring the SDFG to a canonical form. The changes in test code (`test_execution.py`) are pure cleanup. --- .../program_processors/runners/dace_common/utility.py | 2 +- .../runners/dace_fieldview/workflow.py | 7 ++++++- tests/next_tests/definitions.py | 6 +----- .../feature_tests/ffront_tests/test_execution.py | 11 +---------- 4 files changed, 9 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index d678fdab7f..bc01e2abda 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -19,7 +19,7 @@ # regex to match the symbols for field shape and strides -FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile("__.+_(size|stride)_\d+") +FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile(r"__.+_(size|stride)_\d+") def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 85ae95c432..aa4fd0cd3e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -57,7 +57,12 @@ def generate_sdfg( if auto_opt: gtx_transformations.gt_auto_optimize(sdfg, gpu=on_gpu) elif on_gpu: - gtx_transformations.gt_gpu_transformation(sdfg, try_removing_trivial_maps=False) + # We run simplify to bring the SDFG into a canonical form that the gpu transformations + # can handle. This is a workaround for an issue with scalar expressions that are + # promoted to symbolic expressions and computed on the host (CPU), but the intermediate + # result is written to a GPU global variable (https://github.com/spcl/dace/issues/1773). + gtx_transformations.gt_simplify(sdfg) + gtx_transformations.gt_gpu_transformation(sdfg, try_removing_trivial_maps=True) return sdfg diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index c86ba88ead..01fd18897d 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -193,11 +193,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.GTIR_DACE_CPU: GTIR_DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.GTIR_DACE_GPU: GTIR_DACE_SKIP_TEST_LIST - + [ - # TODO(edopao): Enable when GPU codegen issues related to symbolic domain are fixed. - (ALL, XFAIL, UNSUPPORTED_MESSAGE), - ], + OptionalProgramBackendId.GTIR_DACE_GPU: GTIR_DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index f10f195d3a..a5453151e6 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -7,21 +7,14 @@ # SPDX-License-Identifier: BSD-3-Clause from functools import reduce -from gt4py.next.otf import languages, stages, workflow -from gt4py.next.otf.binding import interface import numpy as np import pytest -import diskcache -from gt4py.eve import SymbolName - import gt4py.next as gtx from gt4py.next import ( astype, broadcast, common, - constructors, errors, - field_utils, float32, float64, int32, @@ -30,8 +23,6 @@ neighbor_sum, ) from gt4py.next.ffront.experimental import as_offset -from gt4py.next.program_processors.runners import gtfn -from gt4py.next.type_system import type_specifications as ts from gt4py.next import utils as gt_utils from next_tests.integration_tests import cases @@ -306,7 +297,7 @@ def test_double_use_scalar(cartesian_case): # TODO(tehrengruber): This should be a regression test on ITIR level, but tracing doesn't # work for this case. @gtx.field_operator - def testee(a: np.int32, b: np.int32, c: cases.IField) -> cases.IField: + def testee(a: int32, b: int32, c: cases.IField) -> cases.IField: tmp = a * b tmp2 = tmp * tmp # important part here is that we use the intermediate twice so that it is From 5e937363ce3427f001addecf81815894ec3b9941 Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 20 Nov 2024 13:47:08 +0100 Subject: [PATCH 048/178] feat[next]: Extend the IR pass for pruning of unnecessary casts (#1728) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extend the IR pass delivered in #1688 Pruning of cast expressions may appear as a `as_fieldop` expression with form `(⇑(λ(__val) → cast_(·__val, float64)))(a)`, where `a` is already a field with data type `float64` in this example. This PR adds pruning of such trivial expressions. --- src/gt4py/next/ffront/foast_to_gtir.py | 4 +-- .../ir_utils/common_pattern_matcher.py | 26 ++++++++++++++++ src/gt4py/next/iterator/ir_utils/ir_makers.py | 22 +++++++++++++ .../next/iterator/transforms/prune_casts.py | 31 +++++++++++-------- .../ffront_tests/test_foast_to_gtir.py | 28 +++++------------ .../transforms_tests/test_prune_casts.py | 19 ++++++++++++ .../dace_tests/test_gtir_to_sdfg.py | 4 +-- 7 files changed, 94 insertions(+), 40 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 6cf4cc67fd..2c2971f49a 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -360,9 +360,7 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall: if isinstance(t[0], ts.FieldType): - return im.as_fieldop( - im.lambda_("__val")(im.call("cast_")(im.deref("__val"), str(new_type))) - )(expr) + return im.cast_as_fieldop(str(new_type))(expr) else: assert isinstance(t[0], ts.ScalarType) return im.call("cast_")(expr, str(new_type)) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 16a88b282a..9df091ac2a 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -10,6 +10,7 @@ from typing import TypeGuard from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: @@ -84,3 +85,28 @@ def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunC def is_ref_to(node, ref: str): return isinstance(node, itir.SymRef) and node.id == ref + + +def is_identity_as_fieldop(node: itir.Expr): + """ + Match field operators implementing element-wise copy of a field argument, + that is expressions of the form `as_fieldop(stencil)(*args)` + + >>> from gt4py.next.iterator.ir_utils import ir_makers as im + >>> node = im.as_fieldop(im.lambda_("__arg0")(im.deref("__arg0")))("a") + >>> is_identity_as_fieldop(node) + True + >>> node = im.as_fieldop("deref")("a") + >>> is_identity_as_fieldop(node) + False + """ + if not is_applied_as_fieldop(node): + return False + stencil = node.fun.args[0] # type: ignore[attr-defined] + if ( + isinstance(stencil, itir.Lambda) + and len(stencil.params) == 1 + and stencil == im.lambda_(stencil.params[0])(im.deref(stencil.params[0].id)) + ): + return True + return False diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index d7a66b8285..2864c7f727 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -497,6 +497,28 @@ def _impl(*its: itir.Expr) -> itir.FunCall: return _impl +def cast_as_fieldop(type_: str, domain: Optional[itir.FunCall] = None): + """ + Promotes the function `cast_` to a field_operator. + + Args: + type_: the target type to be passed as argument to `cast_` function. + domain: the domain of the returned field. + + Returns: + A function from Fields to Field. + + Examples: + >>> str(cast_as_fieldop("float32")("a")) + '(⇑(λ(__arg0) → cast_(·__arg0, float32)))(a)' + """ + + def _impl(it: itir.Expr) -> itir.FunCall: + return op_as_fieldop(lambda v: call("cast_")(v, type_), domain)(it) + + return _impl + + def map_(op): """Create a `map_` call.""" return call(call("map_")(op)) diff --git a/src/gt4py/next/iterator/transforms/prune_casts.py b/src/gt4py/next/iterator/transforms/prune_casts.py index 0720394db5..c825f68a5f 100644 --- a/src/gt4py/next/iterator/transforms/prune_casts.py +++ b/src/gt4py/next/iterator/transforms/prune_casts.py @@ -6,13 +6,13 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py import eve from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.type_system import type_specifications as ts -class PruneCasts(PreserveLocationVisitor, NodeTranslator): +class PruneCasts(eve.NodeTranslator): """ Removes cast expressions where the argument is already in the target type. @@ -20,23 +20,28 @@ class PruneCasts(PreserveLocationVisitor, NodeTranslator): therefore it should be applied after type-inference. """ + PRESERVED_ANNEX_ATTRS = ("domain",) + def visit_FunCall(self, node: ir.FunCall) -> ir.Node: node = self.generic_visit(node) - if not cpm.is_call_to(node, "cast_"): - return node + if cpm.is_call_to(node, "cast_"): + value, type_constructor = node.args - value, type_constructor = node.args + assert ( + value.type + and isinstance(type_constructor, ir.SymRef) + and (type_constructor.id in ir.TYPEBUILTINS) + ) + dtype = ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper())) - assert ( - value.type - and isinstance(type_constructor, ir.SymRef) - and (type_constructor.id in ir.TYPEBUILTINS) - ) - dtype = ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper())) + if value.type == dtype: + return value - if value.type == dtype: - return value + elif cpm.is_identity_as_fieldop(node): + # pruning of cast expressions may leave some trivial `as_fieldop` expressions + # with form '(⇑(λ(__arg) → ·__arg))(a)' + return node.args[0] return node diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 4a1a7cba8e..516890ea46 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -284,9 +284,7 @@ def foo(a: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - "a" - ) + reference = im.cast_as_fieldop("int32")("a") assert lowered.expr == reference @@ -312,12 +310,8 @@ def foo(a: tuple[gtx.Field[[TDim], float64], gtx.Field[[TDim], float64]]): lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) reference = im.make_tuple( - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(0, "a") - ), - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(1, "a") - ), + im.cast_as_fieldop("int32")(im.tuple_get(0, "a")), + im.cast_as_fieldop("int32")(im.tuple_get(1, "a")), ) assert lowered_inlined.expr == reference @@ -332,9 +326,7 @@ def foo(a: tuple[gtx.Field[[TDim], float64], float64]): lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) reference = im.make_tuple( - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(0, "a") - ), + im.cast_as_fieldop("int32")(im.tuple_get(0, "a")), im.call("cast_")(im.tuple_get(1, "a"), "int32"), ) @@ -356,16 +348,10 @@ def foo( reference = im.make_tuple( im.make_tuple( - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(0, im.tuple_get(0, "a")) - ), - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(1, im.tuple_get(0, "a")) - ), - ), - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(1, "a") + im.cast_as_fieldop("int32")(im.tuple_get(0, im.tuple_get(0, "a"))), + im.cast_as_fieldop("int32")(im.tuple_get(1, im.tuple_get(0, "a"))), ), + im.cast_as_fieldop("int32")(im.tuple_get(1, "a")), ) assert lowered_inlined.expr == reference diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py index 462eed8408..7c991fb9a8 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from gt4py import next as gtx from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.prune_casts import PruneCasts @@ -21,3 +22,21 @@ def test_prune_casts_simple(): expected = im.call("plus")(im.call("cast_")(x_ref, "float64"), y_ref) actual = PruneCasts.apply(testee) assert actual == expected + + +def test_prune_casts_fieldop(): + IDim = gtx.Dimension("IDim") + x_ref = im.ref("x", ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))) + y_ref = im.ref("y", ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64))) + testee = im.op_as_fieldop("plus")( + im.cast_as_fieldop("float64")(x_ref), + im.cast_as_fieldop("float64")(y_ref), + ) + testee = type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + + expected = im.op_as_fieldop("plus")( + im.cast_as_fieldop("float64")(x_ref), + y_ref, + ) + actual = PruneCasts.apply(testee) + assert actual == expected diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index a94157ecd2..e0c0c3fa4e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -146,9 +146,7 @@ def test_gtir_cast(): body=[ gtir.SetAt( expr=im.op_as_fieldop("eq", domain)( - im.as_fieldop( - im.lambda_("a")(im.call("cast_")(im.deref("a"), "float32")), domain - )("x"), + im.cast_as_fieldop("float32", domain)("x"), "y", ), domain=domain, From 0a01597d0bcd5b1288ff9d42293fc1225738e977 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 21 Nov 2024 14:21:26 +0100 Subject: [PATCH 049/178] bug[next]: extract scalar value with correct dtype (#1723) credits to @egparedes for this pattern and realizing that `item()` decays to a python type. --- src/gt4py/next/embedded/nd_array_field.py | 3 ++- .../runners/dace_common/dace_backend.py | 6 +----- .../unit_tests/embedded_tests/test_nd_array_field.py | 10 ++++++++++ 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 655a1137e8..9ff5feaaee 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -148,7 +148,8 @@ def as_scalar(self) -> core_defs.ScalarT: raise ValueError( f"'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'." ) - return self.ndarray.item() + # note: `.item()` will return a Python type, therefore we use indexing with an empty tuple + return self.asnumpy()[()] # type: ignore[return-value] # should be ensured by the 0-d check @property def codomain(self) -> type[core_defs.ScalarT]: diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index bbf45a822c..db0df7d121 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -28,11 +28,7 @@ def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: return arg if len(arg.domain.dims) == 0: # Pass zero-dimensional fields as scalars. - # We need to extract the scalar value from the 0d numpy array without changing its type. - # Note that 'ndarray.item()' always transforms the numpy scalar to a python scalar, - # which may change its precision. To avoid this, we use here the empty tuple as index - # for 'ndarray.__getitem__()'. - return arg.asnumpy()[()] + return arg.as_scalar() # field domain offsets are not supported non_zero_offsets = [ (dim, dim_range) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 9fba633cba..063e79d92e 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -264,6 +264,16 @@ def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expecte assert np.allclose(op_result.ndarray, expected_result) +def test_as_scalar(nd_array_implementation): + testee = common._field( + nd_array_implementation.asarray(42.0, dtype=np.float32), domain=common.Domain() + ) + + result = testee.as_scalar() + assert result == 42.0 + assert isinstance(result, np.float32) + + def product_nd_array_implementation_params(): for xp1 in nd_array_field._nd_array_implementations: for xp2 in nd_array_field._nd_array_implementations: From 1cb29e3ce7f24954a14054be51d375f9851d533c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 25 Nov 2024 11:14:53 +0100 Subject: [PATCH 050/178] build: add devcontainer setup (#1725) Add devcontainer configuration with special customizations for VS Code. --------- Co-authored-by: Enrique Gonzalez Paredes --- .devcontainer/.vscode/launch.json | 24 +++++++++++++++ .devcontainer/Dockerfile | 5 ++++ .devcontainer/devcontainer.json | 49 +++++++++++++++++++++++++++++++ .devcontainer/setup.sh | 10 +++++++ .gitignore | 2 +- 5 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 .devcontainer/.vscode/launch.json create mode 100644 .devcontainer/Dockerfile create mode 100644 .devcontainer/devcontainer.json create mode 100755 .devcontainer/setup.sh diff --git a/.devcontainer/.vscode/launch.json b/.devcontainer/.vscode/launch.json new file mode 100644 index 0000000000..f682b56388 --- /dev/null +++ b/.devcontainer/.vscode/launch.json @@ -0,0 +1,24 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File (just my code)", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": true + }, + { + "name": "Python: Current File (all)", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": false + } + ] +} diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000000..414f2d0292 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,5 @@ +FROM mcr.microsoft.com/devcontainers/python:1-3.10-bookworm +RUN apt-get update \ + && export DEBIAN_FRONTEND=noninteractive && apt-get install -y libboost-dev \ + && apt-get clean && rm -rf /var/cache/apt/* && rm -rf /var/lib/apt/lists/* && rm -rf /tmp/* +RUN curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR="/bin" sh diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000000..7dc4b2f08c --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,49 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/python +{ + "name": "Python 3", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "build": { + "dockerfile": "Dockerfile" + }, + + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + + // Use 'postCreateCommand' to run commands after the container is created. + "postCreateCommand": "bash .devcontainer/setup.sh", + + "containerEnv": { + "PRE_COMMIT_HOME": "/workspaces/gt4py/.caches/pre-commit" + }, + + // Configure tool-specific properties. + "customizations": { + // Configure properties specific to VS Code. + "vscode": { + // Set *default* container specific settings.json values on container create. + "settings": { + "python.formatting.provider": "ruff", + "python.testing.pytestEnabled": true, + "python.defaultInterpreterPath": "/workspaces/gt4py/.venv/bin/python", + "files.insertFinalNewline": true, + "python.terminal.activateEnvironment": true, + "cmake.ignoreCMakeListsMissing": true + }, + "extensions": [ + "charliermarsh.ruff", + "donjayamanne.githistory", + "github.vscode-github-actions", + "lextudio.restructuredtext", + "ms-python.python", + "ms-vsliveshare.vsliveshare", + "swyddfa.esbonio" + ] + } + } + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" +} diff --git a/.devcontainer/setup.sh b/.devcontainer/setup.sh new file mode 100755 index 0000000000..d23dda9dea --- /dev/null +++ b/.devcontainer/setup.sh @@ -0,0 +1,10 @@ +#!/bin/sh + +ln -sfn /workspaces/gt4py/.devcontainer/.vscode /workspaces/gt4py/.vscode +uv venv .venv +source .venv/bin/activate +uv pip install -r requirements-dev.txt +uv pip install -e . +uv pip install -i https://test.pypi.org/simple/ atlas4py +pre-commit install --install-hooks +deactivate diff --git a/.gitignore b/.gitignore index 5792b8a9b7..b1c8ed26e9 100644 --- a/.gitignore +++ b/.gitignore @@ -159,5 +159,5 @@ venv.bak/ ### Others ### .obsidian - coverage.json +.caches From d7f55522beacfc77c12964f6bbb1962899d8821d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 25 Nov 2024 14:22:24 +0100 Subject: [PATCH 051/178] feat[next]: remove NeighborTableOffsetProvider, use gtx.as_connectivity (#1729) User-facing change: use `gtx.as_connectivity` to create a connectivity/neighbor table instead of `NeighborTableOffsetProvider` which is deprecated (and the backward-compatible mechanism broken for some use-cases). The internal concepts of `Connectivity` and `NeighborTable` are updated. `ConnectivityType` is introduced which contains the compile-time info of a `Connectivity`. See ADR 19. Additionally, the compile-time info is used (instead of the run-time connectivities) in many places of the toolchain when possible. --- .gitpod/.vscode/launch.json | 13 +- .../0008-Mapping_Domain_to_Cpp-Backend.md | 2 +- docs/development/ADRs/0019-Connectivities.md | 55 +++++ docs/user/next/QuickstartGuide.md | 6 +- .../exercises/2_divergence_exercise.ipynb | 4 +- .../2_divergence_exercise_solution.ipynb | 4 +- .../exercises/3_gradient_exercise.ipynb | 4 +- .../3_gradient_exercise_solution.ipynb | 4 +- .../workshop/exercises/4_curl_exercise.ipynb | 4 +- .../exercises/4_curl_exercise_solution.ipynb | 4 +- .../exercises/5_vector_laplace_exercise.ipynb | 10 +- .../5_vector_laplace_exercise_solution.ipynb | 10 +- .../8_diffusion_exercise_solution.ipynb | 8 +- docs/user/next/workshop/slides/slides_2.ipynb | 10 +- src/gt4py/_core/definitions.py | 10 +- src/gt4py/next/__init__.py | 6 +- src/gt4py/next/common.py | 170 ++++++++++---- src/gt4py/next/constructors.py | 24 +- src/gt4py/next/embedded/nd_array_field.py | 35 ++- src/gt4py/next/ffront/decorator.py | 47 ++-- src/gt4py/next/ffront/experimental.py | 2 +- src/gt4py/next/ffront/fbuiltins.py | 30 +-- src/gt4py/next/iterator/embedded.py | 215 +++++++++++------- .../next/iterator/ir_utils/domain_utils.py | 26 +-- src/gt4py/next/iterator/runtime.py | 10 +- .../iterator/transforms/collapse_tuple.py | 6 +- src/gt4py/next/iterator/transforms/cse.py | 6 +- .../iterator/transforms/fuse_as_fieldop.py | 9 +- .../next/iterator/transforms/global_tmps.py | 4 +- .../next/iterator/transforms/inline_scalar.py | 4 +- .../next/iterator/transforms/pass_manager.py | 29 ++- .../transforms/pass_manager_legacy.py | 14 +- .../next/iterator/transforms/unroll_reduce.py | 28 +-- .../next/iterator/type_system/inference.py | 34 +-- .../iterator/type_system/type_synthesizer.py | 48 ++-- src/gt4py/next/otf/arguments.py | 54 +---- .../codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py | 76 +------ .../codegens/gtfn/gtfn_module.py | 47 ++-- .../codegens/gtfn/itir_to_gtfn_ir.py | 31 +-- .../runners/dace_common/dace_backend.py | 21 +- .../runners/dace_common/utility.py | 15 +- .../runners/dace_fieldview/gtir_dataflow.py | 75 +++--- .../runners/dace_fieldview/gtir_sdfg.py | 33 ++- .../runners/dace_fieldview/workflow.py | 6 +- .../runners/dace_iterator/__init__.py | 53 +++-- .../runners/dace_iterator/itir_to_sdfg.py | 45 ++-- .../runners/dace_iterator/itir_to_tasklet.py | 97 ++++---- .../runners/dace_iterator/utility.py | 10 +- .../runners/dace_iterator/workflow.py | 6 +- .../next/program_processors/runners/gtfn.py | 16 +- .../program_processors/runners/roundtrip.py | 16 +- .../next/type_system/type_specifications.py | 1 + .../feature_tests/dace/test_orchestration.py | 86 ++++--- .../ffront_tests/ffront_test_utils.py | 91 +++++--- .../ffront_tests/test_execution.py | 36 +-- .../ffront_tests/test_external_local_field.py | 12 +- .../ffront_tests/test_gt4py_builtins.py | 18 +- .../test_temporaries_with_sizes.py | 2 +- .../iterator_tests/test_builtins.py | 40 +--- .../test_strided_offset_provider.py | 9 +- .../ffront_tests/test_ffront_fvm_nabla.py | 64 +++--- .../multi_feature_tests/fvm_nabla_setup.py | 56 +++-- .../iterator_tests/test_fvm_nabla.py | 114 ++++------ .../test_with_toy_connectivity.py | 38 ++-- tests/next_tests/toy_connectivity.py | 7 + tests/next_tests/unit_tests/conftest.py | 25 +- .../embedded_tests/test_nd_array_field.py | 15 +- .../test_embedded_field_with_list.py | 10 +- .../iterator_tests/test_runtime_domain.py | 10 +- .../iterator_tests/test_type_inference.py | 34 +-- .../transforms_tests/test_cse.py | 14 +- .../transforms_tests/test_domain_inference.py | 13 +- .../transforms_tests/test_fuse_as_fieldop.py | 13 +- .../transforms_tests/test_global_tmps.py | 8 +- .../transforms_tests/test_prune_casts.py | 6 +- .../transforms_tests/test_unroll_reduce.py | 69 ++++-- .../gtfn_tests/test_itir_to_gtfn_ir.py | 4 +- .../runners_tests/dace_tests/test_dace.py | 24 +- .../dace_tests/test_gtir_to_sdfg.py | 134 ++++++----- .../unit_tests/test_constructors.py | 14 +- 80 files changed, 1293 insertions(+), 1170 deletions(-) create mode 100644 docs/development/ADRs/0019-Connectivities.md diff --git a/.gitpod/.vscode/launch.json b/.gitpod/.vscode/launch.json index f682b56388..b25a182648 100644 --- a/.gitpod/.vscode/launch.json +++ b/.gitpod/.vscode/launch.json @@ -6,7 +6,7 @@ "configurations": [ { "name": "Python: Current File (just my code)", - "type": "python", + "type": "debugpy", "request": "launch", "program": "${file}", "console": "integratedTerminal", @@ -14,11 +14,20 @@ }, { "name": "Python: Current File (all)", - "type": "python", + "type": "debugpy", "request": "launch", "program": "${file}", "console": "integratedTerminal", "justMyCode": false + }, + { + "name": "Python: Debug Tests", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "purpose": ["debug-test"], + "console": "integratedTerminal", + "justMyCode": true } ] } diff --git a/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md b/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md index a1ee8575d2..1ce83431ee 100644 --- a/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md +++ b/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md @@ -20,7 +20,7 @@ The Python embedded execution for Iterator IR keeps track of the current locatio ### Python side -On the Python side, we label dimensions of fields with the location type, e.g. `Edge` or `Vertex`. The domain uses `named_ranges` that uses the same location types to express _where_ to iterate, e.g. `named_range(Vertex, range(0, 100))` is an iteration over the `Vertex` dimension, no order in the domain is required. Additionally, the `Connectivity` (aka `NeighborTableOffsetProvider` in the current implementation) describes the mapping between location types. +On the Python side, we label dimensions of fields with the location type, e.g. `Edge` or `Vertex`. The domain uses `named_ranges` that uses the same location types to express _where_ to iterate, e.g. `named_range(Vertex, range(0, 100))` is an iteration over the `Vertex` dimension, no order in the domain is required. Additionally, the `Connectivity` describes the mapping between location types. ### C++ side diff --git a/docs/development/ADRs/0019-Connectivities.md b/docs/development/ADRs/0019-Connectivities.md new file mode 100644 index 0000000000..76e85e49a6 --- /dev/null +++ b/docs/development/ADRs/0019-Connectivities.md @@ -0,0 +1,55 @@ +--- +tags: [] +--- + +# [Connectivities] + +- **Status**: valid +- **Authors**: Hannes Vogt (@havogt) +- **Created**: 2024-11-08 +- **Updated**: 2024-11-08 + +The representation of Connectivities (neighbor tables, `NeighborTableOffsetProvider`) and their identifier (offset tag, `FieldOffset`, etc.) was extended and modified based on the needs of different parts of the toolchain. Here we outline the ideas for consolidating the different closely-related concepts. + +## History + +In the early days of Iterator IR (ITIR), an `offset` was a literal in the IR. Its meaning was only provided at execution time by a mapping from `offset` tag to an entity that we labelled `OffsetProvider`. We had mainly 2 kinds of `OffsetProvider`: a `Dimension` representing a Cartesian shift and a `NeighborTableOffsetProvider` for unstructured shifts. Since the type of `offset` needs to be known for compilation (strided for Cartesian, lookup-table for unstructured), this prevents a clean interface for ahead-of-time compilation. +For the frontend type-checking we later introduce a `FieldOffset` which contained type information of the mapped dimensions. +For (field-view) embedded we introduced a `ConnectivityField` (now `Connectivity`) which could be generated from the OffsetProvider information. + +These different concepts had overlap but were not 1-to-1 replacements. + +## Decision + +We update and introduce the following concepts + +### Conceptual definitions + +**Connectivity** is a mapping from index (or product of indices) to index. It covers 1-to-1 mappings, e.g. Cartesian shifts, NeighborTables (2D mappings) and dynamic Cartesian shifts. + +**NeighborConnectivity** is a 2D mapping of the N neighbors of a Location A to a Location B. + +**NeighborTable** is a _NeighborConnectivity_ backed by a buffer. + +**ConnectivityType**, **NeighborConnectivityType** contains all information that is needed for compilation. + +### Full definitions + +See `next.common` module + +Note: Currently, the compiled backends supports only `NeighborConnectivity`s that are `NeighborTable`s. We do not yet encode this in the type and postpone discussion to the point where we support alternative implementations (e.g. `StridedNeighborConnectivity`). + +## Which parts of the toolchain use which concept? + +### Embedded + +Embedded execution of field-view supports any kind of `Connectivity`. +Embedded execution of iterator (local) view supports only `NeighborConnectivity`s. + +### IR transformations and compiled backends + +All transformations and code-generation should use `ConnectivityType`, not the `Connectivity` which contains the runtime mapping. + +Note, currently the `global_tmps` pass uses runtime information, therefore this is not strictly enforced. + +The only supported `Connectivity`s in compiled backends (currently) are `NeighborTable`s. diff --git a/docs/user/next/QuickstartGuide.md b/docs/user/next/QuickstartGuide.md index 81604c7770..2cb6647519 100644 --- a/docs/user/next/QuickstartGuide.md +++ b/docs/user/next/QuickstartGuide.md @@ -155,8 +155,6 @@ This section approaches the pseudo-laplacian by introducing the required APIs pr - [Using reductions on connected mesh elements](#Using-reductions-on-connected-mesh-elements) - [Implementing the actual pseudo-laplacian](#Implementing-the-pseudo-laplacian) -+++ - #### Defining the mesh and its connectivities The examples related to unstructured meshes use the mesh below. The edges (in blue) and the cells (in red) are numbered with zero-based indices. @@ -237,7 +235,7 @@ E2C = gtx.FieldOffset("E2C", source=CellDim, target=(EdgeDim,E2CDim)) Note that the field offset does not contain the actual connectivity table, that's provided through an _offset provider_: ```{code-cell} ipython3 -E2C_offset_provider = gtx.NeighborTableOffsetProvider(edge_to_cell_table, EdgeDim, CellDim, 2) +E2C_offset_provider = gtx.as_connectivity([EdgeDim, E2CDim], codomain=CellDim, data=edge_to_cell_table, skip_value=-1) ``` The field operator `nearest_cell_to_edge` below shows an example of applying this transform. There is a little twist though: the subscript in `E2C[0]` means that only the value of the first connected cell is taken, the second (if exists) is ignored. @@ -385,7 +383,7 @@ As explained in the section outline, the pseudo-laplacian needs the cell-to-edge C2EDim = gtx.Dimension("C2E", kind=gtx.DimensionKind.LOCAL) C2E = gtx.FieldOffset("C2E", source=EdgeDim, target=(CellDim, C2EDim)) -C2E_offset_provider = gtx.NeighborTableOffsetProvider(cell_to_edge_table, CellDim, EdgeDim, 3) +C2E_offset_provider = gtx.as_connectivity([CellDim, C2EDim], codomain=EdgeDim, data=cell_to_edge_table, skip_value=-1) ``` **Weights of edge differences:** diff --git a/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb b/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb index 50349e52b0..b0a1980d0f 100644 --- a/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb +++ b/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb @@ -81,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "5dbd2f62", "metadata": {}, "outputs": [], @@ -113,7 +113,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " divergence_gt4py = gtx.zeros(cell_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb b/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb index 6baac2b8c0..573ee6a44e 100644 --- a/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb @@ -86,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "5dbd2f62", "metadata": {}, "outputs": [], @@ -118,7 +118,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " divergence_gt4py = gtx.zeros(cell_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb b/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb index c8914120d3..2b422b1823 100644 --- a/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb +++ b/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb @@ -80,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "84b02762", "metadata": {}, "outputs": [], @@ -110,7 +110,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " gradient_gt4py_x = gtx.zeros(cell_domain, allocator=backend)\n", " gradient_gt4py_y = gtx.zeros(cell_domain, allocator=backend)\n", diff --git a/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb b/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb index 5e940a4b71..85044b989f 100644 --- a/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb @@ -93,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "84b02762", "metadata": {}, "outputs": [], @@ -123,7 +123,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " gradient_gt4py_x = gtx.zeros(cell_domain, allocator=backend)\n", " gradient_gt4py_y = gtx.zeros(cell_domain, allocator=backend)\n", diff --git a/docs/user/next/workshop/exercises/4_curl_exercise.ipynb b/docs/user/next/workshop/exercises/4_curl_exercise.ipynb index 4a6b37baf7..dc321f1bdd 100644 --- a/docs/user/next/workshop/exercises/4_curl_exercise.ipynb +++ b/docs/user/next/workshop/exercises/4_curl_exercise.ipynb @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "5b6ffc9e", "metadata": {}, "outputs": [], @@ -134,7 +134,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", "\n", " curl_gt4py = gtx.zeros(vertex_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb b/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb index 065cf02de7..251fe8239a 100644 --- a/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb @@ -107,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "5b6ffc9e", "metadata": {}, "outputs": [], @@ -139,7 +139,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", "\n", " curl_gt4py = gtx.zeros(vertex_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb b/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb index 832375a86b..30f568de6f 100644 --- a/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb +++ b/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb @@ -228,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "f9cfc097", "metadata": {}, "outputs": [], @@ -272,10 +272,10 @@ " edge_orientation_cell.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", - " e2v_connectivity = gtx.NeighborTableOffsetProvider(e2v_table, E, V, 2, has_skip_values=False)\n", - " e2c_connectivity = gtx.NeighborTableOffsetProvider(e2c_table, E, C, 2, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", + " e2v_connectivity = gtx.as_connectivity([E, E2VDim], codomain=V, data=e2v_table)\n", + " e2c_connectivity = gtx.as_connectivity([E, E2CDim], codomain=C, data=e2c_table)\n", "\n", " laplacian_gt4py = gtx.zeros(edge_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb b/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb index be846d199d..eaeb8c7b02 100644 --- a/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb @@ -249,7 +249,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "f9cfc097", "metadata": {}, "outputs": [], @@ -293,10 +293,10 @@ " edge_orientation_cell.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", - " e2v_connectivity = gtx.NeighborTableOffsetProvider(e2v_table, E, V, 2, has_skip_values=False)\n", - " e2c_connectivity = gtx.NeighborTableOffsetProvider(e2c_table, E, C, 2, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", + " e2v_connectivity = gtx.as_connectivity([E, E2VDim], codomain=V, data=e2v_table)\n", + " e2c_connectivity = gtx.as_connectivity([E, E2CDim], codomain=C, data=e2c_table)\n", "\n", " laplacian_gt4py = gtx.zeros(edge_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb b/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb index d4bcdb33d5..b278cee26d 100644 --- a/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb @@ -118,7 +118,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": null, "id": "f9cfc097", "metadata": {}, "outputs": [], @@ -156,10 +156,8 @@ " dt,\n", " )\n", "\n", - " e2c2v_connectivity = gtx.NeighborTableOffsetProvider(\n", - " e2c2v_table, E, V, 4, has_skip_values=False\n", - " )\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", + " e2c2v_connectivity = gtx.as_connectivity([E, E2C2VDim], codomain=V, data=e2c2v_table)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", "\n", " diffusion_step(\n", " u,\n", diff --git a/docs/user/next/workshop/slides/slides_2.ipynb b/docs/user/next/workshop/slides/slides_2.ipynb index 1e8925087f..c6967df4b2 100644 --- a/docs/user/next/workshop/slides/slides_2.ipynb +++ b/docs/user/next/workshop/slides/slides_2.ipynb @@ -281,17 +281,19 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "6d30a5e1", "metadata": {}, "outputs": [], "source": [ - "E2C_offset_provider = gtx.NeighborTableOffsetProvider(e2c_table, Edge, Cell, 2)" + "E2C_offset_provider = gtx.as_connectivity(\n", + " [Edge, E2CDim], codomain=Cell, data=e2c_table, skip_value=-1\n", + ")" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "d62f6c98", "metadata": {}, "outputs": [ @@ -311,7 +313,7 @@ " return cell_field(E2C[0]) # 0th index to isolate edge dimension\n", "\n", "\n", - "@gtx.program # uses skip_values, therefore we cannot use embedded\n", + "@gtx.program\n", "def run_nearest_cell_to_edge(\n", " cell_field: gtx.Field[Dims[Cell], float64], edge_field: gtx.Field[Dims[Edge], float64]\n", "):\n", diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 9d07b2eb79..8f62788b8f 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -439,13 +439,21 @@ def ndim(self) -> int: ... @property def shape(self) -> tuple[int, ...]: ... + @property + def strides(self) -> tuple[int, ...]: ... + @property def dtype(self) -> Any: ... + @property + def itemsize(self) -> int: ... + def item(self) -> Any: ... def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: ... + def any(self) -> bool: ... + def __getitem__(self, item: Any) -> NDArrayObject: ... def __abs__(self) -> NDArrayObject: ... @@ -496,4 +504,4 @@ def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __xor(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + def __xor__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index 80bb276c70..4fa9215706 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -20,6 +20,7 @@ from . import common, ffront, iterator, program_processors from .common import ( + Connectivity, Dimension, DimensionKind, Dims, @@ -39,8 +40,7 @@ from .ffront.fbuiltins import * # noqa: F403 [undefined-local-with-import-star] explicitly reexport all from fbuiltins.__all__ from .ffront.fbuiltins import FieldOffset from .iterator.embedded import ( - NeighborTableOffsetProvider, - StridedNeighborOffsetProvider, + NeighborTableOffsetProvider, # TODO(havogt): deprecated index_field, np_as_located_field, ) @@ -61,6 +61,7 @@ "Dimension", "DimensionKind", "Field", + "Connectivity", "GridType", "domain", "Domain", @@ -75,7 +76,6 @@ "as_connectivity", # from iterator "NeighborTableOffsetProvider", - "StridedNeighborOffsetProvider", "index_field", "np_as_located_field", # from ffront diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 4aa0dd03aa..9b2870e1c0 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -18,7 +18,6 @@ from collections.abc import Mapping, Sequence import numpy as np -import numpy.typing as npt from gt4py._core import definitions as core_defs from gt4py.eve import utils @@ -95,7 +94,7 @@ def __str__(self) -> str: def __call__(self, val: int) -> NamedIndex: return NamedIndex(self, val) - def __add__(self, offset: int) -> ConnectivityField: + def __add__(self, offset: int) -> Connectivity: # TODO(sf-n): just to avoid circular import. Move or refactor the FieldOffset to avoid this. from gt4py.next.ffront import fbuiltins @@ -104,7 +103,7 @@ def __add__(self, offset: int) -> ConnectivityField: dimension_to_implicit_offset(self.value), source=self, target=(self,) )[offset] - def __sub__(self, offset: int) -> ConnectivityField: + def __sub__(self, offset: int) -> Connectivity: return self + (-offset) @@ -678,6 +677,9 @@ def codomain(self) -> type[core_defs.ScalarT] | Dimension: ... @property def dtype(self) -> core_defs.DType[core_defs.ScalarT]: ... + # TODO(havogt) + # This property is wrong, because for a function field we would not know to which NDArrayObject we want to convert + # at the very least, we need to take an allocator and rename this to `as_ndarray`. @property def ndarray(self) -> core_defs.NDArrayObject: ... @@ -688,7 +690,7 @@ def __str__(self) -> str: def asnumpy(self) -> np.ndarray: ... @abc.abstractmethod - def premap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... + def premap(self, index_field: Connectivity | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod def restrict(self, item: AnyIndexSpec) -> Self: ... @@ -700,8 +702,8 @@ def as_scalar(self) -> core_defs.ScalarT: ... @abc.abstractmethod def __call__( self, - index_field: ConnectivityField | fbuiltins.FieldOffset, - *args: ConnectivityField | fbuiltins.FieldOffset, + index_field: Connectivity | fbuiltins.FieldOffset, + *args: Connectivity | fbuiltins.FieldOffset, ) -> Field: ... @abc.abstractmethod @@ -811,12 +813,64 @@ def remapping(cls) -> ConnectivityKind: return cls.ALTER_DIMS | cls.ALTER_STRUCT +@dataclasses.dataclass(frozen=True) +class ConnectivityType: # TODO(havogt): would better live in type_specifications but would have to solve a circular import + domain: tuple[Dimension, ...] + codomain: Dimension + skip_value: Optional[core_defs.IntegralScalar] + dtype: core_defs.DType + + @property + def has_skip_values(self) -> bool: + return self.skip_value is not None + + +@dataclasses.dataclass(frozen=True) +class NeighborConnectivityType(ConnectivityType): + # TODO(havogt): refactor towards encoding this information in the local dimensions of the ConnectivityType.domain + max_neighbors: int + + @property + def source_dim(self) -> Dimension: + return self.domain[0] + + @property + def neighbor_dim(self) -> Dimension: + return self.domain[1] + + @runtime_checkable # type: ignore[misc] # DimT should be covariant, but then it breaks in other places -class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): +class Connectivity(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): @property @abc.abstractmethod - def codomain(self) -> DimT: ... + def codomain(self) -> DimT: + """ + The `codomain` is the set of all indices in a certain `Dimension`. + + We use the `Dimension` itself to describe the (infinite) set of all indices. + + Note: + We could restrict the infinite codomain to only the indices that are actually contained in the mapping. + Currently, this would just complicate implementation as we do not use this information. + """ + + def __gt_type__(self) -> ConnectivityType: + if is_neighbor_connectivity(self): + return NeighborConnectivityType( + domain=self.domain.dims, + codomain=self.codomain, + dtype=self.dtype, + skip_value=self.skip_value, + max_neighbors=self.ndarray.shape[1], + ) + else: + return ConnectivityType( + domain=self.domain.dims, + codomain=self.codomain, + dtype=self.dtype, + skip_value=self.skip_value, + ) @property def kind(self) -> ConnectivityKind: @@ -831,61 +885,61 @@ def skip_value(self) -> Optional[core_defs.IntegralScalar]: ... # Operators def __abs__(self) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __neg__(self) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __invert__(self) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __eq__(self, other: Any) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __ne__(self, other: Any) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __add__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __radd__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __sub__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rsub__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __mul__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rmul__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __truediv__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rtruediv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __floordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rfloordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __pow__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __and__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __or__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __xor__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") # Utility function to construct a `Field` from different buffer representations. @@ -911,38 +965,58 @@ def _connectivity( domain: Optional[DomainLike] = None, dtype: Optional[core_defs.DType] = None, skip_value: Optional[core_defs.IntegralScalar] = None, -) -> ConnectivityField: +) -> Connectivity: raise NotImplementedError -@runtime_checkable -class Connectivity(Protocol): - max_neighbors: int - has_skip_values: bool - origin_axis: Dimension - neighbor_axis: Dimension - index_type: type[int] | type[np.int32] | type[np.int64] +class NeighborConnectivity(Connectivity, Protocol): + # TODO(havogt): work towards encoding this properly in the type + def __gt_type__(self) -> NeighborConnectivityType: ... + - def mapped_index( - self, cur_index: int | np.integer, neigh_index: int | np.integer - ) -> Optional[int | np.integer]: - """Return neighbor index.""" +def is_neighbor_connectivity(obj: Any) -> TypeGuard[NeighborConnectivity]: + if not isinstance(obj, Connectivity): + return False + domain_dims = obj.domain.dims + return ( + len(domain_dims) == 2 + and domain_dims[0].kind is DimensionKind.HORIZONTAL + and domain_dims[1].kind is DimensionKind.LOCAL + ) -@runtime_checkable -class NeighborTable(Connectivity, Protocol): - table: npt.NDArray +class NeighborTable( + NeighborConnectivity, Protocol +): # TODO(havogt): try to express by inheriting from NdArrayConnectivityField (but this would require a protocol to move it out of `embedded.nd_array_field`) + @property + def ndarray(self) -> core_defs.NDArrayObject: + # Note that this property is currently already there from inheriting from `Field`, + # however this seems wrong, therefore we explicitly introduce it here (or it should come + # implicitly from the `NdArrayConnectivityField` protocol). + ... -OffsetProviderElem: TypeAlias = Dimension | Connectivity +def is_neighbor_table(obj: Any) -> TypeGuard[NeighborTable]: + return is_neighbor_connectivity(obj) and hasattr(obj, "ndarray") + + +OffsetProviderElem: TypeAlias = Dimension | NeighborConnectivity +OffsetProviderTypeElem: TypeAlias = Dimension | NeighborConnectivityType OffsetProvider: TypeAlias = Mapping[Tag, OffsetProviderElem] +OffsetProviderType: TypeAlias = Mapping[Tag, OffsetProviderTypeElem] + + +def offset_provider_to_type(offset_provider: OffsetProvider) -> OffsetProviderType: + return { + k: v.__gt_type__() if isinstance(v, Connectivity) else v for k, v in offset_provider.items() + } DomainDimT = TypeVar("DomainDimT", bound="Dimension") @dataclasses.dataclass(frozen=True, eq=False) -class CartesianConnectivity(ConnectivityField[Dims[DomainDimT], DimT]): +class CartesianConnectivity(Connectivity[Dims[DomainDimT], DimT]): domain_dim: DomainDimT codomain: DimT offset: int = 0 @@ -981,7 +1055,7 @@ def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: return core_defs.Int32DType() # type: ignore[return-value] # This is a workaround to make this class concrete, since `codomain` is an - # abstract property of the `ConnectivityField` Protocol. + # abstract property of the `Connectivity` Protocol. if not TYPE_CHECKING: @functools.cached_property @@ -1024,9 +1098,9 @@ def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRa def premap( self, - index_field: ConnectivityField | fbuiltins.FieldOffset, - *args: ConnectivityField | fbuiltins.FieldOffset, - ) -> ConnectivityField: + index_field: Connectivity | fbuiltins.FieldOffset, + *args: Connectivity | fbuiltins.FieldOffset, + ) -> Connectivity: raise NotImplementedError() __call__ = premap diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index dd52559e85..7b39511674 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -290,22 +290,24 @@ def as_connectivity( *, allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None, device: Optional[core_defs.Device] = None, - skip_value: Optional[core_defs.IntegralScalar] = None, + skip_value: core_defs.IntegralScalar | eve.NothingType | None = eve.NOTHING, # TODO: copy=False -) -> common.ConnectivityField: +) -> common.Connectivity: """ - Construct a connectivity field from the given domain, codomain, and data. + Construct a `Connectivity` from the given domain, codomain, and data. Arguments: - domain: The domain of the connectivity field. It can be either a `common.DomainLike` object or a + domain: The domain of the connectivity. It can be either a `common.DomainLike` object or a sequence of `common.Dimension` objects. - codomain: The codomain dimension of the connectivity field. + codomain: The codomain dimension of the connectivity. data: The data used to construct the connectivity field. - dtype: The data type of the connectivity field. If not provided, it will be inferred from the data. - allocator: The allocator used to allocate the buffer for the connectivity field. If not provided, + dtype: The data type of the connectivity. If not provided, it will be inferred from the data. + allocator: The allocator used to allocate the buffer for the connectivity. If not provided, a default allocator will be used. - device: The device on which the connectivity field will be allocated. If not provided, the default + device: The device on which the connectivity will be allocated. If not provided, the default device will be used. + skip_value: The value that signals missing entries in the neighbor table. Defaults to the default + skip value if it is found in data, otherwise to `None` (= no skip value). Returns: The constructed connectivity field. @@ -313,9 +315,15 @@ def as_connectivity( Raises: ValueError: If the domain or codomain is invalid, or if the shape of the data does not match the domain shape. """ + if skip_value is eve.NOTHING: + skip_value = ( + common._DEFAULT_SKIP_VALUE if (data == common._DEFAULT_SKIP_VALUE).any() else None + ) + assert ( skip_value is None or skip_value == common._DEFAULT_SKIP_VALUE ) # TODO(havogt): not yet configurable + skip_value = cast(Optional[core_defs.IntegralScalar], skip_value) if isinstance(domain, Sequence) and all(isinstance(dim, common.Dimension) for dim in domain): domain = cast(Sequence[common.Dimension], domain) if len(domain) != data.ndim: diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 9ff5feaaee..e15fb4266a 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -36,7 +36,6 @@ exceptions as embedded_exceptions, ) from gt4py.next.ffront import experimental, fbuiltins -from gt4py.next.iterator import embedded as itir_embedded try: @@ -189,10 +188,10 @@ def from_array( def premap( self: NdArrayField, - *connectivities: common.ConnectivityField | fbuiltins.FieldOffset, + *connectivities: common.Connectivity | fbuiltins.FieldOffset, ) -> NdArrayField: """ - Rearrange the field content using the provided connectivity fields as index mappings. + Rearrange the field content using the provided connectivities (index mappings). This operation is conceptually equivalent to a regular composition of mappings `f∘c`, being `c` the `connectivity` argument and `f` the `self` data field. @@ -206,7 +205,7 @@ def premap( argument used in the right hand side of the operator should therefore have the same product of dimensions `c: S × T → A × B`. Such a mapping can also be expressed as a pair of mappings `c1: S × T → A` and `c2: S × T → B`, and this - is actually the only supported form in GT4Py because `ConnectivityField` instances + is actually the only supported form in GT4Py because `Connectivity` instances can only deal with a single dimension in its codomain. This approach makes connectivities reusable for any combination of dimensions in a field domain and matches the NumPy advanced indexing API, which basically is a @@ -261,15 +260,15 @@ def premap( """ # noqa: RUF002 # TODO(egparedes): move docstring to the `premap` builtin function when it exists - conn_fields: list[common.ConnectivityField] = [] + conn_fields: list[common.Connectivity] = [] codomains_counter: collections.Counter[common.Dimension] = collections.Counter() for connectivity in connectivities: - # For neighbor reductions, a FieldOffset is passed instead of an actual ConnectivityField - if not isinstance(connectivity, common.ConnectivityField): + # For neighbor reductions, a FieldOffset is passed instead of an actual Connectivity + if not isinstance(connectivity, common.Connectivity): assert isinstance(connectivity, fbuiltins.FieldOffset) connectivity = connectivity.as_connectivity_field() - assert isinstance(connectivity, common.ConnectivityField) + assert isinstance(connectivity, common.Connectivity) # Current implementation relies on skip_value == -1: # if we assume the indexed array has at least one element, @@ -318,8 +317,8 @@ def premap( def __call__( self, - index_field: common.ConnectivityField | fbuiltins.FieldOffset, - *args: common.ConnectivityField | fbuiltins.FieldOffset, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: return functools.reduce( lambda field, current_index_field: field.premap(current_index_field), @@ -460,7 +459,7 @@ def _dace_descriptor(self) -> Any: @dataclasses.dataclass(frozen=True) class NdArrayConnectivityField( # type: ignore[misc] # for __ne__, __eq__ - common.ConnectivityField[common.DimsT, common.DimT], + common.Connectivity[common.DimsT, common.DimT], NdArrayField[common.DimsT, core_defs.IntegralScalar], ): _codomain: common.DimT @@ -579,7 +578,7 @@ def restrict(self, index: common.AnyIndexSpec) -> NdArrayConnectivityField: __getitem__ = restrict -def _domain_premap(data: NdArrayField, *connectivities: common.ConnectivityField) -> NdArrayField: +def _domain_premap(data: NdArrayField, *connectivities: common.Connectivity) -> NdArrayField: """`premap` implementation transforming only the field domain not the data (i.e. translation and relocation).""" new_domain = data.domain for connectivity in connectivities: @@ -668,7 +667,7 @@ def _reshuffling_premap( ) -def _remapping_premap(data: NdArrayField, connectivity: common.ConnectivityField) -> NdArrayField: +def _remapping_premap(data: NdArrayField, connectivity: common.Connectivity) -> NdArrayField: new_dims = {*connectivity.domain.dims} - {connectivity.codomain} if repeated_dims := (new_dims & {*data.domain.dims}): raise ValueError(f"Remapped field will contain repeated dimensions '{repeated_dims}'.") @@ -693,7 +692,7 @@ def _remapping_premap(data: NdArrayField, connectivity: common.ConnectivityField if restricted_connectivity_domain != connectivity.domain else connectivity ) - assert isinstance(restricted_connectivity, common.ConnectivityField) + assert isinstance(restricted_connectivity, common.Connectivity) # 2- then compute the index array new_idx_array = xp.asarray(restricted_connectivity.ndarray) - current_range.start @@ -971,7 +970,7 @@ def _concat_where( return cls_.from_array(result_array, domain=result_domain) -NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[has-type] +NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] def _make_reduction( @@ -996,15 +995,15 @@ def _builtin_op( offset_definition = current_offset_provider[ axis.value ] # assumes offset and local dimension have same name - assert isinstance(offset_definition, itir_embedded.NeighborTableOffsetProvider) + assert common.is_neighbor_table(offset_definition) new_domain = common.Domain(*[nr for nr in field.domain if nr.dim != axis]) broadcast_slice = tuple( - slice(None) if d in [axis, offset_definition.origin_axis] else xp.newaxis + slice(None) if d in [axis, offset_definition.domain.dims[0]] else xp.newaxis for d in field.domain.dims ) masked_array = xp.where( - xp.asarray(offset_definition.table[broadcast_slice]) != common._DEFAULT_SKIP_VALUE, + xp.asarray(offset_definition.ndarray[broadcast_slice]) != common._DEFAULT_SKIP_VALUE, field.ndarray, initial_value_op(field), ) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index dc2421e1d2..9ce07d01bb 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -30,7 +30,6 @@ embedded as next_embedded, errors, ) -from gt4py.next.common import Connectivity, Dimension, GridType from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( field_operator_ast as foast, @@ -82,15 +81,15 @@ class Program: definition_stage: ffront_stages.ProgramDefinition backend: Optional[next_backend.Backend] - connectivities: Optional[dict[str, Connectivity]] + connectivities: Optional[common.OffsetProviderType] = None @classmethod def from_function( cls, definition: types.FunctionType, backend: Optional[next_backend], - grid_type: Optional[GridType] = None, - connectivities: Optional[dict[str, Connectivity]] = None, + grid_type: Optional[common.GridType] = None, + connectivities: Optional[common.OffsetProviderType] = None, ) -> Program: program_def = ffront_stages.ProgramDefinition(definition=definition, grid_type=grid_type) return cls(definition_stage=program_def, backend=backend, connectivities=connectivities) @@ -140,10 +139,10 @@ def _frontend_transforms(self) -> next_backend.Transforms: def with_backend(self, backend: next_backend.Backend) -> Program: return dataclasses.replace(self, backend=backend) - def with_connectivities(self, connectivities: dict[str, Connectivity]) -> Program: + def with_connectivities(self, connectivities: common.OffsetProviderType) -> Program: return dataclasses.replace(self, connectivities=connectivities) - def with_grid_type(self, grid_type: GridType) -> Program: + def with_grid_type(self, grid_type: common.GridType) -> Program: return dataclasses.replace( self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type) ) @@ -199,7 +198,7 @@ def itir(self) -> itir.FencilDefinition: return self._frontend_transforms.past_to_itir(no_args_past).data @functools.cached_property - def _implicit_offset_provider(self) -> dict[common.Tag, common.OffsetProviderElem]: + def _implicit_offset_provider(self) -> dict[str, common.Dimension]: """ Add all implicit offset providers. @@ -226,9 +225,7 @@ def _implicit_offset_provider(self) -> dict[common.Tag, common.OffsetProviderEle ) return implicit_offset_provider - def __call__( - self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any - ) -> None: + def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any) -> None: offset_provider = offset_provider | self._implicit_offset_provider if self.backend is None: warnings.warn( @@ -287,19 +284,17 @@ def definition(self) -> str: def with_backend(self, backend: next_backend.Backend) -> FrozenProgram: return self.__class__(program=self.program, backend=backend) - def with_grid_type(self, grid_type: GridType) -> FrozenProgram: + def with_grid_type(self, grid_type: common.GridType) -> FrozenProgram: return self.__class__( program=dataclasses.replace(self.program, grid_type=grid_type), backend=self.backend ) def jit( - self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any + self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any ) -> stages.CompiledProgram: return self.backend.jit(self.program, *args, offset_provider=offset_provider, **kwargs) - def __call__( - self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any - ) -> None: + def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any) -> None: args, kwargs = signature.convert_to_positional(self.program, *args, **kwargs) if not self._compiled_program: @@ -328,7 +323,7 @@ class ProgramFromPast(Program): past_stage: ffront_stages.PastProgramDefinition - def __call__(self, *args: Any, offset_provider: dict[str, Dimension], **kwargs: Any) -> None: + def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any) -> None: if self.backend is None: raise NotImplementedError( "Programs created from a PAST node (without a function definition) can not be executed in embedded mode" @@ -350,7 +345,7 @@ def __post_init__(self): class ProgramWithBoundArgs(Program): bound_args: dict[str, typing.Union[float, int, bool]] = None - def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs): + def __call__(self, *args, offset_provider: common.OffsetProvider, **kwargs): type_ = self.past_stage.past_node.type new_type = ts_ffront.ProgramType( definition=ts.FunctionType( @@ -436,7 +431,7 @@ def program( *, # `NOTHING` -> default backend, `None` -> no backend (embedded execution) backend: next_backend.Backend | eve.NOTHING = eve.NOTHING, - grid_type: Optional[GridType] = None, + grid_type: Optional[common.GridType] = None, frozen: bool = False, ) -> Program | FrozenProgram | Callable[[types.FunctionType], Program | FrozenProgram]: """ @@ -506,7 +501,7 @@ def from_function( cls, definition: types.FunctionType, backend: Optional[next_backend.Backend], - grid_type: Optional[GridType] = None, + grid_type: Optional[common.GridType] = None, *, operator_node_cls: type[OperatorNodeT] = foast.FieldOperator, operator_attributes: Optional[dict[str, Any]] = None, @@ -557,7 +552,7 @@ def __gt_type__(self) -> ts.CallableType: def with_backend(self, backend: next_backend.Backend) -> FieldOperator: return dataclasses.replace(self, backend=backend) - def with_grid_type(self, grid_type: GridType) -> FieldOperator: + def with_grid_type(self, grid_type: common.GridType) -> FieldOperator: return dataclasses.replace( self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type) ) @@ -688,33 +683,33 @@ def field_operator_inner(definition: types.FunctionType) -> FieldOperator[foast. def scan_operator( definition: types.FunctionType, *, - axis: Dimension, + axis: common.Dimension, forward: bool, init: core_defs.Scalar, backend: Optional[str], - grid_type: GridType, + grid_type: common.GridType, ) -> FieldOperator[foast.ScanOperator]: ... @typing.overload def scan_operator( *, - axis: Dimension, + axis: common.Dimension, forward: bool, init: core_defs.Scalar, backend: Optional[str], - grid_type: GridType, + grid_type: common.GridType, ) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]: ... def scan_operator( definition: Optional[types.FunctionType] = None, *, - axis: Dimension, + axis: common.Dimension, forward: bool = True, init: core_defs.Scalar = 0.0, backend=eve.NOTHING, - grid_type: GridType = None, + grid_type: common.GridType = None, ) -> ( FieldOperator[foast.ScanOperator] | Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]] diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index 8a94c20832..bd22aebe57 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -14,7 +14,7 @@ @BuiltInFunction -def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.ConnectivityField: +def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivity: raise NotImplementedError() diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index d932431b51..b60fa63f95 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -16,7 +16,6 @@ import numpy as np from numpy import float32, float64, int32, int64 -import gt4py.next as gtx from gt4py._core import definitions as core_defs from gt4py.next import common from gt4py.next.common import Dimension, Field # noqa: F401 [unused-import] for TYPE_BUILTINS @@ -55,7 +54,7 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp return ts.DimensionType elif t is FieldOffset: return ts.OffsetType - elif t is common.ConnectivityField: + elif t is common.Connectivity: return ts.OffsetType elif t is core_defs.ScalarT: return ts.ScalarType @@ -321,7 +320,7 @@ def __post_init__(self) -> None: def __gt_type__(self) -> ts.OffsetType: return ts.OffsetType(source=self.source, target=self.target) - def __getitem__(self, offset: int) -> common.ConnectivityField: + def __getitem__(self, offset: int) -> common.Connectivity: """Serve as a connectivity factory.""" from gt4py.next import embedded # avoid circular import @@ -330,22 +329,19 @@ def __getitem__(self, offset: int) -> common.ConnectivityField: assert current_offset_provider is not None offset_definition = current_offset_provider[self.value] - connectivity: common.ConnectivityField + connectivity: common.Connectivity if isinstance(offset_definition, common.Dimension): connectivity = common.CartesianConnectivity(offset_definition, offset) - elif isinstance( - offset_definition, (gtx.NeighborTableOffsetProvider, common.ConnectivityField) - ): - unrestricted_connectivity = self.as_connectivity_field() - assert unrestricted_connectivity.domain.ndim > 1 + elif isinstance(offset_definition, common.Connectivity): + assert common.is_neighbor_connectivity(offset_definition) named_index = common.NamedIndex(self.target[-1], offset) - connectivity = unrestricted_connectivity[named_index] + connectivity = offset_definition[named_index] else: raise NotImplementedError() return connectivity - def as_connectivity_field(self) -> common.ConnectivityField: + def as_connectivity_field(self) -> common.Connectivity: """Convert to connectivity field using the offset providers in current embedded execution context.""" from gt4py.next import embedded # avoid circular import @@ -356,18 +352,8 @@ def as_connectivity_field(self) -> common.ConnectivityField: cache_key = id(offset_definition) if (connectivity := self._cache.get(cache_key, None)) is None: - if isinstance(offset_definition, common.ConnectivityField): + if isinstance(offset_definition, common.Connectivity): connectivity = offset_definition - elif isinstance(offset_definition, gtx.NeighborTableOffsetProvider): - connectivity = gtx.as_connectivity( - domain=self.target, - codomain=self.source, - data=offset_definition.table, - dtype=offset_definition.index_type, - skip_value=( - common._DEFAULT_SKIP_VALUE if offset_definition.has_skip_values else None - ), - ) else: raise NotImplementedError() diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 6221c95522..3c63ffef30 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -93,77 +93,113 @@ class SparseTag(Tag): ... -class NeighborTableOffsetProvider: +@xtyping.deprecated("Use a 'Connectivity' instead.") +def NeighborTableOffsetProvider( + table: core_defs.NDArrayObject, + origin_axis: common.Dimension, + neighbor_axis: common.Dimension, + max_neighbors: int, + has_skip_values=True, +) -> common.Connectivity: + return common._connectivity( + table, + codomain=neighbor_axis, + domain={ + origin_axis: table.shape[0], + common.Dimension( + value="_DummyLocalDim", kind=common.DimensionKind.LOCAL + ): max_neighbors, + }, + skip_value=common._DEFAULT_SKIP_VALUE if has_skip_values else None, + ) + + +# TODO(havogt): complete implementation and make available for fieldview embedded +@dataclasses.dataclass(frozen=True) +class StridedConnectivityField(common.Connectivity): + domain_dims: tuple[common.Dimension, common.Dimension] + codomain_dim: common.Dimension + _max_neighbors: int + def __init__( self, - table: core_defs.NDArrayObject, - origin_axis: common.Dimension, - neighbor_axis: common.Dimension, + domain_dims: Sequence[common.Dimension], + codomain_dim: common.Dimension, max_neighbors: int, - has_skip_values=True, - ) -> None: - self.table = table - self.origin_axis = origin_axis - self.neighbor_axis = neighbor_axis - assert not hasattr(table, "shape") or table.shape[1] == max_neighbors - self.max_neighbors = max_neighbors - self.has_skip_values = has_skip_values - self.index_type = table.dtype - - def mapped_index( - self, primary: common.IntIndex, neighbor_idx: common.IntIndex - ) -> common.IntIndex: - res = self.table[(primary, neighbor_idx)] - assert common.is_int_index(res) - return res + ): + object.__setattr__(self, "domain_dims", tuple(domain_dims)) + object.__setattr__(self, "codomain_dim", codomain_dim) + object.__setattr__(self, "_max_neighbors", max_neighbors) - if dace: - # Extension of NeighborTableOffsetProvider adding SDFGConvertible support in GT4Py Programs - def _dace_data_ptr(self) -> int: - obj = self.table - if dace.dtypes.is_array(obj): - if hasattr(obj, "__array_interface__"): - return obj.__array_interface__["data"][0] - if hasattr(obj, "__cuda_array_interface__"): - return obj.__cuda_array_interface__["data"][0] - raise ValueError("Unsupported data container.") - - def _dace_descriptor(self) -> dace.data.Data: - return dace.data.create_datadescriptor(self.table) - else: + @property + def __gt_origin__(self) -> xtyping.Never: + raise NotImplementedError + + def __gt_type__(self) -> common.NeighborConnectivityType: + return common.NeighborConnectivityType( + domain=self.domain_dims, + codomain=self.codomain_dim, + max_neighbors=self._max_neighbors, + skip_value=self.skip_value, + dtype=self.dtype, + ) - def _dace_data_ptr(self) -> NoReturn: # type: ignore[misc] - raise NotImplementedError( - "data_ptr is only supported when the 'dace' module is available." - ) + @property + def domain(self) -> common.Domain: + return common.Domain( + dims=self.domain_dims, + ranges=(common.UnitRange.infinite(), common.unit_range(self._max_neighbors)), + ) - def _dace_descriptor(self) -> NoReturn: # type: ignore[misc] - raise NotImplementedError( - "__descriptor__ is only supported when the 'dace' module is available." - ) + @property + def codomain(self) -> common.Dimension: + return self.codomain_dim - data_ptr = _dace_data_ptr - __descriptor__ = _dace_descriptor + @property + def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: + return core_defs.Int32DType() # type: ignore[return-value] + @property + def ndarray(self) -> core_defs.NDArrayObject: + raise NotImplementedError -class StridedNeighborOffsetProvider: - def __init__( + def asnumpy(self) -> np.ndarray: + raise NotImplementedError + + def premap(self, index_field: common.Connectivity | fbuiltins.FieldOffset) -> common.Field: + raise NotImplementedError + + def restrict( # type: ignore[override] self, - origin_axis: common.Dimension, - neighbor_axis: common.Dimension, - max_neighbors: int, - has_skip_values=True, - ) -> None: - self.origin_axis = origin_axis - self.neighbor_axis = neighbor_axis - self.max_neighbors = max_neighbors - self.has_skip_values = has_skip_values - self.index_type = int + item: common.AnyIndexSpec, + ) -> common.Field: + if not isinstance(item, tuple) or (isinstance(item, tuple) and not len(item) == 2): + raise NotImplementedError() # TODO(havogt): add proper slicing + index = item[0] * self._max_neighbors + item[1] # type: ignore[operator, call-overload] + return ConstantField(index) + + def as_scalar(self) -> xtyping.Never: + raise NotImplementedError() + + def __call__( + self, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, + ) -> common.Field: + raise NotImplementedError() - def mapped_index( - self, primary: common.IntIndex, neighbor_idx: common.IntIndex - ) -> common.IntIndex: - return primary * self.max_neighbors + neighbor_idx + __getitem__ = restrict # type: ignore[assignment] + + def inverse_image( + self, image_range: common.UnitRange | common.NamedRange + ) -> Sequence[common.NamedRange]: + raise NotImplementedError + + @property + def skip_value( + self, + ) -> None: + return None # Offsets @@ -597,10 +633,11 @@ def execute_shift( new_entry[i] = 0 else: offset_implementation = offset_provider[tag] - assert isinstance(offset_implementation, common.Connectivity) - cur_index = pos[offset_implementation.origin_axis.value] + assert common.is_neighbor_connectivity(offset_implementation) + source_dim = offset_implementation.__gt_type__().source_dim + cur_index = pos[source_dim.value] assert common.is_int_index(cur_index) - if offset_implementation.mapped_index(cur_index, index) in [ + if offset_implementation[cur_index, index].as_scalar() in [ None, common._DEFAULT_SKIP_VALUE, ]: @@ -620,22 +657,22 @@ def execute_shift( else: raise AssertionError() return new_pos - else: - assert isinstance(offset_implementation, common.Connectivity) - assert offset_implementation.origin_axis.value in pos + elif common.is_neighbor_connectivity(offset_implementation): + source_dim = offset_implementation.__gt_type__().source_dim + assert source_dim.value in pos new_pos = pos.copy() - new_pos.pop(offset_implementation.origin_axis.value) - cur_index = pos[offset_implementation.origin_axis.value] + new_pos.pop(source_dim.value) + cur_index = pos[source_dim.value] assert common.is_int_index(cur_index) - if offset_implementation.mapped_index(cur_index, index) in [ + if offset_implementation[cur_index, index].as_scalar() in [ None, common._DEFAULT_SKIP_VALUE, ]: return None else: - new_index = offset_implementation.mapped_index(cur_index, index) + new_index = offset_implementation[cur_index, index].as_scalar() assert new_index is not None - new_pos[offset_implementation.neighbor_axis.value] = int(new_index) + new_pos[offset_implementation.codomain.value] = int(new_index) return new_pos @@ -1196,8 +1233,8 @@ def as_scalar(self) -> core_defs.IntegralScalar: def premap( self, - index_field: common.ConnectivityField | fbuiltins.FieldOffset, - *args: common.ConnectivityField | fbuiltins.FieldOffset, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() @@ -1322,8 +1359,8 @@ def asnumpy(self) -> np.ndarray: def premap( self, - index_field: common.ConnectivityField | fbuiltins.FieldOffset, - *args: common.ConnectivityField | fbuiltins.FieldOffset, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() @@ -1428,10 +1465,12 @@ def __gt_type__(self) -> itir_ts.ListType: assert isinstance(offset_tag, str) element_type = type_translation.from_value(self.values[0]) assert isinstance(element_type, ts.DataType) - return itir_ts.ListType( - element_type=element_type, - offset_type=common.Dimension(value=offset_tag, kind=common.DimensionKind.LOCAL), - ) + offset_provider = embedded_context.offset_provider.get() + assert offset_provider is not None + connectivity = offset_provider[offset_tag] + assert common.is_neighbor_connectivity(connectivity) + local_dim = connectivity.__gt_type__().neighbor_dim + return itir_ts.ListType(element_type=element_type, offset_type=local_dim) @dataclasses.dataclass(frozen=True) @@ -1457,11 +1496,11 @@ def neighbors(offset: runtime.Offset, it: ItIterator) -> _List: offset_provider = embedded_context.offset_provider.get() assert offset_provider is not None connectivity = offset_provider[offset_str] - assert isinstance(connectivity, common.Connectivity) + assert common.is_neighbor_connectivity(connectivity) return _List( values=tuple( shifted.deref() - for i in range(connectivity.max_neighbors) + for i in range(connectivity.__gt_type__().max_neighbors) if (shifted := it.shift(offset_str, i)).can_deref() ), offset=offset, @@ -1533,11 +1572,11 @@ def deref(self) -> Any: offset_provider = embedded_context.offset_provider.get() assert offset_provider is not None connectivity = offset_provider[self.list_offset] - assert isinstance(connectivity, common.Connectivity) + assert common.is_neighbor_connectivity(connectivity) return _List( values=tuple( shifted.deref() - for i in range(connectivity.max_neighbors) + for i in range(connectivity.__gt_type__().max_neighbors) if ( shifted := self.it.shift(*self.offsets, SparseTag(self.list_offset), i) ).can_deref() @@ -1671,9 +1710,9 @@ def _dimension_to_tag(domain: Domain) -> dict[Tag, range]: return {k.value if isinstance(k, common.Dimension) else k: v for k, v in domain.items()} -def _validate_domain(domain: Domain, offset_provider: OffsetProvider) -> None: +def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProviderType) -> None: if isinstance(domain, runtime.CartesianDomain): - if any(isinstance(o, common.Connectivity) for o in offset_provider.values()): + if any(isinstance(o, common.ConnectivityType) for o in offset_provider_type.values()): raise RuntimeError( "Got a 'CartesianDomain', but found a 'Connectivity' in 'offset_provider', expected 'UnstructuredDomain'." ) @@ -1770,10 +1809,10 @@ def _fieldspec_list_to_value( offset_type = type_.offset_type assert isinstance(offset_type, common.Dimension) connectivity = offset_provider[offset_type.value] - assert isinstance(connectivity, common.Connectivity) + assert common.is_neighbor_connectivity(connectivity) return domain.insert( len(domain), - common.named_range((offset_type, connectivity.max_neighbors)), + common.named_range((offset_type, connectivity.__gt_type__().max_neighbors)), ), type_.element_type return domain, type_ @@ -1809,7 +1848,7 @@ def closure( ) -> None: assert embedded_context.within_valid_context() offset_provider = embedded_context.offset_provider.get() - _validate_domain(domain_, offset_provider) + _validate_domain(domain_, common.offset_provider_to_type(offset_provider)) domain: dict[Tag, range] = _dimension_to_tag(domain_) if not (isinstance(out, common.Field) or is_tuple_of_field(out)): raise TypeError("'Out' needs to be a located field.") diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 8f842e1c13..f5625b509c 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -12,7 +12,6 @@ import functools from typing import Any, Literal, Mapping, Optional -import gt4py.next as gtx from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im @@ -23,20 +22,19 @@ def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> di """ Extract horizontal domain sizes from an `offset_provider`. - Considers the shape of the neighbor table to get the size of each `origin_axis` and the maximum - value inside the neighbor table to get the size of each `neighbor_axis`. + Considers the shape of the neighbor table to get the size of each `source_dim` and the maximum + value inside the neighbor table to get the size of each `codomain`. """ sizes = dict[str, int]() for provider in offset_provider.values(): - if isinstance(provider, gtx.NeighborTableOffsetProvider): - assert provider.origin_axis.kind == gtx.DimensionKind.HORIZONTAL - assert provider.neighbor_axis.kind == gtx.DimensionKind.HORIZONTAL - sizes[provider.origin_axis.value] = max( - sizes.get(provider.origin_axis.value, 0), provider.table.shape[0] + if common.is_neighbor_connectivity(provider): + conn_type = provider.__gt_type__() + sizes[conn_type.source_dim.value] = max( + sizes.get(conn_type.source_dim.value, 0), provider.ndarray.shape[0] ) - sizes[provider.neighbor_axis.value] = max( - sizes.get(provider.neighbor_axis.value, 0), - provider.table.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject + sizes[conn_type.codomain.value] = max( + sizes.get(conn_type.codomain.value, 0), + provider.ndarray.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject ) return sizes @@ -114,7 +112,7 @@ def translate( new_ranges[current_dim] = SymbolicRange.translate( self.ranges[current_dim], val.value ) - elif isinstance(nbt_provider, common.Connectivity): + elif common.is_neighbor_connectivity(nbt_provider): # unstructured shift assert ( isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int) @@ -132,8 +130,8 @@ def translate( for k, v in _max_domain_sizes_by_location_type(offset_provider).items() } - old_dim = nbt_provider.origin_axis - new_dim = nbt_provider.neighbor_axis + old_dim = nbt_provider.__gt_type__().source_dim + new_dim = nbt_provider.__gt_type__().codomain assert new_dim not in new_ranges or old_dim == new_dim diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index ad85d154cb..d42f961202 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -12,7 +12,7 @@ import functools import types from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union import devtools @@ -127,7 +127,9 @@ def fendef( ) -def _deduce_domain(domain: dict[common.Dimension, range], offset_provider: dict[str, Any]): +def _deduce_domain( + domain: dict[common.Dimension, range], offset_provider_type: common.OffsetProviderType +): if isinstance(domain, UnstructuredDomain): domain_builtin = builtins.unstructured_domain elif isinstance(domain, CartesianDomain): @@ -135,7 +137,7 @@ def _deduce_domain(domain: dict[common.Dimension, range], offset_provider: dict[ else: domain_builtin = ( builtins.unstructured_domain - if any(isinstance(o, common.Connectivity) for o in offset_provider.values()) + if any(isinstance(o, common.ConnectivityType) for o in offset_provider_type.values()) else builtins.cartesian_domain ) @@ -160,7 +162,7 @@ def impl(out, *inps): elif isinstance(dom, dict): # if passed as a dict, we need to convert back to builtins for interpretation by the backends assert offset_provider is not None - dom = _deduce_domain(dom, offset_provider) + dom = _deduce_domain(dom, common.offset_provider_to_type(offset_provider)) closure(dom, self.fundef_dispatcher, out, [*inps]) return impl diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index f84714e779..e71a24127f 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -105,7 +105,7 @@ def apply( *, ignore_tuple_size: bool = False, remove_letified_make_tuple_elements: bool = True, - offset_provider: Optional[common.OffsetProvider] = None, + offset_provider_type: Optional[common.OffsetProviderType] = None, within_stencil: Optional[bool] = None, # manually passing flags is mostly for allowing separate testing of the modes flags: Optional[Flag] = None, @@ -126,7 +126,7 @@ def apply( `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` -> {1, 2}` """ flags = flags or cls.flags - offset_provider = offset_provider or {} + offset_provider_type = offset_provider_type or {} if isinstance(node, (ir.Program, ir.FencilDefinition)): within_stencil = False @@ -138,7 +138,7 @@ def apply( if not ignore_tuple_size: node = itir_type_inference.infer( node, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, allow_undeclared_symbols=allow_undeclared_symbols, ) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 38ea1fd53d..824adfdd8d 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -411,7 +411,7 @@ def apply( cls, node: ProgramOrExpr, within_stencil: bool | None = None, - offset_provider: common.OffsetProvider | None = None, + offset_provider_type: common.OffsetProviderType | None = None, ) -> ProgramOrExpr: is_program = isinstance(node, (itir.Program, itir.FencilDefinition)) if is_program: @@ -422,9 +422,9 @@ def apply( within_stencil is not None ), "The expression's context must be specified using `within_stencil`." - offset_provider = offset_provider or {} + offset_provider_type = offset_provider_type or {} node = itir_type_inference.infer( - node, offset_provider=offset_provider, allow_undeclared_symbols=not is_program + node, offset_provider_type=offset_provider_type, allow_undeclared_symbols=not is_program ) return cls().visit(node, within_stencil=within_stencil) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index da238733da..9076bf2d3f 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -11,6 +11,7 @@ from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import ( @@ -89,7 +90,7 @@ class FuseAsFieldOp(eve.NodeTranslator): ) >>> print( ... FuseAsFieldOp.apply( - ... nested_as_fieldop, offset_provider={}, allow_undeclared_symbols=True + ... nested_as_fieldop, offset_provider_type={}, allow_undeclared_symbols=True ... ) ... ) as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3) @@ -134,12 +135,14 @@ def apply( cls, node: itir.Program, *, - offset_provider, + offset_provider_type: common.OffsetProviderType, uids: Optional[eve_utils.UIDGenerator] = None, allow_undeclared_symbols=False, ): node = type_inference.infer( - node, offset_provider=offset_provider, allow_undeclared_symbols=allow_undeclared_symbols + node, + offset_provider_type=offset_provider_type, + allow_undeclared_symbols=allow_undeclared_symbols, ) if not uids: diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 90f8a6cded..a6d39883e3 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -187,7 +187,9 @@ def create_global_tmps( arguments into temporaries. """ program = infer_domain.infer_program(program, offset_provider) - program = type_inference.infer(program, offset_provider=offset_provider) + program = type_inference.infer( + program, offset_provider_type=common.offset_provider_to_type(offset_provider) + ) if not uids: uids = eve_utils.UIDGenerator(prefix="__tmp") diff --git a/src/gt4py/next/iterator/transforms/inline_scalar.py b/src/gt4py/next/iterator/transforms/inline_scalar.py index c6e2c38b90..87b576d14d 100644 --- a/src/gt4py/next/iterator/transforms/inline_scalar.py +++ b/src/gt4py/next/iterator/transforms/inline_scalar.py @@ -17,8 +17,8 @@ class InlineScalar(eve.NodeTranslator): @classmethod - def apply(cls, program: itir.Program, offset_provider: common.OffsetProvider): - program = itir_inference.infer(program, offset_provider=offset_provider) + def apply(cls, program: itir.Program, offset_provider_type: common.OffsetProviderType): + program = itir_inference.infer(program, offset_provider_type=offset_provider_type) return cls().visit(program) def visit_Expr(self, node: itir.Expr): diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 52a452155a..ec6f89685a 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -43,8 +43,8 @@ def __call__( def apply_common_transforms( ir: itir.Program | itir.FencilDefinition, *, + offset_provider=None, # TODO(havogt): should be replaced by offset_provider_type, but global_tmps currently relies on runtime info extract_temporaries=False, - offset_provider=None, unroll_reduce=False, common_subexpression_elimination=True, force_inline_lambda_args=False, @@ -56,7 +56,12 @@ def apply_common_transforms( #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. symbolic_domain_sizes: Optional[dict[str, str]] = None, + offset_provider_type: Optional[common.OffsetProviderType] = None, ) -> itir.Program: + # TODO(havogt): if the runtime `offset_provider` is not passed, we cannot run global_tmps + if offset_provider_type is None: + offset_provider_type = common.offset_provider_to_type(offset_provider) + # FIXME[#1582](tehrengruber): Rewrite iterator tests with itir.Program and remove this if isinstance(ir, itir.FencilDefinition): ir = fencil_to_program.FencilToProgram.apply(ir) @@ -75,7 +80,7 @@ def apply_common_transforms( # Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) - ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program + ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( ir, # type: ignore[arg-type] # always an itir.Program offset_provider=offset_provider, @@ -89,15 +94,15 @@ def apply_common_transforms( inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # always an itir.Program # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply(inlined, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program - inlined = InlineScalar.apply(inlined, offset_provider=offset_provider) + inlined = CollapseTuple.apply(inlined, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + inlined = InlineScalar.apply(inlined, offset_provider_type=offset_provider_type) # This pass is required to run after CollapseTuple as otherwise we can not inline # expressions like `tuple_get(make_tuple(as_fieldop(stencil)(...)))` where stencil returns # a list. Such expressions must be inlined however because no backend supports such # field operators right now. inlined = fuse_as_fieldop.FuseAsFieldOp.apply( - inlined, uids=mergeasfop_uids, offset_provider=offset_provider + inlined, uids=mergeasfop_uids, offset_provider_type=offset_provider_type ) if inlined == ir: @@ -108,19 +113,21 @@ def apply_common_transforms( # breaks in test_zero_dim_tuple_arg as trivial tuple_get is not inlined if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) + ir = CommonSubexpressionElimination.apply(ir, offset_provider_type=offset_provider_type) ir = MergeLet().visit(ir) ir = InlineLambdas.apply(ir, opcount_preserving=True) if extract_temporaries: - ir = infer(ir, inplace=True, offset_provider=offset_provider) + ir = infer(ir, inplace=True, offset_provider_type=offset_provider_type) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # type: ignore[arg-type] # always an itir.Program # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can # only run the unconditional version here instead of in the loop above. if unconditionally_collapse_tuples: - ir = CollapseTuple.apply(ir, ignore_tuple_size=True, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program + ir = CollapseTuple.apply( + ir, ignore_tuple_size=True, offset_provider_type=offset_provider_type + ) # type: ignore[assignment] # always an itir.Program ir = NormalizeShifts().visit(ir) @@ -129,7 +136,7 @@ def apply_common_transforms( if unroll_reduce: for _ in range(10): - unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) + unrolled = UnrollReduce.apply(ir, offset_provider_type=offset_provider_type) if unrolled == ir: break ir = unrolled # type: ignore[assignment] # still a `itir.Program` @@ -156,6 +163,8 @@ def apply_fieldview_transforms( ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) - ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # type is still `itir.Program` + ir = CollapseTuple.apply( + ir, offset_provider_type=common.offset_provider_to_type(offset_provider) + ) # type: ignore[assignment] # type is still `itir.Program` ir = infer_domain.infer_program(ir, offset_provider=offset_provider) return ir diff --git a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py index 792bb421f1..94c962e92d 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py +++ b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py @@ -10,6 +10,7 @@ from typing import Callable, Optional from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import fencil_to_program, inline_fundefs from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet @@ -75,8 +76,13 @@ def apply_common_transforms( Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, symbolic_domain_sizes: Optional[dict[str, str]] = None, + offset_provider_type: Optional[common.OffsetProviderType] = None, ) -> itir.Program: assert isinstance(ir, itir.FencilDefinition) + # TODO(havogt): if the runtime `offset_provider` is not passed, we cannot run global_tmps + if offset_provider_type is None: + offset_provider_type = common.offset_provider_to_type(offset_provider) + ir = fencil_to_program.FencilToProgram().apply(ir) icdlv_uids = eve_utils.UIDGenerator() @@ -109,7 +115,7 @@ def apply_common_transforms( # is constant-folded the surrounding tuple_get calls can be removed. inlined = CollapseTuple.apply( inlined, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, # TODO(tehrengruber): disabled since it increases compile-time too much right now flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, ) @@ -134,7 +140,7 @@ def apply_common_transforms( ir = CollapseTuple.apply( ir, ignore_tuple_size=True, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, # TODO(tehrengruber): disabled since it increases compile-time too much right now flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, ) @@ -149,7 +155,7 @@ def apply_common_transforms( if unroll_reduce: for _ in range(10): - unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) + unrolled = UnrollReduce.apply(ir, offset_provider_type=offset_provider_type) if unrolled == ir: break ir = unrolled @@ -164,7 +170,7 @@ def apply_common_transforms( ir = ScanEtaReduction().visit(ir) if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) # type: ignore[type-var] # always an itir.Program + ir = CommonSubexpressionElimination.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[type-var] # always an itir.Program ir = MergeLet().visit(ir) ir = InlineLambdas.apply( diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index ec9c3efb2b..042a86cd8e 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -64,16 +64,16 @@ def _get_partial_offset_tags(reduce_args: Iterable[itir.Expr]) -> Iterable[str]: def _get_connectivity( applied_reduce_node: itir.FunCall, - offset_provider: dict[str, common.Dimension | common.Connectivity], -) -> common.Connectivity: + offset_provider_type: common.OffsetProviderType, +) -> common.NeighborConnectivityType: """Return single connectivity that is compatible with the arguments of the reduce.""" if not cpm.is_applied_reduce(applied_reduce_node): raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") - connectivities: list[common.Connectivity] = [] + connectivities: list[common.NeighborConnectivityType] = [] for o in _get_partial_offset_tags(applied_reduce_node.args): - conn = offset_provider[o] - assert isinstance(conn, common.Connectivity) + conn = offset_provider_type[o] + assert isinstance(conn, common.NeighborConnectivityType) connectivities.append(conn) if not connectivities: @@ -120,15 +120,15 @@ class UnrollReduce(PreserveLocationVisitor, NodeTranslator): uids: UIDGenerator = dataclasses.field(init=False, repr=False, default_factory=UIDGenerator) @classmethod - def apply(cls, node: itir.Node, **kwargs) -> itir.Node: - return cls().visit(node, **kwargs) - - def _visit_reduce(self, node: itir.FunCall, **kwargs) -> itir.Expr: - offset_provider = kwargs["offset_provider"] - assert offset_provider is not None - connectivity = _get_connectivity(node, offset_provider) - max_neighbors = connectivity.max_neighbors - has_skip_values = connectivity.has_skip_values + def apply(cls, node: itir.Node, offset_provider_type: common.OffsetProviderType) -> itir.Node: + return cls().visit(node, offset_provider_type=offset_provider_type) + + def _visit_reduce( + self, node: itir.FunCall, offset_provider_type: common.OffsetProviderType + ) -> itir.Expr: + connectivity_type = _get_connectivity(node, offset_provider_type) + max_neighbors = connectivity_type.max_neighbors + has_skip_values = connectivity_type.has_skip_values acc = itir.SymRef(id=self.uids.sequential_id(prefix="_acc")) offset = itir.SymRef(id=self.uids.sequential_id(prefix="_i")) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 66d8345b94..987eb0f308 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -155,7 +155,7 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): >>> square_func_type_synthesizer = type_synthesizer.TypeSynthesizer( ... type_synthesizer=lambda base: power(base, int_type) ... ) - >>> square_func_type_synthesizer(float_type, offset_provider={}) + >>> square_func_type_synthesizer(float_type, offset_provider_type={}) ScalarType(kind=, shape=None) Note that without a corresponding call the function itself can not be fully typed and as such @@ -169,7 +169,7 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): ... node=square_func, ... store_inferred_type_in_node=True, ... ) - >>> o_type_synthesizer(float_type, offset_provider={}) + >>> o_type_synthesizer(float_type, offset_provider_type={}) ScalarType(kind=, shape=None) >>> square_func.type == ts.FunctionType( ... pos_only_args=[float_type], pos_or_kw_args={}, kw_only_args={}, returns=float_type @@ -225,13 +225,15 @@ def on_type_ready(self, cb: Callable[[ts.TypeSpec], None]) -> None: def __call__( self, *args: type_synthesizer.TypeOrTypeSynthesizer, - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> Union[ts.TypeSpec, ObservableTypeSynthesizer]: assert all( isinstance(arg, (ts.TypeSpec, ObservableTypeSynthesizer)) for arg in args ), "ObservableTypeSynthesizer can only be used with arguments that are TypeSpec or ObservableTypeSynthesizer" - return_type_or_synthesizer = self.type_synthesizer(*args, offset_provider=offset_provider) + return_type_or_synthesizer = self.type_synthesizer( + *args, offset_provider_type=offset_provider_type + ) # return type is a typing rule by itself if isinstance(return_type_or_synthesizer, type_synthesizer.TypeSynthesizer): @@ -250,18 +252,18 @@ def __call__( def _get_dimensions_from_offset_provider( - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> dict[str, common.Dimension]: dimensions: dict[str, common.Dimension] = {} - for offset_name, provider in offset_provider.items(): + for offset_name, provider in offset_provider_type.items(): dimensions[offset_name] = common.Dimension( value=offset_name, kind=common.DimensionKind.LOCAL ) if isinstance(provider, common.Dimension): dimensions[provider.value] = provider - elif isinstance(provider, common.Connectivity): - dimensions[provider.origin_axis.value] = provider.origin_axis - dimensions[provider.neighbor_axis.value] = provider.neighbor_axis + elif isinstance(provider, common.NeighborConnectivityType): + dimensions[provider.source_dim.value] = provider.source_dim + dimensions[provider.codomain.value] = provider.codomain return dimensions @@ -318,7 +320,7 @@ class ITIRTypeInference(eve.NodeTranslator): PRESERVED_ANNEX_ATTRS = ("domain",) - offset_provider: common.OffsetProvider + offset_provider_type: common.OffsetProviderType #: Mapping from a dimension name to the actual dimension instance. dimensions: dict[str, common.Dimension] #: Allow sym refs to symbols that have not been declared. Mostly used in testing. @@ -329,7 +331,7 @@ def apply( cls, node: T, *, - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, inplace: bool = False, allow_undeclared_symbols: bool = False, ) -> T: @@ -340,7 +342,7 @@ def apply( node: The :class:`itir.Node` to infer the types of. Keyword Arguments: - offset_provider: Offset provider dictionary. + offset_provider_type: Offset provider dictionary. inplace: Write types directly to the given ``node`` instead of returning a copy. allow_undeclared_symbols: Allow references to symbols that don't have a corresponding declaration. This is useful for testing or inference on partially inferred sub-nodes. @@ -403,9 +405,9 @@ def apply( ) instance = cls( - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, dimensions=( - _get_dimensions_from_offset_provider(offset_provider) + _get_dimensions_from_offset_provider(offset_provider_type) | _get_dimensions_from_types( node.pre_walk_values() .if_isinstance(itir.Node) @@ -540,7 +542,7 @@ def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.Stenc for input_ in inputs ] stencil_returns = stencil_type_synthesizer( - *stencil_args, offset_provider=self.offset_provider + *stencil_args, offset_provider_type=self.offset_provider_type ) return it_ts.StencilClosureType( @@ -632,7 +634,7 @@ def visit_FunCall( fun = self.visit(node.fun, ctx=ctx) args = self.visit(node.args, ctx=ctx) - result = fun(*args, offset_provider=self.offset_provider) + result = fun(*args, offset_provider_type=self.offset_provider_type) if isinstance(result, ObservableTypeSynthesizer): assert not result.node diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 43c4465576..5be9ed7438 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -35,20 +35,20 @@ class TypeSynthesizer: - isinstance checks to determine if an object is actually (meant to be) a type synthesizer and not just any callable. - writing simple type synthesizers without cluttering the signature with the additional - offset_provider argument that is only needed by some. + offset_provider_type argument that is only needed by some. """ type_synthesizer: Callable[..., TypeOrTypeSynthesizer] def __post_init__(self): - if "offset_provider" not in inspect.signature(self.type_synthesizer).parameters: + if "offset_provider_type" not in inspect.signature(self.type_synthesizer).parameters: synthesizer = self.type_synthesizer - self.type_synthesizer = lambda *args, offset_provider: synthesizer(*args) + self.type_synthesizer = lambda *args, offset_provider_type: synthesizer(*args) def __call__( - self, *args: TypeOrTypeSynthesizer, offset_provider: common.OffsetProvider + self, *args: TypeOrTypeSynthesizer, offset_provider_type: common.OffsetProviderType ) -> TypeOrTypeSynthesizer: - return self.type_synthesizer(*args, offset_provider=offset_provider) + return self.type_synthesizer(*args, offset_provider_type=offset_provider_type) TypeOrTypeSynthesizer = Union[ts.TypeSpec, TypeSynthesizer] @@ -212,7 +212,7 @@ def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) - def lift(stencil: TypeSynthesizer) -> TypeSynthesizer: @TypeSynthesizer def apply_lift( - *its: it_ts.IteratorType, offset_provider: common.OffsetProvider + *its: it_ts.IteratorType, offset_provider_type: common.OffsetProviderType ) -> it_ts.IteratorType: assert all(isinstance(it, it_ts.IteratorType) for it in its) stencil_args = [ @@ -224,7 +224,7 @@ def apply_lift( ) for it in its ] - stencil_return_type = stencil(*stencil_args, offset_provider=offset_provider) + stencil_return_type = stencil(*stencil_args, offset_provider_type=offset_provider_type) assert isinstance(stencil_return_type, ts.DataType) position_dims = its[0].position_dims if its else [] @@ -282,7 +282,7 @@ def as_fieldop( stencil: TypeSynthesizer, domain: Optional[it_ts.DomainType] = None, *, - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> TypeSynthesizer: # In case we don't have a domain argument to `as_fieldop` we can not infer the exact result # type. In order to still allow some passes which don't need this information to run before the @@ -308,7 +308,7 @@ def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: stencil_return = stencil( *(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields), - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, ) assert isinstance(stencil_return, ts.DataType) return type_info.apply_to_primitive_constituents( @@ -328,8 +328,10 @@ def scan( assert isinstance(direction, ts.ScalarType) and direction.kind == ts.ScalarKind.BOOL @TypeSynthesizer - def apply_scan(*its: it_ts.IteratorType, offset_provider: common.OffsetProvider) -> ts.DataType: - result = scan_pass(init, *its, offset_provider=offset_provider) + def apply_scan( + *its: it_ts.IteratorType, offset_provider_type: common.OffsetProviderType + ) -> ts.DataType: + result = scan_pass(init, *its, offset_provider_type=offset_provider_type) assert isinstance(result, ts.DataType) return result @@ -340,12 +342,12 @@ def apply_scan(*its: it_ts.IteratorType, offset_provider: common.OffsetProvider) def map_(op: TypeSynthesizer) -> TypeSynthesizer: @TypeSynthesizer def applied_map( - *args: it_ts.ListType, offset_provider: common.OffsetProvider + *args: it_ts.ListType, offset_provider_type: common.OffsetProviderType ) -> it_ts.ListType: assert len(args) > 0 assert all(isinstance(arg, it_ts.ListType) for arg in args) arg_el_types = [arg.element_type for arg in args] - el_type = op(*arg_el_types, offset_provider=offset_provider) + el_type = op(*arg_el_types, offset_provider_type=offset_provider_type) assert isinstance(el_type, ts.DataType) return it_ts.ListType(element_type=el_type) @@ -355,15 +357,17 @@ def applied_map( @_register_builtin_type_synthesizer def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer: @TypeSynthesizer - def applied_reduce(*args: it_ts.ListType, offset_provider: common.OffsetProvider): + def applied_reduce(*args: it_ts.ListType, offset_provider_type: common.OffsetProviderType): assert all(isinstance(arg, it_ts.ListType) for arg in args) - return op(init, *(arg.element_type for arg in args), offset_provider=offset_provider) + return op( + init, *(arg.element_type for arg in args), offset_provider_type=offset_provider_type + ) return applied_reduce @_register_builtin_type_synthesizer -def shift(*offset_literals, offset_provider: common.OffsetProvider) -> TypeSynthesizer: +def shift(*offset_literals, offset_provider_type: common.OffsetProviderType) -> TypeSynthesizer: @TypeSynthesizer def apply_shift( it: it_ts.IteratorType | ts.DeferredType, @@ -379,19 +383,19 @@ def apply_shift( assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance( offset_axis.value, common.Dimension ) - provider = offset_provider[offset_axis.value.value] # TODO: naming - if isinstance(provider, common.Dimension): + type_ = offset_provider_type[offset_axis.value.value] + if isinstance(type_, common.Dimension): pass - elif isinstance(provider, common.Connectivity): + elif isinstance(type_, common.NeighborConnectivityType): found = False for i, dim in enumerate(new_position_dims): - if dim.value == provider.origin_axis.value: + if dim.value == type_.source_dim.value: assert not found - new_position_dims[i] = provider.neighbor_axis + new_position_dims[i] = type_.codomain found = True assert found else: - raise NotImplementedError() + raise NotImplementedError(f"{type_} is not a supported Connectivity type.") return it_ts.IteratorType( position_dims=new_position_dims, defined_dims=it.defined_dims, diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 802ad2155f..69d8985beb 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -26,7 +26,6 @@ import typing from typing import Any, Iterable, Iterator, Optional -import numpy as np from typing_extensions import Self from gt4py.next import common @@ -49,47 +48,19 @@ def from_signature(cls, *args: Any, **kwargs: Any) -> Self: return cls(args=args, kwargs=kwargs) -@dataclasses.dataclass(frozen=True) -class CompileTimeConnectivity(common.Connectivity): - """Compile-time standin for a GTX connectivity, retaining everything except the connectivity tables.""" - - max_neighbors: int - has_skip_values: bool - origin_axis: common.Dimension - neighbor_axis: common.Dimension - index_type: type[int] | type[np.int32] | type[np.int64] - - def mapped_index( - self, cur_index: int | np.integer, neigh_index: int | np.integer - ) -> Optional[int | np.integer]: - raise NotImplementedError( - "A CompileTimeConnectivity instance should not call `mapped_index`." - ) - - @classmethod - def from_connectivity(cls, connectivity: common.Connectivity) -> Self: - return cls( - max_neighbors=connectivity.max_neighbors, - has_skip_values=connectivity.has_skip_values, - origin_axis=connectivity.origin_axis, - neighbor_axis=connectivity.neighbor_axis, - index_type=connectivity.index_type, - ) - - @property - def table(self) -> None: - return None - - @dataclasses.dataclass(frozen=True) class CompileTimeArgs: """Compile-time standins for arguments to a GTX program to be used in ahead-of-time compilation.""" args: tuple[ts.TypeSpec, ...] kwargs: dict[str, ts.TypeSpec] - offset_provider: dict[str, common.Connectivity | common.Dimension] + offset_provider: common.OffsetProvider # TODO(havogt): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information column_axis: Optional[common.Dimension] + @property + def offset_provider_type(self) -> common.OffsetProviderType: + return common.offset_provider_to_type(self.offset_provider) + @classmethod def from_concrete_no_size(cls, *args: Any, **kwargs: Any) -> Self: """Convert concrete GTX program arguments into their compile-time counterparts.""" @@ -98,8 +69,7 @@ def from_concrete_no_size(cls, *args: Any, **kwargs: Any) -> Self: offset_provider = kwargs_copy.pop("offset_provider", {}) return cls( args=compile_args, - offset_provider=offset_provider, # TODO(ricoh): replace with the line below once the temporaries pass is AOT-ready. If unsure, just try it and run the tests. - # offset_provider={k: connectivity_or_dimension(v) for k, v in offset_provider.items()}, # noqa: ERA001 [commented-out-code] + offset_provider=offset_provider, column_axis=kwargs_copy.pop("column_axis", None), kwargs={ k: type_translation.from_value(v) for k, v in kwargs_copy.items() if v is not None @@ -138,18 +108,6 @@ def adapted_jit_to_aot_args_factory() -> ( return toolchain.ArgsOnlyAdapter(jit_to_aot_args) -def connectivity_or_dimension( - some_offset_provider: common.Connectivity | common.Dimension, -) -> CompileTimeConnectivity | common.Dimension: - match some_offset_provider: - case common.Dimension(): - return some_offset_provider - case common.Connectivity(): - return CompileTimeConnectivity.from_connectivity(some_offset_provider) - case _: - raise ValueError - - def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]: for element in tuple_arg: match element: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py index cc57c137bf..b2aea05641 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py @@ -12,7 +12,6 @@ import gt4py.eve as eve from gt4py.eve import NodeTranslator, concepts from gt4py.eve.utils import UIDGenerator -from gt4py.next import common from gt4py.next.program_processors.codegens.gtfn import gtfn_ir, gtfn_ir_common from gt4py.next.program_processors.codegens.gtfn.gtfn_im_ir import ( AssignStmt, @@ -84,54 +83,9 @@ def _is_reduce(node: gtfn_ir.FunCall) -> TypeGuard[gtfn_ir.FunCall]: ) -def _get_connectivity( - applied_reduce_node: gtfn_ir.FunCall, - offset_provider: dict[str, common.Dimension | common.Connectivity], -) -> common.Connectivity: - """Return single connectivity that is compatible with the arguments of the reduce.""" - if not _is_reduce(applied_reduce_node): - raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") - - connectivities: list[common.Connectivity] = [] - for o in _get_partial_offset_tags(applied_reduce_node.args): - conn = offset_provider[o] - assert isinstance(conn, common.Connectivity) - connectivities.append(conn) - - if not connectivities: - raise RuntimeError("Couldn't detect partial shift in any arguments of 'reduce'.") - - if len({(c.max_neighbors, c.has_skip_values) for c in connectivities}) != 1: - # The condition for this check is required but not sufficient: the actual neighbor tables could still be incompatible. - raise RuntimeError("Arguments to 'reduce' have incompatible partial shifts.") - return connectivities[0] - - # TODO: end of code clone -def _make_dense_acess( - shift_call: gtfn_ir.FunCall, nbh_iter: gtfn_ir_common.SymRef -) -> gtfn_ir.FunCall: - return gtfn_ir.FunCall( - fun=gtfn_ir_common.SymRef(id="deref"), - args=[ - gtfn_ir.FunCall( - fun=gtfn_ir_common.SymRef(id="shift"), args=[*shift_call.args, nbh_iter] - ) - ], - ) - - -def _make_sparse_acess( - field_ref: gtfn_ir_common.SymRef, nbh_iter: gtfn_ir_common.SymRef -) -> gtfn_ir.FunCall: - return gtfn_ir.FunCall( - fun=gtfn_ir_common.SymRef(id="tuple_get"), - args=[nbh_iter, gtfn_ir.FunCall(fun=gtfn_ir_common.SymRef(id="deref"), args=[field_ref])], - ) - - class PlugInCurrentIdx(NodeTranslator): def visit_SymRef( self, node: gtfn_ir_common.SymRef @@ -225,32 +179,6 @@ def _expand_symref( ) self.imp_list_ir.append(AssignStmt(lhs=gtfn_ir_common.SymRef(id=red_idx), rhs=rhs)) - def handle_Reduction(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common.SymRef: - offset_provider = kwargs["offset_provider"] - assert offset_provider is not None - - connectivity = _get_connectivity(node, offset_provider) - - args = node.args - # do the following transformations to the node arguments - # dense fields: shift(dense_f, X2Y) -> deref(shift(dense_f, X2Y, nbh_iterator) - # sparse_fields: sparse_f -> tuple_get(nbh_iterator, deref(sparse_f))) - new_args = [] - nbh_iter = gtfn_ir_common.SymRef(id="nbh_iter") - for arg in args: - if isinstance(arg, gtfn_ir.FunCall) and arg.fun.id == "shift": # type: ignore - new_args.append(_make_dense_acess(arg, nbh_iter)) - if isinstance(arg, gtfn_ir_common.SymRef): - new_args.append(_make_sparse_acess(arg, nbh_iter)) - - red_idx = self.uids.sequential_id(prefix="red") - if isinstance(node.fun.args[0], gtfn_ir.Lambda): # type: ignore - self._expand_lambda(node, new_args, red_idx, connectivity.max_neighbors, **kwargs) - elif isinstance(node.fun.args[0], gtfn_ir_common.SymRef): # type: ignore - self._expand_symref(node, new_args, red_idx, connectivity.max_neighbors, **kwargs) - - return gtfn_ir_common.SymRef(id=red_idx) - def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common.Expr: if any(isinstance(arg, gtfn_ir.Lambda) for arg in node.args): # do not try to lower constructs that take lambdas as argument to something more readable @@ -278,7 +206,9 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common. self.imp_list_ir.append(InitStmt(lhs=gtfn_ir_common.Sym(id=f"{lam_idx}"), rhs=expr)) return gtfn_ir_common.SymRef(id=f"{lam_idx}") if _is_reduce(node): - return self.handle_Reduction(node, **kwargs) + raise AssertionError( + "Not implemented. The code-path was removed as it was not actively used and tested." + ) if isinstance(node.fun, gtfn_ir_common.SymRef) and node.fun.id == "make_tuple": tupl_id = self.uids.sequential_id(prefix="tupl") tuple_fun = self.commit_args(node, tupl_id, "make_tuple", **kwargs) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index ce459f7970..f1649112a7 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -82,7 +82,7 @@ def _process_regular_arguments( self, program: itir.FencilDefinition | itir.Program, arg_types: tuple[ts.TypeSpec, ...], - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> tuple[list[interface.Parameter], list[str]]: parameters: list[interface.Parameter] = [] arg_exprs: list[str] = [] @@ -104,22 +104,22 @@ def _process_regular_arguments( ): # translate sparse dimensions to tuple dtype dim_name = dim.value - connectivity = offset_provider[dim_name] - assert isinstance(connectivity, common.Connectivity) + connectivity = offset_provider_type[dim_name] + assert isinstance(connectivity, common.NeighborConnectivityType) size = connectivity.max_neighbors arg = f"gridtools::sid::dimension_to_tuple_like({arg})" arg_exprs.append(arg) return parameters, arg_exprs def _process_connectivity_args( - self, offset_provider: dict[str, common.Connectivity | common.Dimension] + self, offset_provider_type: common.OffsetProviderType ) -> tuple[list[interface.Parameter], list[str]]: parameters: list[interface.Parameter] = [] arg_exprs: list[str] = [] - for name, connectivity in offset_provider.items(): - if isinstance(connectivity, common.Connectivity): - if connectivity.index_type not in [np.int32, np.int64]: + for name, connectivity_type in offset_provider_type.items(): + if isinstance(connectivity_type, common.NeighborConnectivityType): + if connectivity_type.dtype.scalar_type not in [np.int32, np.int64]: raise ValueError( "Neighbor table indices must be of type 'np.int32' or 'np.int64'." ) @@ -129,15 +129,8 @@ def _process_connectivity_args( interface.Parameter( name=GENERATED_CONNECTIVITY_PARAM_PREFIX + name.lower(), type_=ts.FieldType( - dims=[ - connectivity.origin_axis, - common.Dimension( - name, kind=common.DimensionKind.LOCAL - ), # TODO(havogt): we should not use the name of the offset as the name of the local dimension - ], - dtype=ts.ScalarType( - type_translation.get_scalar_kind(connectivity.index_type) - ), + dims=list(connectivity_type.domain), + dtype=type_translation.from_dtype(connectivity_type.dtype), ), ) ) @@ -145,19 +138,19 @@ def _process_connectivity_args( # connectivity argument expression nbtbl = ( f"gridtools::fn::sid_neighbor_table::as_neighbor_table<" - f"generated::{connectivity.origin_axis.value}_t, " - f"generated::{name}_t, {connectivity.max_neighbors}" + f"generated::{connectivity_type.source_dim.value}_t, " + f"generated::{name}_t, {connectivity_type.max_neighbors}" f">(std::forward({GENERATED_CONNECTIVITY_PARAM_PREFIX}{name.lower()}))" ) arg_exprs.append( f"gridtools::hymap::keys::make_values({nbtbl})" ) - elif isinstance(connectivity, common.Dimension): + elif isinstance(connectivity_type, common.Dimension): pass else: raise AssertionError( - f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', " - f"got '{type(connectivity).__name__}'." + f"Expected offset provider type '{name}' to be a 'NeighborConnectivityType' or 'Dimension', " + f"got '{type(connectivity_type).__name__}'." ) return parameters, arg_exprs @@ -165,7 +158,7 @@ def _process_connectivity_args( def _preprocess_program( self, program: itir.FencilDefinition | itir.Program, - offset_provider: dict[str, common.Connectivity | common.Dimension], + offset_provider: common.OffsetProvider, ) -> itir.Program: apply_common_transforms = functools.partial( pass_manager.apply_common_transforms, @@ -194,7 +187,7 @@ def _preprocess_program( def generate_stencil_source( self, program: itir.FencilDefinition | itir.Program, - offset_provider: dict[str, common.Connectivity | common.Dimension], + offset_provider: common.OffsetProvider, column_axis: Optional[common.Dimension], ) -> str: if self.enable_itir_transforms: @@ -204,7 +197,9 @@ def generate_stencil_source( new_program = program gtfn_ir = GTFN_lowering.apply( - new_program, offset_provider=offset_provider, column_axis=column_axis + new_program, + offset_provider_type=common.offset_provider_to_type(offset_provider), + column_axis=column_axis, ) if self.use_imperative_backend: @@ -224,13 +219,13 @@ def __call__( # handle regular parameters and arguments of the program (i.e. what the user defined in # the program) regular_parameters, regular_args_expr = self._process_regular_arguments( - program, inp.args.args, inp.args.offset_provider + program, inp.args.args, inp.args.offset_provider_type ) # handle connectivity parameters and arguments (i.e. what the user provided in the offset # provider) connectivity_parameters, connectivity_args_expr = self._process_connectivity_args( - inp.args.offset_provider + inp.args.offset_provider_type ) # combine into a format that is aligned with what the backend expects diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index bc2bd645e8..129d81d6f9 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -159,7 +159,7 @@ def _collect_dimensions_from_domain( def _collect_offset_definitions( node: itir.Node, grid_type: common.GridType, - offset_provider: dict[str, common.Dimension | common.Connectivity], + offset_provider_type: common.OffsetProviderType, ) -> dict[str, TagDefinition]: used_offset_tags: set[itir.OffsetLiteral] = ( node.walk_values() @@ -167,13 +167,13 @@ def _collect_offset_definitions( .filter(lambda offset_literal: isinstance(offset_literal.value, str)) .getattr("value") ).to_set() - if not used_offset_tags.issubset(set(offset_provider.keys())): + if not used_offset_tags.issubset(set(offset_provider_type.keys())): raise AssertionError("ITIR contains an offset tag without a corresponding offset provider.") offset_definitions = {} - for offset_name, dim_or_connectivity in offset_provider.items(): - if isinstance(dim_or_connectivity, common.Dimension): - dim: common.Dimension = dim_or_connectivity + for offset_name, dim_or_connectivity_type in offset_provider_type.items(): + if isinstance(dim_or_connectivity_type, common.Dimension): + dim: common.Dimension = dim_or_connectivity_type if grid_type == common.GridType.CARTESIAN: # create alias from offset to dimension offset_definitions[dim.value] = TagDefinition(name=Sym(id=dim.value)) @@ -201,12 +201,13 @@ def _collect_offset_definitions( offset_definitions[offset_name] = TagDefinition( name=Sym(id=offset_name), alias=SymRef(id=dim.value) ) - elif isinstance(dim_or_connectivity, common.Connectivity): + elif isinstance( + connectivity_type := dim_or_connectivity_type, common.NeighborConnectivityType + ): assert grid_type == common.GridType.UNSTRUCTURED offset_definitions[offset_name] = TagDefinition(name=Sym(id=offset_name)) - connectivity: common.Connectivity = dim_or_connectivity - for dim in [connectivity.origin_axis, connectivity.neighbor_axis]: + for dim in [connectivity_type.source_dim, connectivity_type.codomain]: if dim.kind != common.DimensionKind.HORIZONTAL: raise NotImplementedError() offset_definitions[dim.value] = TagDefinition( @@ -323,7 +324,7 @@ class GTFN_lowering(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): } _unary_op_map: ClassVar[dict[str, str]] = {"not_": "!"} - offset_provider: dict + offset_provider_type: common.OffsetProviderType column_axis: Optional[common.Dimension] grid_type: common.GridType @@ -338,18 +339,18 @@ def apply( cls, node: itir.Program, *, - offset_provider: dict, + offset_provider_type: common.OffsetProviderType, column_axis: Optional[common.Dimension], ) -> Program: if not isinstance(node, itir.Program): raise TypeError(f"Expected a 'Program', got '{type(node).__name__}'.") - node = itir_type_inference.infer(node, offset_provider=offset_provider) + node = itir_type_inference.infer(node, offset_provider_type=offset_provider_type) grid_type = _get_gridtype(node.body) if grid_type == common.GridType.UNSTRUCTURED: node = _CannonicalizeUnstructuredDomain.apply(node) return cls( - offset_provider=offset_provider, column_axis=column_axis, grid_type=grid_type + offset_provider_type=offset_provider_type, column_axis=column_axis, grid_type=grid_type ).visit(node) def visit_Sym(self, node: itir.Sym, **kwargs: Any) -> Sym: @@ -484,8 +485,8 @@ def _visit_unstructured_domain(self, node: itir.FunCall, **kwargs: Any) -> Node: if "stencil" in kwargs: shift_offsets = self._collect_offset_or_axis_node(itir.OffsetLiteral, kwargs["stencil"]) for o in shift_offsets: - if o in self.offset_provider and isinstance( - self.offset_provider[o], common.Connectivity + if o in self.offset_provider_type and isinstance( + self.offset_provider_type[o], common.NeighborConnectivityType ): connectivities.append(SymRef(id=o)) return UnstructuredDomain( @@ -679,7 +680,7 @@ def visit_Program(self, node: itir.Program, **kwargs: Any) -> Program: function_definitions = self.visit(node.function_definitions) + extracted_functions offset_definitions = { **_collect_dimensions_from_domain(node.body), - **_collect_offset_definitions(node, self.grid_type, self.offset_provider), + **_collect_offset_definitions(node, self.grid_type, self.offset_provider_type), } return Program( id=SymbolName(node.id), diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index db0df7d121..56ba08015b 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -12,6 +12,7 @@ import dace import numpy as np +from gt4py._core import definitions as core_defs from gt4py.next import common as gtx_common, utils as gtx_utils from . import utility as dace_utils @@ -65,8 +66,8 @@ def _get_args( def _ensure_is_on_device( - connectivity_arg: np.typing.NDArray, device: dace.dtypes.DeviceType -) -> np.typing.NDArray: + connectivity_arg: core_defs.NDArrayObject, device: dace.dtypes.DeviceType +) -> core_defs.NDArrayObject: if device == dace.dtypes.DeviceType.GPU: if not isinstance(connectivity_arg, cp.ndarray): warnings.warn( @@ -78,7 +79,7 @@ def _ensure_is_on_device( def _get_shape_args( - arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray] + arrays: Mapping[str, dace.data.Array], args: Mapping[str, core_defs.NDArrayObject] ) -> dict[str, int]: shape_args: dict[str, int] = {} for name, value in args.items(): @@ -103,7 +104,7 @@ def _get_shape_args( def _get_stride_args( - arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray] + arrays: Mapping[str, dace.data.Array], args: Mapping[str, core_defs.NDArrayObject] ) -> dict[str, int]: stride_args = {} for name, value in args.items(): @@ -134,7 +135,7 @@ def get_sdfg_conn_args( sdfg: dace.SDFG, offset_provider: gtx_common.OffsetProvider, on_gpu: bool, -) -> dict[str, np.typing.NDArray]: +) -> dict[str, core_defs.NDArrayObject]: """ Extracts the connectivity tables that are used in the sdfg and ensures that the memory buffers are allocated for the target device. @@ -142,11 +143,11 @@ def get_sdfg_conn_args( device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU connectivity_args = {} - for offset, connectivity in dace_utils.filter_connectivities(offset_provider).items(): - assert isinstance(connectivity, gtx_common.NeighborTable) - param = dace_utils.connectivity_identifier(offset) - if param in sdfg.arrays: - connectivity_args[param] = _ensure_is_on_device(connectivity.table, device) + for offset, connectivity in offset_provider.items(): + if gtx_common.is_neighbor_table(connectivity): + param = dace_utils.connectivity_identifier(offset) + if param in sdfg.arrays: + connectivity_args[param] = _ensure_is_on_device(connectivity.ndarray, device) return connectivity_args diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index bc01e2abda..29395a30c1 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -79,19 +79,18 @@ def debug_info( return default -def filter_connectivities( - offset_provider: gtx_common.OffsetProvider, -) -> dict[str, gtx_common.Connectivity]: +def filter_connectivity_types( + offset_provider_type: gtx_common.OffsetProviderType, +) -> dict[str, gtx_common.NeighborConnectivityType]: """ - Filter offset providers of type `Connectivity`. + Filter offset provider types of type `NeighborConnectivityType`. In other words, filter out the cartesian offset providers. - Returns a new dictionary containing only `Connectivity` values. """ return { - offset: table - for offset, table in offset_provider.items() - if isinstance(table, gtx_common.Connectivity) + offset: conn + for offset, conn in offset_provider_type.items() + if isinstance(conn, gtx_common.NeighborConnectivityType) } diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 73b6e2ed4c..74142dec66 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -527,14 +527,14 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(node.args[0], gtir.OffsetLiteral) offset = node.args[0].value assert isinstance(offset, str) - offset_provider = self.subgraph_builder.get_offset_provider(offset) - assert isinstance(offset_provider, gtx_common.Connectivity) + offset_provider = self.subgraph_builder.get_offset_provider_type(offset) + assert isinstance(offset_provider, gtx_common.NeighborConnectivityType) it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) - assert offset_provider.neighbor_axis in it.dimensions - assert offset_provider.origin_axis in it.indices - origin_index = it.indices[offset_provider.origin_axis] + assert offset_provider.codomain in it.dimensions + assert offset_provider.source_dim in it.indices + origin_index = it.indices[offset_provider.source_dim] assert isinstance(origin_index, SymbolExpr) assert all(isinstance(index, SymbolExpr) for index in it.indices.values()) @@ -561,7 +561,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: subset=sbs.Range.from_string( ",".join( it.indices[dim].value # type: ignore[union-attr] - if dim != offset_provider.neighbor_axis + if dim != offset_provider.codomain else f"0:{size}" for dim, size in zip(it.dimensions, field_desc.shape, strict=True) ) @@ -657,7 +657,9 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: tasklet_expression = f"{output_connector} = {fun_python_code}" input_args = [self.visit(arg) for arg in node.args] - input_connectivities: dict[gtx_common.Dimension, gtx_common.Connectivity] = {} + input_connectivity_types: dict[ + gtx_common.Dimension, gtx_common.NeighborConnectivityType + ] = {} for input_arg in input_args: assert isinstance(input_arg.gt_dtype, itir_ts.ListType) assert input_arg.gt_dtype.offset_type is not None @@ -665,11 +667,11 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: if offset_type == _CONST_DIM: # this input argument is the result of `make_const_list` continue - offset_provider = self.subgraph_builder.get_offset_provider(offset_type.value) - assert isinstance(offset_provider, gtx_common.Connectivity) - input_connectivities[offset_type] = offset_provider + offset_provider_t = self.subgraph_builder.get_offset_provider_type(offset_type.value) + assert isinstance(offset_provider_t, gtx_common.NeighborConnectivityType) + input_connectivity_types[offset_type] = offset_provider_t - if len(input_connectivities) == 0: + if len(input_connectivity_types) == 0: raise ValueError(f"Missing information on local dimension for map node {node}.") # GT4Py guarantees that all connectivities used to generate lists of neighbors @@ -678,14 +680,14 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: len( set( (conn.has_skip_values, conn.max_neighbors) - for conn in input_connectivities.values() + for conn in input_connectivity_types.values() ) ) != 1 ): raise ValueError("Unexpected arguments to map expression with different neighborhood.") - offset_type, offset_provider = next(iter(input_connectivities.items())) - local_size = offset_provider.max_neighbors + offset_type, offset_provider_type = next(iter(input_connectivity_types.items())) + local_size = offset_provider_type.max_neighbors map_index = dace_gtir_utils.get_map_variable(offset_type) # The dataflow we build in this class has some loose connections on input edges. @@ -717,14 +719,14 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: result, _ = self.sdfg.add_temp_transient((local_size,), dc_dtype) result_node = self.state.add_access(result) - if offset_provider.has_skip_values: + if offset_provider_type.has_skip_values: # In case the `map_` input expressions contain skip values, we use # the connectivity-based offset provider as mask for map computation. connectivity = dace_utils.connectivity_identifier(offset_type.value) connectivity_desc = self.sdfg.arrays[connectivity] connectivity_desc.transient = False - origin_map_index = dace_gtir_utils.get_map_variable(offset_provider.origin_axis) + origin_map_index = dace_gtir_utils.get_map_variable(offset_provider_type.source_dim) connectivity_slice = self._construct_local_view( MemletExpr( @@ -733,7 +735,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: element_type=node.type.element_type, offset_type=offset_type ), subset=sbs.Range.from_string( - f"{origin_map_index}, 0:{offset_provider.max_neighbors}" + f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" ), ) ) @@ -774,7 +776,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: def _make_reduce_with_skip_values( self, input_expr: ValueExpr | MemletExpr, - offset_provider: gtx_common.Connectivity, + offset_provider_type: gtx_common.NeighborConnectivityType, reduce_init: SymbolExpr, reduce_identity: SymbolExpr, reduce_wcr: str, @@ -792,7 +794,7 @@ def _make_reduce_with_skip_values( corresponding neighbor index in the connectivity table is valid, or the identity value if the neighbor index is missing. """ - origin_map_index = dace_gtir_utils.get_map_variable(offset_provider.origin_axis) + origin_map_index = dace_gtir_utils.get_map_variable(offset_provider_type.source_dim) assert ( isinstance(input_expr.gt_dtype, itir_ts.ListType) @@ -815,7 +817,7 @@ def _make_reduce_with_skip_values( f"Found {len(local_dim_indices)} local dimensions in reduce expression, expected one." ) local_dim_index = local_dim_indices[0] - assert desc.shape[local_dim_index] == offset_provider.max_neighbors + assert desc.shape[local_dim_index] == offset_provider_type.max_neighbors # we lower the reduction map with WCR out memlet in a nested SDFG nsdfg = dace.SDFG(name=self.unique_nsdfg_name("reduce_with_skip_values")) @@ -853,7 +855,7 @@ def _make_reduce_with_skip_values( # TODO(phimuell): decide if auto-optimizer should reset `wcr_nonatomic` properties, as DaCe does. st_reduce.add_mapped_tasklet( name="reduce_with_skip_values", - map_ranges={"i": f"0:{offset_provider.max_neighbors}"}, + map_ranges={"i": f"0:{offset_provider_type.max_neighbors}"}, inputs={ "__val": dace.Memlet(data="values", subset="i"), "__neighbor_idx": dace.Memlet(data="neighbor_indices", subset="i"), @@ -882,7 +884,7 @@ def _make_reduce_with_skip_values( ) self._add_input_data_edge( connectivity_node, - sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider.max_neighbors}"), + sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}"), nsdfg_node, "neighbor_indices", ) @@ -910,12 +912,17 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: and input_expr.gt_dtype.offset_type is not None ) offset_type = input_expr.gt_dtype.offset_type - offset_provider = self.subgraph_builder.get_offset_provider(offset_type.value) - assert isinstance(offset_provider, gtx_common.Connectivity) + offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset_type.value) + assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) - if offset_provider.has_skip_values: + if offset_provider_type.has_skip_values: self._make_reduce_with_skip_values( - input_expr, offset_provider, reduce_init, reduce_identity, reduce_wcr, result_node + input_expr, + offset_provider_type, + reduce_init, + reduce_identity, + reduce_wcr, + result_node, ) else: @@ -1082,16 +1089,16 @@ def _make_dynamic_neighbor_offset( def _make_unstructured_shift( self, it: IteratorExpr, - connectivity: gtx_common.Connectivity, + connectivity: gtx_common.NeighborConnectivityType, offset_table_node: dace.nodes.AccessNode, offset_expr: DataExpr, ) -> IteratorExpr: """Implements shift in unstructured domain by means of a neighbor table.""" - assert connectivity.neighbor_axis in it.dimensions - neighbor_dim = connectivity.neighbor_axis + assert connectivity.codomain in it.dimensions + neighbor_dim = connectivity.codomain assert neighbor_dim not in it.indices - origin_dim = connectivity.origin_axis + origin_dim = connectivity.source_dim assert origin_dim in it.indices origin_index = it.indices[origin_dim] assert isinstance(origin_index, SymbolExpr) @@ -1132,7 +1139,7 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: assert isinstance(offset_provider_arg, gtir.OffsetLiteral) offset = offset_provider_arg.value assert isinstance(offset, str) - offset_provider = self.subgraph_builder.get_offset_provider(offset) + offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset) # second argument should be the offset value, which could be a symbolic expression or a dynamic offset offset_expr = ( SymbolExpr(offset_value_arg.value, IndexDType) @@ -1140,8 +1147,8 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: else self.visit(offset_value_arg) ) - if isinstance(offset_provider, gtx_common.Dimension): - return self._make_cartesian_shift(it, offset_provider, offset_expr) + if isinstance(offset_provider_type, gtx_common.Dimension): + return self._make_cartesian_shift(it, offset_provider_type, offset_expr) else: # initially, the storage for the connectivity tables is created as transient; # when the tables are used, the storage is changed to non-transient, @@ -1151,7 +1158,7 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: offset_table_node = self.state.add_access(offset_table) return self._make_unstructured_shift( - it, offset_provider, offset_table_node, offset_expr + it, offset_provider_type, offset_table_node, offset_expr ) def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index ad8f490f12..52284edfac 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -41,7 +41,7 @@ class DataflowBuilder(Protocol): """Visitor interface to build a dataflow subgraph.""" @abc.abstractmethod - def get_offset_provider(self, offset: str) -> gtx_common.OffsetProviderElem: ... + def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderTypeElem: ... @abc.abstractmethod def unique_nsdfg_name(self, sdfg: dace.SDFG, prefix: str) -> str: ... @@ -155,7 +155,7 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): from where to continue building the SDFG. """ - offset_provider: gtx_common.OffsetProvider + offset_provider_type: gtx_common.OffsetProviderType global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {}) map_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") @@ -164,8 +164,8 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="tlet") ) - def get_offset_provider(self, offset: str) -> gtx_common.OffsetProviderElem: - return self.offset_provider[offset] + def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderTypeElem: + return self.offset_provider_type[offset] def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.global_symbols[symbol_name] @@ -195,10 +195,10 @@ def _make_array_shape_and_strides( Two lists of symbols, one for the shape and the other for the strides of the array. """ dc_dtype = gtir_builtin_translators.INDEX_DTYPE - neighbor_tables = dace_utils.filter_connectivities(self.offset_provider) + neighbor_table_types = dace_utils.filter_connectivity_types(self.offset_provider_type) shape = [ ( - neighbor_tables[dim.value].max_neighbors + neighbor_table_types[dim.value].max_neighbors if dim.kind == gtx_common.DimensionKind.LOCAL else dace.symbol(dace_utils.field_size_symbol_name(name, i), dc_dtype) ) @@ -374,13 +374,12 @@ def _add_sdfg_params( self.global_symbols[pname] = param.type # add SDFG storage for connectivity tables - for offset, offset_provider in dace_utils.filter_connectivities( - self.offset_provider + for offset, connectivity_type in dace_utils.filter_connectivity_types( + self.offset_provider_type ).items(): - scalar_kind = tt.get_scalar_kind(offset_provider.index_type) - local_dim = gtx_common.Dimension(offset, kind=gtx_common.DimensionKind.LOCAL) + scalar_type = tt.from_dtype(connectivity_type.dtype) gt_type = ts.FieldType( - [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) + [connectivity_type.source_dim, connectivity_type.neighbor_dim], scalar_type ) # We store all connectivity tables as transient arrays here; later, while building # the field operator expressions, we change to non-transient (i.e. allocated externally) @@ -585,7 +584,7 @@ def visit_Lambda( } # lower let-statement lambda node as a nested SDFG - lambda_translator = GTIRToSDFG(self.offset_provider, lambda_symbols) + lambda_translator = GTIRToSDFG(self.offset_provider_type, lambda_symbols) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) nstate = nsdfg.add_state("lambda") @@ -630,7 +629,7 @@ def _flatten_tuples( ) connectivity_arrays = { dace_utils.connectivity_identifier(offset) - for offset in dace_utils.filter_connectivities(self.offset_provider) + for offset in dace_utils.filter_connectivity_types(self.offset_provider_type) } input_memlets = {} @@ -778,7 +777,7 @@ def visit_SymRef( def build_sdfg_from_gtir( ir: gtir.Program, - offset_provider: gtx_common.OffsetProvider, + offset_provider_type: gtx_common.OffsetProviderType, ) -> dace.SDFG: """ Receives a GTIR program and lowers it to a DaCe SDFG. @@ -788,15 +787,15 @@ def build_sdfg_from_gtir( Args: ir: The GTIR program node to be lowered to SDFG - offset_provider: The definitions of offset providers used by the program node + offset_provider_type: The definitions of offset providers used by the program node Returns: An SDFG in the DaCe canonical form (simplified) """ - ir = gtir_type_inference.infer(ir, offset_provider=offset_provider) + ir = gtir_type_inference.infer(ir, offset_provider_type=offset_provider_type) ir = ir_prune_casts.PruneCasts().visit(ir) - sdfg_genenerator = GTIRToSDFG(offset_provider) + sdfg_genenerator = GTIRToSDFG(offset_provider_type) sdfg = sdfg_genenerator.visit(ir) assert isinstance(sdfg, dace.SDFG) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index aa4fd0cd3e..40d44f5ab0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -52,7 +52,9 @@ def generate_sdfg( on_gpu: bool, ) -> dace.SDFG: ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) - sdfg = gtir_sdfg.build_sdfg_from_gtir(ir, offset_provider=offset_provider) + sdfg = gtir_sdfg.build_sdfg_from_gtir( + ir, offset_provider_type=common.offset_provider_to_type(offset_provider) + ) if auto_opt: gtx_transformations.gt_auto_optimize(sdfg, gpu=on_gpu) @@ -75,7 +77,7 @@ def __call__( sdfg = self.generate_sdfg( program, - inp.args.offset_provider, + inp.args.offset_provider, # TODO(havogt): should be offset_provider_type once the transformation don't require run-time info inp.args.column_axis, auto_opt=self.auto_optimize, on_gpu=(self.device_type == gtx_allocators.CUPY_DEVICE), diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index fc2772027e..ef09cf51cd 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -9,7 +9,7 @@ import dataclasses import warnings from collections import OrderedDict -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Sequence from dataclasses import field from inspect import currentframe, getframeinfo from pathlib import Path @@ -38,7 +38,7 @@ def preprocess_program( program: itir.FencilDefinition, - offset_provider: Mapping[str, Any], + offset_provider_type: common.OffsetProviderType, lift_mode: legacy_itir_transforms.LiftMode, symbolic_domain_sizes: Optional[dict[str, str]] = None, temporary_extraction_heuristics: Optional[ @@ -51,13 +51,13 @@ def preprocess_program( common_subexpression_elimination=False, force_inline_lambda_args=True, lift_mode=lift_mode, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, symbolic_domain_sizes=symbolic_domain_sizes, temporary_extraction_heuristics=temporary_extraction_heuristics, unroll_reduce=unroll_reduce, ) - node = itir_type_inference.infer(node, offset_provider=offset_provider) + node = itir_type_inference.infer(node, offset_provider_type=offset_provider_type) if isinstance(node, itir.Program): fencil_definition = program_to_fencil.program_to_fencil(node) @@ -72,7 +72,7 @@ def preprocess_program( def build_sdfg_from_itir( program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], - offset_provider: dict[str, Any], + offset_provider_type: common.OffsetProviderType, auto_optimize: bool = False, on_gpu: bool = False, column_axis: Optional[common.Dimension] = None, @@ -109,10 +109,18 @@ def build_sdfg_from_itir( # visit ITIR and generate SDFG program, tmps = preprocess_program( - program, offset_provider, lift_mode, symbolic_domain_sizes, temporary_extraction_heuristics + program, + offset_provider_type, + lift_mode, + symbolic_domain_sizes, + temporary_extraction_heuristics, ) sdfg_genenerator = ItirToSDFG( - list(arg_types), offset_provider, tmps, use_field_canonical_representation, column_axis + list(arg_types), + offset_provider_type, + tmps, + use_field_canonical_representation, + column_axis, ) sdfg = sdfg_genenerator.visit(program) if sdfg is None: @@ -186,14 +194,12 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: raise ValueError( "[DaCe Orchestration] Connectivities -at compile time- are required to generate the SDFG. Use `with_connectivities` method." ) - offset_provider = ( - self.connectivities | self._implicit_offset_provider - ) # tables are None at this point + offset_provider_type = {**self.connectivities, **self._implicit_offset_provider} sdfg = self.backend.executor.step.translation.generate_sdfg( # type: ignore[union-attr] self.itir, arg_types, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, column_axis=kwargs.get("column_axis", None), ) self.sdfg_closure_vars["sdfg.arrays"] = sdfg.arrays # use it in __sdfg_closure__ @@ -238,7 +244,7 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: sdfg.offset_providers_per_input_field = {} itir_tmp = legacy_itir_transforms.apply_common_transforms( - self.itir, offset_provider=offset_provider + self.itir, offset_provider_type=offset_provider_type ) itir_tmp_fencil = program_to_fencil.program_to_fencil(itir_tmp) for closure in itir_tmp_fencil.closures: @@ -267,7 +273,7 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ the offset providers are not part of GT4Py Program's arguments. Keep in mind, that `__sdfg_closure__` is called after `__sdfg__` method. """ - offset_provider = self.connectivities + offset_provider_type = self.connectivities # Define DaCe symbols connectivity_table_size_symbols = { @@ -276,9 +282,9 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ ): dace.symbol( dace_utils.field_size_symbol_name(dace_utils.connectivity_identifier(k), axis) ) - for k, v in offset_provider.items() # type: ignore[union-attr] + for k, v in offset_provider_type.items() # type: ignore[union-attr] for axis in [0, 1] - if hasattr(v, "table") + if isinstance(v, common.NeighborConnectivityType) and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] } @@ -288,9 +294,9 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ ): dace.symbol( dace_utils.field_stride_symbol_name(dace_utils.connectivity_identifier(k), axis) ) - for k, v in offset_provider.items() # type: ignore[union-attr] + for k, v in offset_provider_type.items() # type: ignore[union-attr] for axis in [0, 1] - if hasattr(v, "table") + if isinstance(v, common.NeighborConnectivityType) and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] } @@ -298,8 +304,8 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ # Define the storage location (e.g. CPU, GPU) of the connectivity tables if "storage" not in Program.connectivity_tables_data_descriptors: - for k, v in offset_provider.items(): # type: ignore[union-attr] - if not hasattr(v, "table"): + for k, v in offset_provider_type.items(): # type: ignore[union-attr] + if not isinstance(v, common.NeighborConnectivityType): continue if dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"]: Program.connectivity_tables_data_descriptors["storage"] = ( @@ -311,12 +317,15 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ # Build the closure dictionary closure_dict = {} - for k, v in offset_provider.items(): # type: ignore[union-attr] + for k, v in offset_provider_type.items(): # type: ignore[union-attr] conn_id = dace_utils.connectivity_identifier(k) - if hasattr(v, "table") and conn_id in self.sdfg_closure_vars["sdfg.arrays"]: + if ( + isinstance(v, common.NeighborConnectivityType) + and conn_id in self.sdfg_closure_vars["sdfg.arrays"] + ): if conn_id not in Program.connectivity_tables_data_descriptors: Program.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( - dtype=dace.int64 if v.index_type == np.int64 else dace.int32, + dtype=dace.int64 if v.dtype.scalar_type == np.int64 else dace.int32, shape=[ symbols[dace_utils.field_size_symbol_name(conn_id, 0)], symbols[dace_utils.field_size_symbol_name(conn_id, 1)], diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index a0f4b83d35..823943cfd5 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -7,14 +7,13 @@ # SPDX-License-Identifier: BSD-3-Clause import warnings -from typing import Any, Mapping, Optional, Sequence, cast +from typing import Optional, Sequence, cast import dace from dace.sdfg.state import LoopRegion import gt4py.eve as eve -from gt4py.next import Dimension, DimensionKind -from gt4py.next.common import Connectivity +from gt4py.next import Dimension, DimensionKind, common from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef @@ -91,7 +90,10 @@ def _get_scan_dim( def _make_array_shape_and_strides( - name: str, dims: Sequence[Dimension], offset_provider: Mapping[str, Any], sort_dims: bool + name: str, + dims: Sequence[Dimension], + offset_provider_type: common.OffsetProviderType, + sort_dims: bool, ) -> tuple[list[dace.symbol], list[dace.symbol]]: """ Parse field dimensions and allocate symbols for array shape and strides. @@ -106,10 +108,10 @@ def _make_array_shape_and_strides( """ dtype = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) sorted_dims = dace_utils.get_sorted_dims(dims) if sort_dims else list(enumerate(dims)) - neighbor_tables = dace_utils.filter_connectivities(offset_provider) + connectivity_types = dace_utils.filter_connectivity_types(offset_provider_type) shape = [ ( - neighbor_tables[dim.value].max_neighbors + connectivity_types[dim.value].max_neighbors if dim.kind == DimensionKind.LOCAL # we reuse the same gt4py symbol for field size passed as scalar argument which is used in closure domain else dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) @@ -144,21 +146,21 @@ class ItirToSDFG(eve.NodeVisitor): param_types: list[ts.TypeSpec] storage_types: dict[str, ts.TypeSpec] column_axis: Optional[Dimension] - offset_provider: dict[str, Any] + offset_provider_type: common.OffsetProviderType unique_id: int use_field_canonical_representation: bool def __init__( self, param_types: list[ts.TypeSpec], - offset_provider: dict[str, Connectivity | Dimension], + offset_provider_type: common.OffsetProviderType, tmps: list[itir.Temporary], use_field_canonical_representation: bool, column_axis: Optional[Dimension] = None, ): self.param_types = param_types self.column_axis = column_axis - self.offset_provider = offset_provider + self.offset_provider_type = offset_provider_type self.storage_types = {} self.tmps = tmps self.use_field_canonical_representation = use_field_canonical_representation @@ -166,7 +168,7 @@ def __init__( def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, sort_dimensions: bool): if isinstance(type_, ts.FieldType): shape, strides = _make_array_shape_and_strides( - name, type_.dims, self.offset_provider, sort_dimensions + name, type_.dims, self.offset_provider_type, sort_dimensions ) dtype = dace_utils.as_dace_type(type_.dtype) sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) @@ -255,7 +257,7 @@ def get_output_nodes( # Visit output node again to generate the corresponding tasklet context = Context(sdfg, state, output_symbols_pass.symbol_refs) translator = PythonTaskletCodegen( - self.offset_provider, context, self.use_field_canonical_representation + self.offset_provider_type, context, self.use_field_canonical_representation ) output_nodes = flatten_list(translator.visit(closure.output)) return {node.value.data: node.value for node in output_nodes} @@ -266,7 +268,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): entry_state = program_sdfg.add_state("program_entry", is_start_block=True) # Filter neighbor tables from offset providers. - neighbor_tables = get_used_connectivities(node, self.offset_provider) + connectivity_types = get_used_connectivities(node, self.offset_provider_type) # Add program parameters as SDFG storages. for param, type_ in zip(node.params, self.param_types): @@ -285,11 +287,10 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): last_state = entry_state # Add connectivities as SDFG storages. - for offset, offset_provider in neighbor_tables.items(): - scalar_kind = tt.get_scalar_kind(offset_provider.index_type) - local_dim = Dimension(offset, kind=DimensionKind.LOCAL) + for offset, connectivity_type in connectivity_types.items(): + scalar_type = tt.from_dtype(connectivity_type.dtype) type_ = ts.FieldType( - [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) + [connectivity_type.source_dim, connectivity_type.neighbor_dim], scalar_type ) self.add_storage( program_sdfg, @@ -362,7 +363,7 @@ def visit_StencilClosure( isinstance(inp, SymRef) for inp in node.inputs ) # backend only supports SymRef inputs, not `index` calls input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert - neighbor_tables = get_used_connectivities(node, self.offset_provider) + neighbor_tables = get_used_connectivities(node, self.offset_provider_type) connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() ] @@ -568,7 +569,7 @@ def _visit_scan_stencil_closure( ) assert isinstance(node.output, SymRef) - neighbor_tables = get_used_connectivities(node, self.offset_provider) + neighbor_tables = get_used_connectivities(node, self.offset_provider_type) assert all( isinstance(inp, SymRef) for inp in node.inputs ) # backend only supports SymRef inputs, not `index` calls @@ -673,7 +674,7 @@ def _visit_scan_stencil_closure( connectivity_arrays = [(scan_sdfg.arrays[name], name) for name in connectivity_names] lambda_context, lambda_outputs = closure_to_tasklet_sdfg( node, - self.offset_provider, + self.offset_provider_type, lambda_domain, input_arrays, connectivity_arrays, @@ -738,7 +739,7 @@ def _visit_parallel_stencil_closure( tuple[str, tuple[ValueExpr | SymbolExpr, ValueExpr | SymbolExpr]], ... ], ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]: - neighbor_tables = get_used_connectivities(node, self.offset_provider) + neighbor_tables = get_used_connectivities(node, self.offset_provider_type) assert all( isinstance(inp, SymRef) for inp in node.inputs ) # backend only supports SymRef inputs, not `index` calls @@ -762,7 +763,7 @@ def _visit_parallel_stencil_closure( context, results = closure_to_tasklet_sdfg( node, - self.offset_provider, + self.offset_provider_type, index_domain, input_arrays, connectivity_arrays, @@ -788,7 +789,7 @@ def _visit_domain( lower_bound = named_range.args[1] upper_bound = named_range.args[2] translator = PythonTaskletCodegen( - self.offset_provider, + self.offset_provider_type, context, self.use_field_canonical_representation, ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 991053b4a5..2b2669187a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -19,8 +19,8 @@ import gt4py.eve.codegen from gt4py import eve -from gt4py.next import Dimension -from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value, Connectivity +from gt4py.next import common +from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir import FunCall, Lambda from gt4py.next.iterator.type_system import type_specifications as it_ts @@ -187,15 +187,15 @@ def _visit_lift_in_neighbors_reduction( transformer: PythonTaskletCodegen, node: itir.FunCall, node_args: Sequence[IteratorExpr | list[ValueExpr]], - offset_provider: Connectivity, + connectivity_type: common.NeighborConnectivityType, map_entry: dace.nodes.MapEntry, map_exit: dace.nodes.MapExit, neighbor_index_node: dace.nodes.AccessNode, neighbor_value_node: dace.nodes.AccessNode, ) -> list[ValueExpr]: assert transformer.context.reduce_identity is not None - neighbor_dim = offset_provider.neighbor_axis.value - origin_dim = offset_provider.origin_axis.value + neighbor_dim = connectivity_type.codomain.value + origin_dim = connectivity_type.source_dim.value lifted_args: list[IteratorExpr | ValueExpr] = [] for arg in node_args: @@ -232,7 +232,7 @@ def _visit_lift_in_neighbors_reduction( assert isinstance(y, ValueExpr) input_nodes[x] = y.value - neighbor_tables = get_used_connectivities(node.args[0], transformer.offset_provider) + neighbor_tables = get_used_connectivities(node.args[0], transformer.offset_provider_type) connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() ] @@ -294,7 +294,7 @@ def _visit_lift_in_neighbors_reduction( memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)), ) - if offset_provider.has_skip_values: + if connectivity_type.has_skip_values: # check neighbor validity on if/else inter-state edge # use one branch for connectivity case start_state = lift_context.body.add_state_before( @@ -333,8 +333,8 @@ def builtin_neighbors( assert isinstance(offset_literal, itir.OffsetLiteral) offset_dim = offset_literal.value assert isinstance(offset_dim, str) - offset_provider = transformer.offset_provider[offset_dim] - if not isinstance(offset_provider, Connectivity): + connectivity_type = transformer.offset_provider_type[offset_dim] + if not isinstance(connectivity_type, common.NeighborConnectivityType): raise NotImplementedError( "Neighbor reduction only implemented for connectivity based on neighbor tables." ) @@ -351,7 +351,7 @@ def builtin_neighbors( iterator = transformer.visit(data) assert isinstance(iterator, IteratorExpr) field_desc = iterator.field.desc(transformer.context.body) - origin_index_node = iterator.indices[offset_provider.origin_axis.value] + origin_index_node = iterator.indices[connectivity_type.source_dim.value] assert transformer.context.reduce_identity is not None assert transformer.context.reduce_identity.dtype == iterator.dtype @@ -361,7 +361,7 @@ def builtin_neighbors( sdfg.add_array( neighbor_value_var, dtype=iterator.dtype, - shape=(offset_provider.max_neighbors,), + shape=(connectivity_type.max_neighbors,), transient=True, ) neighbor_value_node = state.add_access(neighbor_value_var, debuginfo=di) @@ -375,7 +375,7 @@ def builtin_neighbors( neighbor_map_index = unique_name(f"{offset_dim}_neighbor_map_idx") me, mx = state.add_map( f"{offset_dim}_neighbor_map", - ndrange={neighbor_map_index: f"0:{offset_provider.max_neighbors}"}, + ndrange={neighbor_map_index: f"0:{connectivity_type.max_neighbors}"}, debuginfo=di, ) @@ -414,7 +414,7 @@ def builtin_neighbors( transformer, lift_node, lift_args, - offset_provider, + connectivity_type, me, mx, neighbor_index_node, @@ -423,13 +423,13 @@ def builtin_neighbors( else: sorted_dims = transformer.get_sorted_field_dimensions(iterator.dimensions) data_access_index = ",".join(f"{dim}_v" for dim in sorted_dims) - connector_neighbor_dim = f"{offset_provider.neighbor_axis.value}_v" + connector_neighbor_dim = f"{connectivity_type.codomain.value}_v" data_access_tasklet = state.add_tasklet( "data_access", code=f"__data = __field[{data_access_index}] " + ( f"if {connector_neighbor_dim} != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" - if offset_provider.has_skip_values + if connectivity_type.has_skip_values else "" ), inputs={"__field"} | {f"{dim}_v" for dim in iterator.dimensions}, @@ -445,7 +445,7 @@ def builtin_neighbors( ) for dim in iterator.dimensions: connector = f"{dim}_v" - if dim == offset_provider.neighbor_axis.value: + if dim == connectivity_type.codomain.value: state.add_edge( neighbor_index_node, None, @@ -470,7 +470,7 @@ def builtin_neighbors( src_conn="__data", ) - if not offset_provider.has_skip_values: + if not connectivity_type.has_skip_values: return [ValueExpr(neighbor_value_node, iterator.dtype)] else: """ @@ -483,7 +483,7 @@ def builtin_neighbors( sdfg.add_array( neighbor_valid_var, dtype=dace.dtypes.bool, - shape=(offset_provider.max_neighbors,), + shape=(connectivity_type.max_neighbors,), transient=True, ) neighbor_valid_node = state.add_access(neighbor_valid_var, debuginfo=di) @@ -572,7 +572,7 @@ def build_if_state(arg, state): symbol_map = copy.deepcopy(transformer.context.symbol_map) node_context = Context(sdfg, state, symbol_map) node_taskgen = PythonTaskletCodegen( - transformer.offset_provider, + transformer.offset_provider_type, node_context, transformer.use_field_canonical_representation, ) @@ -884,21 +884,12 @@ def visit_SymRef(self, node: itir.SymRef): ) +@dataclasses.dataclass class PythonTaskletCodegen(gt4py.eve.codegen.TemplatedGenerator): - offset_provider: dict[str, Any] + offset_provider_type: common.OffsetProviderType context: Context use_field_canonical_representation: bool - def __init__( - self, - offset_provider: dict[str, Any], - context: Context, - use_field_canonical_representation: bool, - ): - self.offset_provider = offset_provider - self.context = context - self.use_field_canonical_representation = use_field_canonical_representation - def get_sorted_field_dimensions(self, dims: Sequence[str]): return sorted(dims) if self.use_field_canonical_representation else dims @@ -914,7 +905,7 @@ def visit_Lambda( ]: func_name = f"lambda_{abs(hash(node)):x}" neighbor_tables = ( - get_used_connectivities(node, self.offset_provider) if use_neighbor_tables else {} + get_used_connectivities(node, self.offset_provider_type) if use_neighbor_tables else {} ) connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() @@ -974,7 +965,7 @@ def visit_Lambda( reduce_identity=self.context.reduce_identity, ) lambda_taskgen = PythonTaskletCodegen( - self.offset_provider, + self.offset_provider_type, lambda_context, self.use_field_canonical_representation, ) @@ -1066,7 +1057,7 @@ def _visit_call(self, node: itir.FunCall): store, self.context.body.arrays[store] ) - neighbor_tables = get_used_connectivities(node.fun, self.offset_provider) + neighbor_tables = get_used_connectivities(node.fun, self.offset_provider_type) for offset in neighbor_tables.keys(): var = dace_utils.connectivity_identifier(offset) nsdfg_inputs[var] = dace.Memlet.from_array(var, self.context.body.arrays[var]) @@ -1136,12 +1127,13 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: dims_not_indexed = [dim for dim in iterator.dimensions if dim not in iterator.indices] assert len(dims_not_indexed) == 1 offset = dims_not_indexed[0] - offset_provider = self.offset_provider[offset] - neighbor_dim = offset_provider.neighbor_axis.value + offset_provider_type = self.offset_provider_type[offset] + assert isinstance(offset_provider_type, common.NeighborConnectivityType) + neighbor_dim = offset_provider_type.codomain.value result_name = unique_var_name() self.context.body.add_array( - result_name, (offset_provider.max_neighbors,), iterator.dtype, transient=True + result_name, (offset_provider_type.max_neighbors,), iterator.dtype, transient=True ) result_array = self.context.body.arrays[result_name] result_node = self.context.state.add_access(result_name, debuginfo=di) @@ -1158,7 +1150,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: # we create a mapped tasklet for array slicing index_name = unique_name(f"_i_{neighbor_dim}") - map_ranges = {index_name: f"0:{offset_provider.max_neighbors}"} + map_ranges = {index_name: f"0:{offset_provider_type.max_neighbors}"} src_subset = ",".join( [f"_i_{dim}" if dim in iterator.indices else index_name for dim in sorted_dims] ) @@ -1212,27 +1204,30 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: offset_node = self.visit(tail[1])[0] assert offset_node.dtype in dace.dtypes.INTEGER_TYPES - if isinstance(self.offset_provider[offset_dim], Connectivity): - offset_provider = self.offset_provider[offset_dim] + if isinstance(self.offset_provider_type[offset_dim], common.NeighborConnectivityType): + offset_provider_type = cast( + common.NeighborConnectivityType, self.offset_provider_type[offset_dim] + ) # ensured by condition connectivity = self.context.state.add_access( dace_utils.connectivity_identifier(offset_dim), debuginfo=di ) - shifted_dim = offset_provider.origin_axis.value - target_dim = offset_provider.neighbor_axis.value + shifted_dim_tag = offset_provider_type.source_dim.value + target_dim_tag = offset_provider_type.codomain.value args = [ ValueExpr(connectivity, _INDEX_DTYPE), - ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), + ValueExpr(iterator.indices[shifted_dim_tag], offset_node.dtype), offset_node, ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{internals[1]}, {internals[2]}]" else: - assert isinstance(self.offset_provider[offset_dim], Dimension) + shifted_dim = self.offset_provider_type[offset_dim] + assert isinstance(shifted_dim, common.Dimension) - shifted_dim = self.offset_provider[offset_dim].value - target_dim = shifted_dim - args = [ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node] + shifted_dim_tag = shifted_dim.value + target_dim_tag = shifted_dim_tag + args = [ValueExpr(iterator.indices[shifted_dim_tag], offset_node.dtype), offset_node] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]} + {internals[1]}" @@ -1241,8 +1236,8 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: )[0].value shifted_index = {dim: value for dim, value in iterator.indices.items()} - del shifted_index[shifted_dim] - shifted_index[target_dim] = shifted_value + del shifted_index[shifted_dim_tag] + shifted_index[target_dim_tag] = shifted_value return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) @@ -1506,7 +1501,7 @@ def is_scan(node: itir.Node) -> bool: def closure_to_tasklet_sdfg( node: itir.StencilClosure, - offset_provider: dict[str, Any], + offset_provider_type: common.OffsetProviderType, domain: dict[str, str], inputs: Sequence[tuple[str, ts.TypeSpec]], connectivities: Sequence[tuple[dace.ndarray, str]], @@ -1547,7 +1542,9 @@ def closure_to_tasklet_sdfg( body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) context = Context(body, state, symbol_map) - translator = PythonTaskletCodegen(offset_provider, context, use_field_canonical_representation) + translator = PythonTaskletCodegen( + offset_provider_type, context, use_field_canonical_representation + ) args = [itir.SymRef(id=name) for name, _ in inputs] if is_scan(node.stencil): diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index d367eb0883..72bb32f003 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -7,21 +7,21 @@ # SPDX-License-Identifier: BSD-3-Clause import itertools -from typing import Any, Mapping +from typing import Any import dace import gt4py.next.iterator.ir as itir from gt4py import eve -from gt4py.next.common import Connectivity +from gt4py.next import common from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.program_processors.runners.dace_common import utility as dace_utils def get_used_connectivities( - node: itir.Node, offset_provider: Mapping[str, Any] -) -> dict[str, Connectivity]: - connectivities = dace_utils.filter_connectivities(offset_provider) + node: itir.Node, offset_provider_type: common.OffsetProviderType +) -> dict[str, common.NeighborConnectivityType]: + connectivities = dace_utils.filter_connectivity_types(offset_provider_type) offset_dims = set(eve.walk_values(node).if_isinstance(itir.OffsetLiteral).getattr("value")) return {offset: connectivities[offset] for offset in offset_dims if offset in connectivities} diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py index 740f1979cd..653ed4719d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py @@ -52,7 +52,7 @@ def generate_sdfg( self, program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], - offset_provider: dict[str, common.Dimension | common.Connectivity], + offset_provider_type: common.OffsetProviderType, column_axis: Optional[common.Dimension], ) -> dace.SDFG: on_gpu = ( @@ -64,7 +64,7 @@ def generate_sdfg( return build_sdfg_from_itir( program, arg_types, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, auto_optimize=self.auto_optimize, on_gpu=on_gpu, column_axis=column_axis, @@ -87,7 +87,7 @@ def __call__( sdfg = self.generate_sdfg( program, inp.args.args, - inp.args.offset_provider, + common.offset_provider_to_type(inp.args.offset_provider), inp.args.column_axis, ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 965c6417b2..1f3778f227 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -12,14 +12,12 @@ import diskcache import factory -import numpy.typing as npt import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators from gt4py.eve import utils from gt4py.eve.utils import content_hash from gt4py.next import backend, common, config -from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind @@ -63,8 +61,8 @@ def decorated_program( def _ensure_is_on_device( - connectivity_arg: npt.NDArray, device: core_defs.DeviceType -) -> npt.NDArray: + connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType +) -> core_defs.NDArrayObject: if device in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]: import cupy as cp @@ -79,17 +77,17 @@ def _ensure_is_on_device( def extract_connectivity_args( offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType -) -> list[tuple[npt.NDArray, tuple[int, ...]]]: +) -> list[tuple[core_defs.NDArrayObject, tuple[int, ...]]]: # note: the order here needs to agree with the order of the generated bindings - args: list[tuple[npt.NDArray, tuple[int, ...]]] = [] + args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [] for name, conn in offset_provider.items(): if isinstance(conn, common.Connectivity): - if not isinstance(conn, common.NeighborTable): + if not common.is_neighbor_table(conn): raise NotImplementedError( "Only 'NeighborTable' connectivities implemented at this point." ) # copying to device here is a fallback for easy testing and might be removed later - conn_arg = _ensure_is_on_device(conn.table, device) + conn_arg = _ensure_is_on_device(conn.ndarray, device) args.append((conn_arg, tuple([0] * 2))) elif isinstance(conn, common.Dimension): pass @@ -125,7 +123,7 @@ def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str: the program, sorted offset_provider, and column_axis. """ program: itir.FencilDefinition | itir.Program = inp.data - offset_provider: dict[str, Connectivity | Dimension] = inp.args.offset_provider + offset_provider: common.OffsetProvider = inp.args.offset_provider column_axis: Optional[common.Dimension] = inp.args.column_axis program_hash = utils.content_hash( diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 4d518d7fcc..1dd568b95a 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -94,7 +94,7 @@ def fencil_generator( ir: itir.Program | itir.FencilDefinition, debug: bool, use_embedded: bool, - offset_provider: dict[str, common.Connectivity | common.Dimension], + offset_provider: common.OffsetProvider, transforms: itir_transforms.ITIRTransform, ) -> stages.CompiledProgram: """ @@ -111,7 +111,15 @@ def fencil_generator( """ # TODO(tehrengruber): just a temporary solution until we have a proper generic # caching mechanism - cache_key = hash((ir, transforms, debug, use_embedded, tuple(offset_provider.items()))) + cache_key = hash( + ( + ir, + transforms, + debug, + use_embedded, + tuple(common.offset_provider_to_type(offset_provider).items()), + ) + ) if cache_key in _FENCIL_CACHE: if debug: print(f"Using cached fencil for key {cache_key}") @@ -151,7 +159,9 @@ def fencil_generator( """ ) - with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as source_file: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", encoding="utf-8", delete=False + ) as source_file: source_file_name = source_file.name if debug: print(source_file_name) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 0827d99cdc..fa8c9b9ab1 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -63,6 +63,7 @@ class DimensionType(TypeSpec): @dataclass(frozen=True) class OffsetType(TypeSpec): + # TODO(havogt): replace by ConnectivityType source: func_common.Dimension target: tuple[func_common.Dimension] | tuple[func_common.Dimension, func_common.Dimension] diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index 1da34db3c0..f5646c71e4 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -6,30 +6,32 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import numpy as np -from typing import Optional from types import ModuleType +from typing import Optional + +import numpy as np import pytest import gt4py.next as gtx -from gt4py.next import backend as next_backend -from gt4py.next.otf import arguments +from gt4py.next import backend as next_backend, common from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case, unstructured_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + E2V, + E2VDim, + Edge, + Vertex, exec_alloc_descriptor, mesh_descriptor, - Vertex, - Edge, - E2V, ) from next_tests.integration_tests.multi_feature_tests.ffront_tests.test_laplacian import ( lap_program, - laplap_program, lap_ref, + laplap_program, ) + try: import dace from gt4py.next.program_processors.runners.dace import ( @@ -57,25 +59,20 @@ def test_sdfgConvertible_laplap(cartesian_case): in_field = cases.allocate(cartesian_case, laplap_program, "in_field")() out_field = cases.allocate(cartesian_case, laplap_program, "out_field")() - connectivities = {} # Dict of NeighborOffsetProviders, where self.table = None - for k, v in cartesian_case.offset_provider.items(): - if hasattr(v, "table"): - connectivities[k] = arguments.CompileTimeConnectivity( - v.max_neighbors, v.has_skip_values, v.origin_axis, v.neighbor_axis, v.table.dtype - ) - else: - connectivities[k] = v - # Test DaCe closure support @dace.program def sdfg(): tmp_field = xp.empty_like(out_field) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( cartesian_case.backend - ).with_connectivities(connectivities)(in_field, tmp_field) + ).with_connectivities(common.offset_provider_to_type(cartesian_case.offset_provider))( + in_field, tmp_field + ) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( cartesian_case.backend - ).with_connectivities(connectivities)(tmp_field, out_field) + ).with_connectivities(common.offset_provider_to_type(cartesian_case.offset_provider))( + tmp_field, out_field + ) sdfg() @@ -130,13 +127,13 @@ def sdfg( a, out, offset_provider=offset_provider ) - e2v = gtx.NeighborTableOffsetProvider( - xp.asarray([[0, 1], [1, 2], [2, 0]]), Edge, Vertex, 2, False - ) - connectivities = {} - connectivities["E2V"] = arguments.CompileTimeConnectivity( - e2v.max_neighbors, e2v.has_skip_values, e2v.origin_axis, e2v.neighbor_axis, e2v.table.dtype + e2v = gtx.as_connectivity( + [Edge, E2VDim], + codomain=Vertex, + data=xp.asarray([[0, 1], [1, 2], [2, 0]]), + allocator=allocator, ) + connectivities = {"E2V": e2v.__gt_type__()} offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) SDFG = sdfg.to_sdfg(connectivities=connectivities) @@ -144,6 +141,9 @@ def sdfg( a = gtx.as_field([Vertex], xp.asarray([0.0, 1.0, 2.0]), allocator=allocator) out = gtx.zeros({Edge: 3}, allocator=allocator) + e2v_ndarray_copy = ( + e2v.ndarray.copy() + ) # otherwise DaCe complains about the gt4py custom allocated view # This is a low level interface to call the compiled SDFG. # It is not supposed to be used in user code. # The high level interface should be provided by a DaCe Orchestrator, @@ -155,21 +155,21 @@ def sdfg( offset_provider, rows=3, cols=2, - connectivity_E2V=e2v.table, - __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 0 - ), - __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 1 - ), + connectivity_E2V=e2v_ndarray_copy, + __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 0), + __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 1), ) - e2v_xp = xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table - assert np.allclose(gtx.field_utils.asnumpy(out), gtx.field_utils.asnumpy(a)[e2v_xp[:, 0]]) + e2v_np = e2v.asnumpy() + assert np.allclose(out.asnumpy(), a.asnumpy()[e2v_np[:, 0]]) - e2v = gtx.NeighborTableOffsetProvider( - xp.asarray([[1, 0], [2, 1], [0, 2]]), Edge, Vertex, 2, False + e2v = gtx.as_connectivity( + [Edge, E2VDim], + codomain=Vertex, + data=xp.asarray([[1, 0], [2, 1], [0, 2]]), + allocator=allocator, ) + e2v_ndarray_copy = e2v.ndarray.copy() offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) cSDFG( a, @@ -177,17 +177,13 @@ def sdfg( offset_provider, rows=3, cols=2, - connectivity_E2V=e2v.table, - __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 0 - ), - __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 1 - ), + connectivity_E2V=e2v_ndarray_copy, + __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 0), + __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 1), ) - e2v_xp = xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table - assert np.allclose(gtx.field_utils.asnumpy(out), gtx.field_utils.asnumpy(a)[e2v_xp[:, 0]]) + e2v_np = e2v.asnumpy() + assert np.allclose(out.asnumpy(), a.asnumpy()[e2v_np[:, 0]]) def get_stride_from_numpy_to_dace(numpy_array: np.ndarray, axis: int) -> int: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index c64efb27d2..794dd06709 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -152,7 +152,10 @@ def num_edges(self) -> int: ... def num_levels(self) -> int: ... @property - def offset_provider(self) -> dict[str, common.Connectivity]: ... + def offset_provider(self) -> common.OffsetProvider: ... + + @property + def offset_provider_type(self) -> common.OffsetProviderType: ... def simple_mesh() -> MeshDescriptor: @@ -211,25 +214,40 @@ def simple_mesh() -> MeshDescriptor: assert all(len(row) == 2 for row in e2v_arr) e2v_arr = np.asarray(e2v_arr, dtype=gtx.IndexType) + offset_provider = { + V2E.value: common._connectivity( + v2e_arr, + codomain=Edge, + domain={Vertex: v2e_arr.shape[0], V2EDim: 4}, + skip_value=None, + ), + E2V.value: common._connectivity( + e2v_arr, + codomain=Vertex, + domain={Edge: e2v_arr.shape[0], E2VDim: 2}, + skip_value=None, + ), + C2V.value: common._connectivity( + c2v_arr, + codomain=Vertex, + domain={Cell: c2v_arr.shape[0], C2VDim: 4}, + skip_value=None, + ), + C2E.value: common._connectivity( + c2e_arr, + codomain=Edge, + domain={Cell: c2e_arr.shape[0], C2EDim: 4}, + skip_value=None, + ), + } + return types.SimpleNamespace( name="simple_mesh", num_vertices=num_vertices, num_edges=np.int32(num_edges), num_cells=num_cells, - offset_provider={ - V2E.value: gtx.NeighborTableOffsetProvider( - v2e_arr, Vertex, Edge, 4, has_skip_values=False - ), - E2V.value: gtx.NeighborTableOffsetProvider( - e2v_arr, Edge, Vertex, 2, has_skip_values=False - ), - C2V.value: gtx.NeighborTableOffsetProvider( - c2v_arr, Cell, Vertex, 4, has_skip_values=False - ), - C2E.value: gtx.NeighborTableOffsetProvider( - c2e_arr, Cell, Edge, 4, has_skip_values=False - ), - }, + offset_provider=offset_provider, + offset_provider_type=common.offset_provider_to_type(offset_provider), ) @@ -287,25 +305,40 @@ def skip_value_mesh() -> MeshDescriptor: dtype=gtx.IndexType, ) + offset_provider = { + V2E.value: common._connectivity( + v2e_arr, + codomain=Edge, + domain={Vertex: v2e_arr.shape[0], V2EDim: 5}, + skip_value=common._DEFAULT_SKIP_VALUE, + ), + E2V.value: common._connectivity( + e2v_arr, + codomain=Vertex, + domain={Edge: e2v_arr.shape[0], E2VDim: 2}, + skip_value=None, + ), + C2V.value: common._connectivity( + c2v_arr, + codomain=Vertex, + domain={Cell: c2v_arr.shape[0], C2VDim: 3}, + skip_value=None, + ), + C2E.value: common._connectivity( + c2e_arr, + codomain=Edge, + domain={Cell: c2e_arr.shape[0], C2EDim: 3}, + skip_value=None, + ), + } + return types.SimpleNamespace( name="skip_value_mesh", num_vertices=num_vertices, num_edges=num_edges, num_cells=num_cells, - offset_provider={ - V2E.value: gtx.NeighborTableOffsetProvider( - v2e_arr, Vertex, Edge, 5, has_skip_values=True - ), - E2V.value: gtx.NeighborTableOffsetProvider( - e2v_arr, Edge, Vertex, 2, has_skip_values=False - ), - C2V.value: gtx.NeighborTableOffsetProvider( - c2v_arr, Cell, Vertex, 3, has_skip_values=False - ), - C2E.value: gtx.NeighborTableOffsetProvider( - c2e_arr, Cell, Edge, 3, has_skip_values=False - ), - }, + offset_provider=offset_provider, + offset_provider_type=common.offset_provider_to_type(offset_provider), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index a5453151e6..1a51e3667d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -89,7 +89,7 @@ def testee(a: cases.VField) -> cases.EField: cases.verify_with_default_data( unstructured_case, testee, - ref=lambda a: a[unstructured_case.offset_provider["E2V"].table[:, 0]], + ref=lambda a: a[unstructured_case.offset_provider["E2V"].ndarray[:, 0]], ) @@ -115,16 +115,16 @@ def composed_shift_unstructured(inp: cases.VField) -> cases.CField: cases.verify_with_default_data( unstructured_case, composed_shift_unstructured_flat, - ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].table[:, 0]][ - unstructured_case.offset_provider["C2E"].table[:, 0] + ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].ndarray[:, 0]][ + unstructured_case.offset_provider["C2E"].ndarray[:, 0] ], ) cases.verify_with_default_data( unstructured_case, composed_shift_unstructured_intermediate_result, - ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].table[:, 0]][ - unstructured_case.offset_provider["C2E"].table[:, 0] + ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].ndarray[:, 0]][ + unstructured_case.offset_provider["C2E"].ndarray[:, 0] ], comparison=lambda inp, tmp: np.all(inp == tmp), ) @@ -132,8 +132,8 @@ def composed_shift_unstructured(inp: cases.VField) -> cases.CField: cases.verify_with_default_data( unstructured_case, composed_shift_unstructured, - ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].table[:, 0]][ - unstructured_case.offset_provider["C2E"].table[:, 0] + ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].ndarray[:, 0]][ + unstructured_case.offset_provider["C2E"].ndarray[:, 0] ], ) @@ -583,11 +583,11 @@ def testee(a: cases.VField) -> cases.VField: unstructured_case, testee, ref=lambda a: np.sum( - np.sum(a[unstructured_case.offset_provider["E2V"].table], axis=1, initial=0)[ - unstructured_case.offset_provider["V2E"].table + np.sum(a[unstructured_case.offset_provider["E2V"].ndarray], axis=1, initial=0)[ + unstructured_case.offset_provider["V2E"].ndarray ], axis=1, - where=unstructured_case.offset_provider["V2E"].table != common._DEFAULT_SKIP_VALUE, + where=unstructured_case.offset_provider["V2E"].ndarray != common._DEFAULT_SKIP_VALUE, ), comparison=lambda a, tmp_2: np.all(a == tmp_2), ) @@ -606,8 +606,8 @@ def testee(inp: cases.EField) -> cases.EField: unstructured_case, testee, ref=lambda inp: np.sum( - np.sum(inp[unstructured_case.offset_provider["V2E"].table], axis=1)[ - unstructured_case.offset_provider["E2V"].table + np.sum(inp[unstructured_case.offset_provider["V2E"].ndarray], axis=1)[ + unstructured_case.offset_provider["E2V"].ndarray ], axis=1, ), @@ -627,8 +627,8 @@ def testee(a: cases.EField, b: cases.EField) -> tuple[cases.VField, cases.VField unstructured_case, testee, ref=lambda a, b: [ - np.sum(a[unstructured_case.offset_provider["V2E"].table], axis=1), - np.sum(b[unstructured_case.offset_provider["V2E"].table], axis=1), + np.sum(a[unstructured_case.offset_provider["V2E"].ndarray], axis=1), + np.sum(b[unstructured_case.offset_provider["V2E"].ndarray], axis=1), ], comparison=lambda a, tmp: (np.all(a[0] == tmp[0]), np.all(a[1] == tmp[1])), ) @@ -649,11 +649,11 @@ def reduce_tuple_element(e: cases.EField, v: cases.VField) -> cases.EField: unstructured_case, reduce_tuple_element, ref=lambda e, v: np.sum( - e[v2e.table] + np.tile(v, (v2e.max_neighbors, 1)).T, + e[v2e.ndarray] + np.tile(v, (v2e.shape[1], 1)).T, axis=1, initial=0, - where=v2e.table != common._DEFAULT_SKIP_VALUE, - )[unstructured_case.offset_provider["E2V"].table[:, 0]], + where=v2e.ndarray != common._DEFAULT_SKIP_VALUE, + )[unstructured_case.offset_provider["E2V"].ndarray[:, 0]], ) @@ -780,7 +780,7 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: tmp = neighbor_sum(b(V2E) if 2 < 3 else a(V2E), axis=V2EDim) return tmp - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, testee, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 37f4ee2cd1..33832fb5f0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -33,11 +33,11 @@ def testee( ) # multiplication with shifted `ones` because reduction of only non-shifted field with local dimension is not supported inp = unstructured_case.as_field( - [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table + [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].ndarray ) ones = cases.allocate(unstructured_case, testee, "ones").strategy(cases.ConstInitializer(1))() - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify( unstructured_case, testee, @@ -57,7 +57,7 @@ def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32 return neighbor_sum(inp, axis=V2EDim) inp = unstructured_case.as_field( - [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table + [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].ndarray ) cases.verify( @@ -65,7 +65,7 @@ def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32 testee, inp, out=cases.allocate(unstructured_case, testee, cases.RETURN)(), - ref=np.sum(unstructured_case.offset_provider["V2E"].table, axis=1), + ref=np.sum(unstructured_case.offset_provider["V2E"].ndarray, axis=1), ) @@ -76,7 +76,7 @@ def testee(inp: gtx.Field[[Edge], int32]) -> gtx.Field[[Vertex, V2EDim], int32]: return inp(V2E) out = unstructured_case.as_field( - [Vertex, V2EDim], np.zeros_like(unstructured_case.offset_provider["V2E"].table) + [Vertex, V2EDim], np.zeros_like(unstructured_case.offset_provider["V2E"].ndarray) ) inp = cases.allocate(unstructured_case, testee, "inp")() cases.verify( @@ -84,5 +84,5 @@ def testee(inp: gtx.Field[[Edge], int32]) -> gtx.Field[[Vertex, V2EDim], int32]: testee, inp, out=out, - ref=inp.asnumpy()[unstructured_case.offset_provider["V2E"].table], + ref=inp.asnumpy()[unstructured_case.offset_provider["V2E"].ndarray], ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 29966c30ad..7648d34db7 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -52,7 +52,7 @@ def testee(edge_f: cases.EField) -> cases.VField: inp = cases.allocate(unstructured_case, testee, "edge_f", strategy=strategy)() out = cases.allocate(unstructured_case, testee, cases.RETURN)() - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray ref = np.max( inp.asnumpy()[v2e_table], axis=1, @@ -69,7 +69,7 @@ def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) return out - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, minover, @@ -106,7 +106,7 @@ def reduction_ke_field( "fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__ ) def test_neighbor_sum(unstructured_case, fop): - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray edge_f = cases.allocate(unstructured_case, fop, "edge_f")() @@ -157,7 +157,7 @@ def fencil_op(edge_f: EKField) -> VKField: def fencil(edge_f: EKField, out: VKField): fencil_op(edge_f, out=out) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray field = cases.allocate(unstructured_case, fencil, "edge_f", sizes={KDim: 2})() out = cases.allocate(unstructured_case, fencil_op, cases.RETURN, sizes={KDim: 1})() @@ -190,7 +190,7 @@ def reduce_expr(edge_f: cases.EField) -> cases.VField: def fencil(edge_f: cases.EField, out: cases.VField): reduce_expr(edge_f, out=out) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, fencil, @@ -210,7 +210,7 @@ def test_reduction_with_common_expression(unstructured_case): def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, testee, @@ -226,7 +226,7 @@ def test_reduction_expression_with_where(unstructured_case): def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: return neighbor_sum(where(mask, inp(V2E), inp(V2E)), axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray mask = unstructured_case.as_field( [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) @@ -255,7 +255,7 @@ def test_reduction_expression_with_where_and_tuples(unstructured_case): def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: return neighbor_sum(where(mask, (inp(V2E), inp(V2E)), (inp(V2E), inp(V2E)))[1], axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray mask = unstructured_case.as_field( [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) @@ -284,7 +284,7 @@ def test_reduction_expression_with_where_and_scalar(unstructured_case): def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: return neighbor_sum(inp(V2E) + where(mask, inp(V2E), 1), axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray mask = unstructured_case.as_field( [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 11e28de9e1..66c56c4827 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -90,7 +90,7 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh a = cases.allocate(unstructured_case, testee, "a")() out = cases.allocate(unstructured_case, testee, "out")() - first_nbs, second_nbs = (mesh_descriptor.offset_provider["E2V"].table[:, i] for i in [0, 1]) + first_nbs, second_nbs = (mesh_descriptor.offset_provider["E2V"].ndarray[:, i] for i in [0, 1]) ref = (a.ndarray * 2)[first_nbs] + (a.ndarray * 2)[second_nbs] cases.verify( diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 3fc4ed9945..5e3a2fcd14 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -248,11 +248,14 @@ def test_can_deref(program_processor, stencil): program_processor, validate = program_processor Node = gtx.Dimension("Node") + NeighDim = gtx.Dimension("Neighbor", kind=gtx.DimensionKind.LOCAL) inp = gtx.as_field([Node], np.ones((1,), dtype=np.int32)) out = gtx.as_field([Node], np.asarray([0], dtype=inp.dtype)) - no_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[-1]]), Node, Node, 1) + no_neighbor_tbl = gtx.as_connectivity( + domain={Node: 1, NeighDim: 1}, codomain=Node, data=np.array([[-1]]), skip_value=-1 + ) run_processor( stencil[{Node: range(1)}], program_processor, @@ -264,7 +267,9 @@ def test_can_deref(program_processor, stencil): if validate: assert np.allclose(out.asnumpy(), -1.0) - a_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[0]]), Node, Node, 1) + a_neighbor_tbl = gtx.as_connectivity( + domain={Node: 1, NeighDim: 1}, codomain=Node, data=np.array([[0]]), skip_value=-1 + ) run_processor( stencil[{Node: range(1)}], program_processor, @@ -277,37 +282,6 @@ def test_can_deref(program_processor, stencil): assert np.allclose(out.asnumpy(), 1.0) -# def test_can_deref_lifted(program_processor): -# program_processor, validate = program_processor - -# Neighbor = offset("Neighbor") -# Node = gtx.Dimension("Node") - -# @fundef -# def _can_deref(inp): -# shifted = shift(Neighbor, 0)(inp) -# return if_(can_deref(shifted), 1, -1) - -# inp = gtx.as_field([Node], np.zeros((1,))) -# out = gtx.as_field([Node], np.asarray([0])) - -# no_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[None]]), Node, Node, 1) -# _can_deref[{Node: range(1)}]( -# inp, out=out, offset_provider={"Neighbor": no_neighbor_tbl}, program_processor=program_processor -# ) - -# if validate: -# assert np.allclose(np.asarray(out), -1.0) - -# a_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[0]]), Node, Node, 1) -# _can_deref[{Node: range(1)}]( -# inp, out=out, offset_provider={"Neighbor": a_neighbor_tbl}, program_processor=program_processor -# ) - -# if validate: -# assert np.allclose(np.asarray(out), 1.0) - - @pytest.mark.parametrize( "input_value, dtype, np_dtype", [ diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index 69786b323b..7bde55bfd2 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -14,6 +14,7 @@ from gt4py.next.iterator.runtime import closure, fendef, fundef, offset from next_tests.unit_tests.conftest import program_processor, run_processor +from gt4py.next.iterator.embedded import StridedConnectivityField LocA = gtx.Dimension("LocA") @@ -21,8 +22,10 @@ LocB = gtx.Dimension("LocB") # unused LocA2LocAB = offset("O") -LocA2LocAB_offset_provider = gtx.StridedNeighborOffsetProvider( - origin_axis=LocA, neighbor_axis=LocAB, max_neighbors=2, has_skip_values=False +LocA2LocAB_offset_provider = StridedConnectivityField( + domain_dims=(LocA, gtx.Dimension("Dummy", kind=gtx.DimensionKind.LOCAL)), + codomain_dim=LocAB, + max_neighbors=2, ) @@ -41,7 +44,7 @@ def test_strided_offset_provider(program_processor): program_processor, validate = program_processor LocA_size = 2 - max_neighbors = LocA2LocAB_offset_provider.max_neighbors + max_neighbors = LocA2LocAB_offset_provider.__gt_type__().max_neighbors LocAB_size = LocA_size * max_neighbors rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py index eb59c77201..6c6ca7e4bc 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py @@ -11,7 +11,6 @@ import numpy as np import pytest - pytest.importorskip("atlas4py") from gt4py import next as gtx @@ -22,20 +21,17 @@ exec_alloc_descriptor, ) from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( + E2V, + V2E, + E2VDim, + Edge, + V2EDim, + Vertex, assert_close, nabla_setup, ) -Vertex = gtx.Dimension("Vertex") -Edge = gtx.Dimension("Edge") -V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) -E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) - -V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) -E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) - - @gtx.field_operator def compute_zavgS( pp: gtx.Field[[Vertex], float], S_M: gtx.Field[[Edge], float] @@ -67,21 +63,19 @@ def pnabla( def test_ffront_compute_zavgS(exec_alloc_descriptor): - executor, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - - setup = nabla_setup() + _, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - pp = gtx.as_field([Vertex], setup.input_field, allocator=allocator) - S_M = tuple(map(gtx.as_field.partial([Edge], allocator=allocator), setup.S_fields)) + setup = nabla_setup(allocator=allocator) zavgS = gtx.zeros({Edge: setup.edges_size}, allocator=allocator) - e2v = gtx.NeighborTableOffsetProvider( - atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False - ) - - compute_zavgS.with_backend(exec_alloc_descriptor)( - pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v} + compute_zavgS.with_backend( + None if exec_alloc_descriptor.executor is None else exec_alloc_descriptor + )( + setup.input_field, + setup.S_fields[0], + out=zavgS, + offset_provider={"E2V": setup.edges2node_connectivity}, ) assert_close(-199755464.25741270, np.min(zavgS.asnumpy())) @@ -89,27 +83,23 @@ def test_ffront_compute_zavgS(exec_alloc_descriptor): def test_ffront_nabla(exec_alloc_descriptor): - executor, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - - setup = nabla_setup() + _, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field, allocator=allocator) - pp = gtx.as_field([Vertex], setup.input_field, allocator=allocator) - S_M = tuple(map(gtx.as_field.partial([Edge], allocator=allocator), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field, allocator=allocator) + setup = nabla_setup(allocator=allocator) pnabla_MXX = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator) pnabla_MYY = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator) - e2v = gtx.NeighborTableOffsetProvider( - atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False - ) - v2e = gtx.NeighborTableOffsetProvider( - atlas_utils.AtlasTable(setup.nodes2edge_connectivity).asnumpy(), Vertex, Edge, 7 - ) - - pnabla.with_backend(exec_alloc_descriptor)( - pp, S_M, sign, vol, out=(pnabla_MXX, pnabla_MYY), offset_provider={"E2V": e2v, "V2E": v2e} + pnabla.with_backend(None if exec_alloc_descriptor.executor is None else exec_alloc_descriptor)( + setup.input_field, + setup.S_fields, + setup.sign_field, + setup.vol_field, + out=(pnabla_MXX, pnabla_MYY), + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) # TODO this check is not sensitive enough, need to implement a proper numpy reference! diff --git a/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py b/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py index 8d7324f438..6a5865134d 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py @@ -20,6 +20,18 @@ functionspace, ) +from gt4py import next as gtx +from gt4py.next.iterator import atlas_utils + + +Vertex = gtx.Dimension("Vertex") +Edge = gtx.Dimension("Edge") +V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) +E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) + +V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) +E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) + def assert_close(expected, actual): assert math.isclose(expected, actual), "expected={}, actual={}".format(expected, actual) @@ -33,9 +45,10 @@ def _default_config(): config["angle"] = 20.0 return config - def __init__(self, *, grid=StructuredGrid("O32"), config=None): + def __init__(self, *, allocator, grid=StructuredGrid("O32"), config=None): if config is None: config = self._default_config() + self.allocator = allocator mesh = StructuredMeshGenerator(config).generate(grid) fs_edges = functionspace.EdgeColumns(mesh, halo=1) @@ -55,12 +68,22 @@ def __init__(self, *, grid=StructuredGrid("O32"), config=None): self.edges_per_node = edges_per_node @property - def edges2node_connectivity(self): - return self.mesh.edges.node_connectivity + def edges2node_connectivity(self) -> gtx.Connectivity: + return gtx.as_connectivity( + domain={Edge: self.edges_size, E2VDim: 2}, + codomain=Vertex, + data=atlas_utils.AtlasTable(self.mesh.edges.node_connectivity).asnumpy(), + allocator=self.allocator, + ) @property - def nodes2edge_connectivity(self): - return self.mesh.nodes.edge_connectivity + def nodes2edge_connectivity(self) -> gtx.Connectivity: + return gtx.as_connectivity( + domain={Vertex: self.nodes_size, V2EDim: self.edges_per_node}, + codomain=Edge, + data=atlas_utils.AtlasTable(self.mesh.nodes.edge_connectivity).asnumpy(), + allocator=self.allocator, + ) @property def nodes_size(self): @@ -75,16 +98,16 @@ def _is_pole_edge(e, edge_flags): return Topology.check(edge_flags[e], Topology.POLE) @property - def is_pole_edge_field(self): + def is_pole_edge_field(self) -> gtx.Field: edge_flags = np.array(self.mesh.edges.flags()) pole_edge_field = np.zeros((self.edges_size,), dtype=bool) for e in range(self.edges_size): pole_edge_field[e] = self._is_pole_edge(e, edge_flags) - return pole_edge_field + return gtx.as_field([Edge], pole_edge_field, allocator=self.allocator) @property - def sign_field(self): + def sign_field(self) -> gtx.Field: node2edge_sign = np.zeros((self.nodes_size, self.edges_per_node)) edge_flags = np.array(self.mesh.edges.flags()) @@ -100,10 +123,10 @@ def sign_field(self): node2edge_sign[jnode, jedge] = -1.0 if self._is_pole_edge(iedge, edge_flags): node2edge_sign[jnode, jedge] = 1.0 - return node2edge_sign + return gtx.as_field([Vertex, V2EDim], node2edge_sign, allocator=self.allocator) @property - def S_fields(self): + def S_fields(self) -> tuple[gtx.Field, gtx.Field]: S = np.array(self.mesh.edges.field("dual_normals"), copy=False) S_MXX = np.zeros((self.edges_size)) S_MYY = np.zeros((self.edges_size)) @@ -124,10 +147,12 @@ def S_fields(self): assert math.isclose(min(S_MYY), -2001577.7946404363) assert math.isclose(max(S_MYY), 2001577.7946404363) - return S_MXX, S_MYY + return gtx.as_field([Edge], S_MXX, allocator=self.allocator), gtx.as_field( + [Edge], S_MYY, allocator=self.allocator + ) @property - def vol_field(self): + def vol_field(self) -> gtx.Field: rpi = 2.0 * math.asin(1.0) radius = 6371.22e03 deg2rad = 2.0 * rpi / 360.0 @@ -142,10 +167,10 @@ def vol_field(self): # VOL(min/max): 57510668192.214096 851856184496.32886 assert_close(57510668192.214096, min(vol)) assert_close(851856184496.32886, max(vol)) - return vol + return gtx.as_field([Vertex], vol, allocator=self.allocator) @property - def input_field(self): + def input_field(self) -> gtx.Field: klevel = 0 MXX = 0 MYY = 1 @@ -200,4 +225,5 @@ def input_field(self): assert_close(0.0000000000000000, min(rzs)) assert_close(1965.4980340735883, max(rzs)) - return rzs[:, klevel] + + return gtx.as_field([Vertex], rzs[:, klevel], allocator=self.allocator) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 3db4497910..4487681abf 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -111,25 +111,18 @@ def nabla(n_nodes, out, pp, S_MXX, S_MYY, sign, vol): @pytest.mark.requires_atlas def test_compute_zavgS(program_processor): program_processor, validate = program_processor - setup = nabla_setup() - - pp = gtx.as_field([Vertex], setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) + setup = nabla_setup(allocator=None) zavgS = gtx.as_field([Edge], np.zeros((setup.edges_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - run_processor( compute_zavgS_fencil, program_processor, setup.edges_size, zavgS, - pp, - S_MXX, - offset_provider={"E2V": e2v}, + setup.input_field, + setup.S_fields[0], + offset_provider={"E2V": setup.edges2node_connectivity}, ) if validate: @@ -141,9 +134,9 @@ def test_compute_zavgS(program_processor): program_processor, setup.edges_size, zavgS, - pp, - S_MYY, - offset_provider={"E2V": e2v}, + setup.input_field, + setup.S_fields[1], + offset_provider={"E2V": setup.edges2node_connectivity}, ) if validate: assert_close(-1000788897.3202186, np.min(zavgS.asnumpy())) @@ -158,29 +151,21 @@ def compute_zavgS2_fencil(n_edges, out, pp, S_M): @pytest.mark.requires_atlas def test_compute_zavgS2(program_processor): program_processor, validate = program_processor - setup = nabla_setup() - - pp = gtx.as_field([Vertex], setup.input_field) - - S = tuple(gtx.as_field([Edge], s) for s in setup.S_fields) + setup = nabla_setup(allocator=None) zavgS = ( gtx.as_field([Edge], np.zeros((setup.edges_size))), gtx.as_field([Edge], np.zeros((setup.edges_size))), ) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - run_processor( compute_zavgS2_fencil, program_processor, setup.edges_size, zavgS, - pp, - S, - offset_provider={"E2V": e2v}, + setup.input_field, + setup.S_fields, + offset_provider={"E2V": setup.edges2node_connectivity}, ) if validate: @@ -195,34 +180,27 @@ def test_compute_zavgS2(program_processor): def test_nabla(program_processor): program_processor, validate = program_processor - setup = nabla_setup() + setup = nabla_setup(allocator=None) - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field) + S_MXX, S_MYY = setup.S_fields pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - v2e = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 - ) - run_processor( nabla, program_processor, setup.nodes_size, (pnabla_MXX, pnabla_MYY), - pp, + setup.input_field, S_MXX, S_MYY, - sign, - vol, - offset_provider={"E2V": e2v, "V2E": v2e}, + setup.sign_field, + setup.vol_field, + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) if validate: @@ -245,33 +223,24 @@ def nabla2(n_nodes, out, pp, S, sign, vol): @pytest.mark.requires_atlas def test_nabla2(program_processor): program_processor, validate = program_processor - setup = nabla_setup() - - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_M = tuple(gtx.as_field([Edge], s) for s in setup.S_fields) - vol = gtx.as_field([Vertex], setup.vol_field) + setup = nabla_setup(allocator=None) pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - v2e = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 - ) - run_processor( nabla2, program_processor, setup.nodes_size, (pnabla_MXX, pnabla_MYY), - pp, - S_M, - sign, - vol, - offset_provider={"E2V": e2v, "V2E": v2e}, + setup.input_field, + setup.S_fields, + setup.sign_field, + setup.vol_field, + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) if validate: @@ -325,36 +294,29 @@ def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_ def test_nabla_sign(program_processor): program_processor, validate = program_processor - setup = nabla_setup() + setup = nabla_setup(allocator=None) - is_pole_edge = gtx.as_field([Edge], setup.is_pole_edge_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field) + S_MXX, S_MYY = setup.S_fields pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - v2e = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 - ) - run_processor( nabla_sign, program_processor, setup.nodes_size, pnabla_MXX, pnabla_MYY, - pp, + setup.input_field, S_MXX, S_MYY, - vol, + setup.vol_field, gtx.index_field(Vertex), - is_pole_edge, - offset_provider={"E2V": e2v, "V2E": v2e}, + setup.is_pole_edge_field, + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) if validate: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 6fdc6a77a1..ac7ce9e544 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -38,9 +38,13 @@ V2VDim, Vertex, c2e_arr, + c2e_conn, e2v_arr, + e2v_conn, v2e_arr, + v2e_conn, v2v_arr, + v2v_conn, ) from next_tests.unit_tests.conftest import program_processor, run_processor @@ -89,7 +93,7 @@ def test_sum_edges_to_vertices(program_processor, stencil): program_processor, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -111,7 +115,7 @@ def test_map_neighbors(program_processor): program_processor, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -134,7 +138,7 @@ def test_map_make_const_list(program_processor): program_processor, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -157,8 +161,8 @@ def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processo inp, out=out, offset_provider={ - "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), - "C2E": gtx.NeighborTableOffsetProvider(c2e_arr, Cell, Edge, 4), + "E2V": e2v_conn, + "C2E": c2e_conn, }, ) if validate: @@ -185,7 +189,7 @@ def test_sparse_input_field(program_processor): non_sparse, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: @@ -208,8 +212,8 @@ def test_sparse_input_field_v2v(program_processor): inp, out=out, offset_provider={ - "V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4), - "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), + "V2V": v2v_conn, + "V2E": v2e_conn, }, ) @@ -235,7 +239,7 @@ def test_slice_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -259,7 +263,7 @@ def test_slice_twice_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -284,7 +288,7 @@ def test_shift_sliced_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -309,7 +313,7 @@ def test_slice_shifted_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -337,7 +341,7 @@ def test_lift(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -360,7 +364,7 @@ def test_shift_sparse_input_field(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -393,8 +397,8 @@ def test_shift_sparse_input_field2(program_processor): out2 = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) offset_provider = { - "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), - "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), + "E2V": e2v_conn, + "V2E": v2e_conn, } domain = {Vertex: range(0, 9)} @@ -448,7 +452,7 @@ def test_sparse_shifted_stencil_reduce(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: diff --git a/tests/next_tests/toy_connectivity.py b/tests/next_tests/toy_connectivity.py index 82c91a5e74..50db24b880 100644 --- a/tests/next_tests/toy_connectivity.py +++ b/tests/next_tests/toy_connectivity.py @@ -49,6 +49,8 @@ dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), ) +c2e_conn = gtx.as_connectivity(domain={Cell: 9, C2EDim: 4}, codomain=Edge, data=c2e_arr) + v2v_arr = np.array( [ [1, 3, 2, 6], @@ -64,6 +66,8 @@ dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), ) +v2v_conn = gtx.as_connectivity(domain={Vertex: 9, V2VDim: 4}, codomain=Vertex, data=v2v_arr) + e2v_arr = np.array( [ [0, 1], @@ -88,6 +92,7 @@ dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), ) +e2v_conn = gtx.as_connectivity(domain={Edge: 18, E2VDim: 2}, codomain=Vertex, data=e2v_arr) # order east, north, west, south (counter-clock wise) v2e_arr = np.array( @@ -104,3 +109,5 @@ ], dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), ) + +v2e_conn = gtx.as_connectivity(domain={Vertex: 9, V2EDim: 4}, codomain=Edge, data=v2e_arr) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index ca66b45d6d..f1269f1ed8 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -14,11 +14,11 @@ import pytest import gt4py.next as gtx -from gt4py.next import backend +from gt4py.next import backend, common +from gt4py.next.embedded import nd_array_field from gt4py.next.iterator import runtime from gt4py.next.program_processors import program_formatter - import next_tests @@ -97,12 +97,21 @@ def run_processor( @dataclasses.dataclass -class DummyConnectivity: +class DummyConnectivity(common.Connectivity): max_neighbors: int has_skip_values: int - origin_axis: gtx.Dimension = gtx.Dimension("dummy_origin") - neighbor_axis: gtx.Dimension = gtx.Dimension("dummy_neighbor") - index_type: type[int] = int + source_dim: gtx.Dimension = gtx.Dimension("dummy_origin") + codomain: gtx.Dimension = gtx.Dimension("dummy_neighbor") + + +def nd_array_implementation_params(): + for xp in nd_array_field._nd_array_implementations: + if hasattr(nd_array_field, "cp") and xp == nd_array_field.cp: + yield pytest.param(xp, id=xp.__name__, marks=pytest.mark.requires_gpu) + else: + yield pytest.param(xp, id=xp.__name__) + - def mapped_index(_, __) -> int: - return 0 +@pytest.fixture(params=nd_array_implementation_params()) +def nd_array_implementation(request): + yield request.param diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 063e79d92e..9dde5bb40a 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -15,7 +15,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.common import Dimension, Domain, UnitRange, NamedRange, NamedIndex +from gt4py.next.common import Dimension, Domain, NamedIndex, NamedRange, UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins @@ -28,19 +28,6 @@ D2 = Dimension("D2") -def nd_array_implementation_params(): - for xp in nd_array_field._nd_array_implementations: - if hasattr(nd_array_field, "cp") and xp == nd_array_field.cp: - yield pytest.param(xp, id=xp.__name__, marks=pytest.mark.requires_gpu) - else: - yield pytest.param(xp, id=xp.__name__) - - -@pytest.fixture(params=nd_array_implementation_params()) -def nd_array_implementation(request): - yield request.param - - @pytest.fixture( params=[ operator.add, diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py index dcc3a306f2..a91dbeb608 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py @@ -31,12 +31,10 @@ # 0 --0-- 1 --1-- 2 e2v_arr = np.array([[0, 1], [1, 2]]) -e2v_conn = gtx.NeighborTableOffsetProvider( - table=e2v_arr, - origin_axis=E, - neighbor_axis=V, - max_neighbors=2, - has_skip_values=False, +e2v_conn = gtx.as_connectivity( + domain={E: 2, E2VDim: 2}, + codomain=V, + data=e2v_arr, ) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py index 1f08362f4f..13e8637d1a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py @@ -10,18 +10,22 @@ import pytest import gt4py.next as gtx +from gt4py.next import common from gt4py.next.iterator.builtins import deref from gt4py.next.iterator.runtime import CartesianDomain, UnstructuredDomain, _deduce_domain, fundef -from next_tests.unit_tests.conftest import DummyConnectivity - @fundef def foo(inp): return deref(inp) -connectivity = DummyConnectivity(max_neighbors=0, has_skip_values=True) +connectivity = common.ConnectivityType( + domain=[gtx.Dimension("dummy_origin"), gtx.Dimension("dummy_neighbor")], + codomain=gtx.Dimension("dummy_codomain"), + skip_value=common._DEFAULT_SKIP_VALUE, + dtype=None, +) def test_deduce_domain(): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 7b6214fb1b..65a5b5888d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -218,11 +218,11 @@ def expression_test_cases(): @pytest.mark.parametrize("test_case", expression_test_cases()) def test_expression_type(test_case): mesh = simple_mesh() - offset_provider = {**mesh.offset_provider, "Ioff": IDim, "Joff": JDim, "Koff": KDim} + offset_provider_type = {**mesh.offset_provider_type, "Ioff": IDim, "Joff": JDim, "Koff": KDim} testee, expected_type = test_case result = itir_type_inference.infer( - testee, offset_provider=offset_provider, allow_undeclared_symbols=True + testee, offset_provider_type=offset_provider_type, allow_undeclared_symbols=True ) assert result.type == expected_type @@ -231,14 +231,16 @@ def test_adhoc_polymorphism(): func = im.lambda_("a")(im.lambda_("b")(im.make_tuple("a", "b"))) testee = im.call(im.call(func)(im.ref("a_", bool_type)))(im.ref("b_", int_type)) - result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) assert result.type == ts.TupleType(types=[bool_type, int_type]) def test_aliased_function(): testee = im.let("f", im.lambda_("x")("x"))(im.call("f")(1)) - result = itir_type_inference.infer(testee, offset_provider={}) + result = itir_type_inference.infer(testee, offset_provider_type={}) assert result.args[0].type == ts.FunctionType( pos_only_args=[int_type], pos_or_kw_args={}, kw_only_args={}, returns=int_type @@ -253,7 +255,7 @@ def test_late_offset_axis(): testee = im.call(func)(im.ensure_offset("V2E")) result = itir_type_inference.infer( - testee, offset_provider=mesh.offset_provider, allow_undeclared_symbols=True + testee, offset_provider_type=mesh.offset_provider_type, allow_undeclared_symbols=True ) assert result.type == it_on_e_of_e_type @@ -265,7 +267,9 @@ def test_cast_first_arg_inference(): testee = im.call("cast_")( im.plus(im.literal_from_value(1), im.literal_from_value(2)), "float64" ) - result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) assert result.args[0].type == int_type assert result.type == float64_type @@ -291,7 +295,7 @@ def test_cartesian_fencil_definition(): ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) closure_type = it_ts.StencilClosureType( domain=it_ts.DomainType(dims=[IDim]), @@ -336,7 +340,7 @@ def test_unstructured_fencil_definition(): ], ) - result = itir_type_inference.infer(testee, offset_provider=mesh.offset_provider) + result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) closure_type = it_ts.StencilClosureType( domain=it_ts.DomainType(dims=[Vertex, KDim]), @@ -384,7 +388,7 @@ def test_function_definition(): ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) closure_type = it_ts.StencilClosureType( domain=it_ts.DomainType(dims=[IDim]), @@ -429,7 +433,7 @@ def test_fencil_with_nb_field_input(): ], ) - result = itir_type_inference.infer(testee, offset_provider=mesh.offset_provider) + result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) assert result.closures[0].stencil.expr.args[0].type == float64_list_type assert result.closures[0].stencil.type.returns == float64_type @@ -456,7 +460,7 @@ def test_program_tuple_setat_short_target(): ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) assert ( isinstance(result.body[0].expr.type, ts.TupleType) @@ -487,7 +491,7 @@ def test_program_setat_without_domain(): ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) assert ( isinstance(result.body[0].expr.type, ts.DeferredType) @@ -512,7 +516,9 @@ def test_if_stmt(): false_branch=[], ) - result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) assert result.cond.type == bool_type assert result.true_branch[0].expr.type == float_i_field @@ -522,7 +528,7 @@ def test_as_fieldop_without_domain(): im.ref("inp", float_i_field) ) result = itir_type_inference.infer( - testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert result.type == ts.DeferredType(constraint=ts.FieldType) assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index e04856b75f..f4ea2d7fe1 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -21,7 +21,7 @@ @pytest.fixture -def offset_provider(request): +def offset_provider_type(request): return {"I": common.Dimension("I", kind=common.DimensionKind.HORIZONTAL)} @@ -137,7 +137,7 @@ def common_expr(): assert actual == expected -def test_if_can_deref_no_extraction(offset_provider): +def test_if_can_deref_no_extraction(offset_provider_type): # Test that a subexpression only occurring in one branch of an `if_` is not moved outside the # if statement. A case using `can_deref` is used here as it is common. @@ -157,11 +157,11 @@ def test_if_can_deref_no_extraction(offset_provider): ) ) - actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) + actual = CSE.apply(testee, offset_provider_type=offset_provider_type, within_stencil=True) assert actual == expected -def test_if_can_deref_eligible_extraction(offset_provider): +def test_if_can_deref_eligible_extraction(offset_provider_type): # Test that a subexpression only occurring in both branches of an `if_` is moved outside the # if statement. A case using `can_deref` is used here as it is common. @@ -178,11 +178,11 @@ def test_if_can_deref_eligible_extraction(offset_provider): ) ) - actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) + actual = CSE.apply(testee, offset_provider_type=offset_provider_type, within_stencil=True) assert actual == expected -def test_if_eligible_extraction(offset_provider): +def test_if_eligible_extraction(offset_provider_type): # Test that a subexpression only occurring in the condition of an `if_` is moved outside the # if statement. @@ -191,7 +191,7 @@ def test_if_eligible_extraction(offset_provider): # (λ(_cs_1) → if _cs_1 ∧ _cs_1 then c else d)(a ∧ b) expected = im.let("_cs_1", im.and_("a", "b"))(im.if_(im.and_("_cs_1", "_cs_1"), "c", "d")) - actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) + actual = CSE.apply(testee, offset_provider_type=offset_provider_type, within_stencil=True) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 141091b450..817c06e8f0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -14,11 +14,12 @@ from gt4py import eve from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next import constructors from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import infer_domain from gt4py.next.iterator.ir_utils import domain_utils from gt4py.next.common import Dimension -from gt4py.next import common, NeighborTableOffsetProvider +from gt4py.next import common from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next import utils @@ -29,6 +30,7 @@ KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) Edge = common.Dimension(value="Edge", kind=common.DimensionKind.HORIZONTAL) +E2VDim = common.Dimension(value="E2V", kind=common.DimensionKind.LOCAL) @pytest.fixture @@ -39,11 +41,10 @@ def offset_provider(): @pytest.fixture def unstructured_offset_provider(): return { - "E2V": NeighborTableOffsetProvider( - np.array([[0, 1]], dtype=np.int32), - Edge, - Vertex, - 2, + "E2V": constructors.as_connectivity( + domain={Edge: 1, E2VDim: 2}, + codomain=Vertex, + data=np.array([[0, 1]], dtype=np.int32), ) } diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index b5b9a62009..168e9490e0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -13,6 +13,7 @@ from gt4py.next.iterator.transforms import fuse_as_fieldop from gt4py.next.type_system import type_specifications as ts + IDim = gtx.Dimension("IDim") field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) @@ -30,7 +31,7 @@ def test_trivial(): d, )(im.ref("inp1", field_type), im.ref("inp2", field_type), im.ref("inp3", field_type)) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected @@ -40,7 +41,7 @@ def test_trivial_literal(): testee = im.op_as_fieldop("plus", d)(im.op_as_fieldop("multiplies", d)(1, 2), 3) expected = im.as_fieldop(im.lambda_()(im.plus(im.multiplies_(1, 2), 3)), d)() actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected @@ -65,7 +66,7 @@ def test_tuple_arg(): d, )() actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected @@ -85,7 +86,7 @@ def test_symref_used_twice(): d, )("inp1", "inp2") actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected @@ -100,7 +101,7 @@ def test_no_inline(): d1, )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type))) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert actual == testee @@ -132,6 +133,6 @@ def test_partial_inline(): d1, )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), "inp1") actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 23f62842c4..9d51dc4f33 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -52,7 +52,7 @@ def test_trivial(): ) ], ) - testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( @@ -87,7 +87,7 @@ def test_trivial_let(): ) ], ) - testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( @@ -128,7 +128,7 @@ def test_top_level_if(): ) ], ) - testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( @@ -186,7 +186,7 @@ def test_nested_if(): ) ], ) - testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py index 7c991fb9a8..77d3323fb4 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py @@ -8,16 +8,16 @@ from gt4py import next as gtx from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.prune_casts import PruneCasts from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_specifications as ts def test_prune_casts_simple(): x_ref = im.ref("x", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) y_ref = im.ref("y", ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) testee = im.call("plus")(im.call("cast_")(x_ref, "float64"), im.call("cast_")(y_ref, "float64")) - testee = type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + testee = type_inference.infer(testee, offset_provider_type={}, allow_undeclared_symbols=True) expected = im.call("plus")(im.call("cast_")(x_ref, "float64"), y_ref) actual = PruneCasts.apply(testee) @@ -32,7 +32,7 @@ def test_prune_casts_fieldop(): im.cast_as_fieldop("float64")(x_ref), im.cast_as_fieldop("float64")(y_ref), ) - testee = type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + testee = type_inference.infer(testee, offset_provider_type={}, allow_undeclared_symbols=True) expected = im.op_as_fieldop("plus")( im.cast_as_fieldop("float64")(x_ref), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py index 28bd88b853..0760247996 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py @@ -11,11 +11,20 @@ import pytest from gt4py.eve.utils import UIDs +from gt4py.next import common from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce, _get_partial_offset_tags from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce, _get_partial_offset_tags -from next_tests.unit_tests.conftest import DummyConnectivity + +def dummy_connectivity_type(max_neighbors: int, has_skip_values: bool): + return common.NeighborConnectivityType( + domain=[common.Dimension("dummy_origin"), common.Dimension("dummy_neighbor")], + codomain=common.Dimension("dummy_codomain"), + skip_value=common._DEFAULT_SKIP_VALUE if has_skip_values else None, + dtype=None, + max_neighbors=max_neighbors, + ) @pytest.fixture(params=[True, False]) @@ -67,7 +76,7 @@ def reduction_if(): ], ) def test_get_partial_offsets(reduction, request): - offset_provider = {"Dim": SimpleNamespace(max_neighbors=3, has_skip_values=False)} + offset_provider_type = {"Dim": SimpleNamespace(max_neighbors=3, has_skip_values=False)} partial_offsets = _get_partial_offset_tags(request.getfixturevalue(reduction).args) assert set(partial_offsets) == {"Dim"} @@ -108,63 +117,73 @@ def _expected(red, dim, max_neighbors, has_skip_values, shifted_arg=0): def test_basic(basic_reduction, has_skip_values): expected = _expected(basic_reduction, "Dim", 3, has_skip_values) - offset_provider = {"Dim": DummyConnectivity(max_neighbors=3, has_skip_values=has_skip_values)} - actual = UnrollReduce.apply(basic_reduction, offset_provider=offset_provider) + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=has_skip_values) + } + actual = UnrollReduce.apply(basic_reduction, offset_provider_type=offset_provider_type) assert actual == expected def test_reduction_with_shift_on_second_arg(reduction_with_shift_on_second_arg, has_skip_values): expected = _expected(reduction_with_shift_on_second_arg, "Dim", 1, has_skip_values, 1) - offset_provider = {"Dim": DummyConnectivity(max_neighbors=1, has_skip_values=has_skip_values)} - actual = UnrollReduce.apply(reduction_with_shift_on_second_arg, offset_provider=offset_provider) + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=1, has_skip_values=has_skip_values) + } + actual = UnrollReduce.apply( + reduction_with_shift_on_second_arg, offset_provider_type=offset_provider_type + ) assert actual == expected def test_reduction_with_if(reduction_if): expected = _expected(reduction_if, "Dim", 2, False) - offset_provider = {"Dim": DummyConnectivity(max_neighbors=2, has_skip_values=False)} - actual = UnrollReduce.apply(reduction_if, offset_provider=offset_provider) + offset_provider_type = {"Dim": dummy_connectivity_type(max_neighbors=2, has_skip_values=False)} + actual = UnrollReduce.apply(reduction_if, offset_provider_type=offset_provider_type) assert actual == expected def test_reduction_with_irrelevant_full_shift(reduction_with_irrelevant_full_shift): expected = _expected(reduction_with_irrelevant_full_shift, "Dim", 3, False) - offset_provider = { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "IrrelevantDim": DummyConnectivity( + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "IrrelevantDim": dummy_connectivity_type( max_neighbors=1, has_skip_values=True ), # different max_neighbors and skip value to trigger error } actual = UnrollReduce.apply( - reduction_with_irrelevant_full_shift, offset_provider=offset_provider + reduction_with_irrelevant_full_shift, offset_provider_type=offset_provider_type ) assert actual == expected @pytest.mark.parametrize( - "offset_provider", + "offset_provider_type", [ { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=2, has_skip_values=False), + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=2, has_skip_values=False), }, { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=3, has_skip_values=True), + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=3, has_skip_values=True), }, { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=2, has_skip_values=True), + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=2, has_skip_values=True), }, ], ) -def test_reduction_with_incompatible_shifts(reduction_with_incompatible_shifts, offset_provider): - offset_provider = { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=2, has_skip_values=False), +def test_reduction_with_incompatible_shifts( + reduction_with_incompatible_shifts, offset_provider_type +): + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=2, has_skip_values=False), } with pytest.raises(RuntimeError, match="incompatible"): - UnrollReduce.apply(reduction_with_incompatible_shifts, offset_provider=offset_provider) + UnrollReduce.apply( + reduction_with_incompatible_shifts, offset_provider_type=offset_provider_type + ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py index 1a86f7b0f8..97591122e5 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py @@ -21,7 +21,7 @@ def test_funcall_to_op(): ) actual = it2gtfn.GTFN_lowering( - grid_type=gtx.GridType.CARTESIAN, offset_provider={}, column_axis=None + grid_type=gtx.GridType.CARTESIAN, offset_provider_type={}, column_axis=None ).visit(testee) assert expected == actual @@ -32,7 +32,7 @@ def test_unapplied_funcall_to_function_object(): expected = gtfn_ir.SymRef(id="plus") actual = it2gtfn.GTFN_lowering( - grid_type=gtx.GridType.CARTESIAN, offset_provider={}, column_axis=None + grid_type=gtx.GridType.CARTESIAN, offset_provider_type={}, column_axis=None ).visit(testee) assert expected == actual diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py index 329b2814d2..62d88d9f0a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py @@ -11,6 +11,7 @@ import ctypes import unittest import unittest.mock +from unittest.mock import patch import numpy as np import pytest @@ -20,19 +21,15 @@ from gt4py.next.ffront.fbuiltins import where from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import ( - E2V, - cartesian_case, - unstructured_case, -) +from next_tests.integration_tests.cases import E2V, cartesian_case, unstructured_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, mesh_descriptor, ) -from unittest.mock import patch from . import pytestmark + dace = pytest.importorskip("dace") @@ -151,14 +148,14 @@ def test_dace_fastcall_with_connectivity(unstructured_case, monkeypatch): # check that test connectivities are allocated on host memory # this is an assumption to test that fast_call cannot be used for gpu tests - assert isinstance(connectivity_E2V.table, np.ndarray) + assert isinstance(connectivity_E2V.ndarray, np.ndarray) @gtx.field_operator def testee(a: cases.VField) -> cases.EField: return a(E2V[0]) (a,), kwfields = cases.get_default_data(unstructured_case, testee) - numpy_ref = lambda a: a[connectivity_E2V.table[:, 0]] + numpy_ref = lambda a: a[connectivity_E2V.ndarray[:, 0]] mock_fast_call, mock_construct_args = make_mocks(monkeypatch) @@ -194,12 +191,11 @@ def verify_testee(offset_provider): # Here we copy the connectivity to gpu memory, and resuse the same cupy array # on multiple program calls, in order to ensure that fast_call is used. offset_provider = { - "E2V": gtx.NeighborTableOffsetProvider( - table=cp.asarray(connectivity_E2V.table), - origin_axis=connectivity_E2V.origin_axis, - neighbor_axis=connectivity_E2V.neighbor_axis, - max_neighbors=connectivity_E2V.max_neighbors, - has_skip_values=connectivity_E2V.has_skip_values, + "E2V": gtx.as_connectivity( + domain=connectivity_E2V.domain, + codomain=connectivity_E2V.codomain, + data=cp.asarray(connectivity_E2V.ndarray), + skip_value=connectivity_E2V.skip_value, ) } diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index e0c0c3fa4e..9c52ea81c3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -18,7 +18,7 @@ import numpy as np import pytest -from gt4py.next import common as gtx_common +from gt4py.next import common as gtx_common, constructors from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts @@ -50,13 +50,7 @@ "IDim": IDim, } SIMPLE_MESH: MeshDescriptor = simple_mesh() -SIMPLE_MESH_OFFSET_PROVIDER: dict[str, gtx_common.Connectivity | gtx_common.Dimension] = ( - SIMPLE_MESH.offset_provider | CARTESIAN_OFFSETS -) SKIP_VALUE_MESH: MeshDescriptor = skip_value_mesh() -SKIP_VALUE_MESH_OFFSET_PROVIDER: dict[str, gtx_common.Connectivity | gtx_common.Dimension] = ( - SKIP_VALUE_MESH.offset_provider | CARTESIAN_OFFSETS -) SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) FSYMBOLS = dict( __w_size_0=N, @@ -83,20 +77,20 @@ def make_mesh_symbols(mesh: MeshDescriptor): __vertices_size_0=mesh.num_vertices, __vertices_stride_0=1, __connectivity_C2E_size_0=mesh.num_cells, - __connectivity_C2E_size_1=mesh.offset_provider["C2E"].max_neighbors, - __connectivity_C2E_stride_0=mesh.offset_provider["C2E"].max_neighbors, + __connectivity_C2E_size_1=mesh.offset_provider_type["C2E"].max_neighbors, + __connectivity_C2E_stride_0=mesh.offset_provider_type["C2E"].max_neighbors, __connectivity_C2E_stride_1=1, __connectivity_C2V_size_0=mesh.num_cells, - __connectivity_C2V_size_1=mesh.offset_provider["C2V"].max_neighbors, - __connectivity_C2V_stride_0=mesh.offset_provider["C2V"].max_neighbors, + __connectivity_C2V_size_1=mesh.offset_provider_type["C2V"].max_neighbors, + __connectivity_C2V_stride_0=mesh.offset_provider_type["C2V"].max_neighbors, __connectivity_C2V_stride_1=1, __connectivity_E2V_size_0=mesh.num_edges, - __connectivity_E2V_size_1=mesh.offset_provider["E2V"].max_neighbors, - __connectivity_E2V_stride_0=mesh.offset_provider["E2V"].max_neighbors, + __connectivity_E2V_size_1=mesh.offset_provider_type["E2V"].max_neighbors, + __connectivity_E2V_stride_0=mesh.offset_provider_type["E2V"].max_neighbors, __connectivity_E2V_stride_1=1, __connectivity_V2E_size_0=mesh.num_vertices, - __connectivity_V2E_size_1=mesh.offset_provider["V2E"].max_neighbors, - __connectivity_V2E_stride_0=mesh.offset_provider["V2E"].max_neighbors, + __connectivity_V2E_size_1=mesh.offset_provider_type["V2E"].max_neighbors, + __connectivity_V2E_stride_0=mesh.offset_provider_type["V2E"].max_neighbors, __connectivity_V2E_stride_1=1, ) @@ -1018,14 +1012,14 @@ def test_gtir_connectivity_shift(): CELL_OFFSET_FTYPE = ts.FieldType(dims=[Cell], dtype=SIZE_TYPE) EDGE_OFFSET_FTYPE = ts.FieldType(dims=[Edge], dtype=SIZE_TYPE) - connectivity_C2E = SIMPLE_MESH_OFFSET_PROVIDER["C2E"] + connectivity_C2E = SIMPLE_MESH.offset_provider["C2E"] assert isinstance(connectivity_C2E, gtx_common.NeighborTable) - connectivity_E2V = SIMPLE_MESH_OFFSET_PROVIDER["E2V"] + connectivity_E2V = SIMPLE_MESH.offset_provider["E2V"] assert isinstance(connectivity_E2V, gtx_common.NeighborTable) ev = np.random.rand(SIMPLE_MESH.num_edges, SIMPLE_MESH.num_vertices) - ref = ev[connectivity_C2E.table[:, C2E_neighbor_idx], :][ - :, connectivity_E2V.table[:, E2V_neighbor_idx] + ref = ev[connectivity_C2E.ndarray[:, C2E_neighbor_idx], :][ + :, connectivity_E2V.ndarray[:, E2V_neighbor_idx] ] for i, stencil in enumerate( @@ -1053,7 +1047,7 @@ def test_gtir_connectivity_shift(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) ce = np.empty([SIMPLE_MESH.num_cells, SIMPLE_MESH.num_edges]) @@ -1062,8 +1056,8 @@ def test_gtir_connectivity_shift(): ev, c2e_offset=np.full(SIMPLE_MESH.num_cells, C2E_neighbor_idx, dtype=np.int32), e2v_offset=np.full(SIMPLE_MESH.num_edges, E2V_neighbor_idx, dtype=np.int32), - connectivity_C2E=connectivity_C2E.table, - connectivity_E2V=connectivity_E2V.table, + connectivity_C2E=connectivity_C2E.ndarray, + connectivity_E2V=connectivity_E2V.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __ce_field_size_0=SIMPLE_MESH.num_cells, @@ -1114,15 +1108,17 @@ def test_gtir_connectivity_shift_chain(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) - connectivity_E2V = SIMPLE_MESH_OFFSET_PROVIDER["E2V"] + connectivity_E2V = SIMPLE_MESH.offset_provider["E2V"] assert isinstance(connectivity_E2V, gtx_common.NeighborTable) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SIMPLE_MESH.num_edges) - ref = e[connectivity_V2E.table[connectivity_E2V.table[:, E2V_neighbor_idx], V2E_neighbor_idx]] + ref = e[ + connectivity_V2E.ndarray[connectivity_E2V.ndarray[:, E2V_neighbor_idx], V2E_neighbor_idx] + ] # new empty output field e_out = np.empty_like(e) @@ -1130,8 +1126,8 @@ def test_gtir_connectivity_shift_chain(): sdfg( e, e_out, - connectivity_E2V=connectivity_E2V.table, - connectivity_V2E=connectivity_V2E.table, + connectivity_E2V=connectivity_E2V.ndarray, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __edges_out_size_0=SIMPLE_MESH.num_edges, @@ -1174,30 +1170,30 @@ def test_gtir_neighbors_as_input(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - v2e_field = np.random.rand(SIMPLE_MESH.num_vertices, connectivity_V2E.max_neighbors) + v2e_field = np.random.rand(SIMPLE_MESH.num_vertices, connectivity_V2E.shape[1]) e = np.random.rand(SIMPLE_MESH.num_edges) v = np.empty(SIMPLE_MESH.num_vertices, dtype=v2e_field.dtype) v_ref = [ functools.reduce(lambda x, y: x + y, v2e_values + e[v2e_neighbors], init_value) - for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field, strict=True) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.ndarray, v2e_field, strict=True) ] sdfg( v2e_field, e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __v2e_field_size_0=SIMPLE_MESH.num_vertices, - __v2e_field_size_1=connectivity_V2E.max_neighbors, - __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_size_1=connectivity_V2E.shape[1], + __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) @@ -1210,7 +1206,7 @@ def test_gtir_neighbors_as_output(): gtx_common.GridType.UNSTRUCTURED, ranges={ Vertex: (0, "nvertices"), - V2EDim: (0, SIMPLE_MESH_OFFSET_PROVIDER["V2E"].max_neighbors), + V2EDim: (0, SIMPLE_MESH.offset_provider_type["V2E"].max_neighbors), }, ) vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) @@ -1232,9 +1228,9 @@ def test_gtir_neighbors_as_output(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SIMPLE_MESH.num_edges) @@ -1243,7 +1239,7 @@ def test_gtir_neighbors_as_output(): sdfg( e, v2e_field, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __v2e_field_size_0=SIMPLE_MESH.num_vertices, @@ -1251,7 +1247,7 @@ def test_gtir_neighbors_as_output(): __v2e_field_stride_0=connectivity_V2E.max_neighbors, __v2e_field_stride_1=1, ) - assert np.allclose(v2e_field, e[connectivity_V2E.table]) + assert np.allclose(v2e_field, e[connectivity_V2E.ndarray]) def test_gtir_reduce(): @@ -1278,13 +1274,13 @@ def test_gtir_reduce(): ) )(im.as_fieldop_neighbors("V2E", "edges", vertex_domain)) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SIMPLE_MESH.num_edges) v_ref = [ functools.reduce(lambda x, y: x + y, e[v2e_neighbors], init_value) - for v2e_neighbors in connectivity_V2E.table + for v2e_neighbors in connectivity_V2E.ndarray ] for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): @@ -1305,7 +1301,7 @@ def test_gtir_reduce(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) # new empty output field v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) @@ -1313,7 +1309,7 @@ def test_gtir_reduce(): sdfg( e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), ) @@ -1344,7 +1340,7 @@ def test_gtir_reduce_with_skip_values(): ) )(im.as_fieldop_neighbors("V2E", "edges", vertex_domain)) - connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SKIP_VALUE_MESH.num_edges) @@ -1354,7 +1350,7 @@ def test_gtir_reduce_with_skip_values(): [e[i] if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 for i in v2e_neighbors], init_value, ) - for v2e_neighbors in connectivity_V2E.table + for v2e_neighbors in connectivity_V2E.ndarray ] for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): @@ -1375,7 +1371,7 @@ def test_gtir_reduce_with_skip_values(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) # new empty output field v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) @@ -1383,7 +1379,7 @@ def test_gtir_reduce_with_skip_values(): sdfg( e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SKIP_VALUE_MESH), ) @@ -1394,10 +1390,10 @@ def test_gtir_reduce_dot_product(): init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) - connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - v2e_field = np.random.rand(SKIP_VALUE_MESH.num_vertices, connectivity_V2E.max_neighbors) + v2e_field = np.random.rand(*connectivity_V2E.shape) e = np.random.rand(SKIP_VALUE_MESH.num_edges) v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) v_ref = [ @@ -1409,7 +1405,7 @@ def test_gtir_reduce_dot_product(): ), init_value, ) - for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.ndarray, v2e_field) ] testee = gtir.Program( @@ -1448,17 +1444,17 @@ def test_gtir_reduce_dot_product(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) sdfg( v2e_field, e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **make_mesh_symbols(SKIP_VALUE_MESH), __v2e_field_size_0=SKIP_VALUE_MESH.num_vertices, - __v2e_field_size_1=connectivity_V2E.max_neighbors, - __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_size_1=connectivity_V2E.shape[1], + __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) @@ -1500,14 +1496,14 @@ def test_gtir_reduce_with_cond_neighbors(): ], ) - connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - v2e_field = np.random.rand(SKIP_VALUE_MESH.num_vertices, connectivity_V2E.max_neighbors) + v2e_field = np.random.rand(*connectivity_V2E.shape) e = np.random.rand(SKIP_VALUE_MESH.num_edges) for use_sparse in [False, True]: - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) v_ref = [ @@ -1525,19 +1521,19 @@ def test_gtir_reduce_with_cond_neighbors(): [e[i] if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 for i in v2e_neighbors], init_value, ) - for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field, strict=True) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.ndarray, v2e_field, strict=True) ] sdfg( np.bool_(use_sparse), v2e_field, e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SKIP_VALUE_MESH), __v2e_field_size_0=SKIP_VALUE_MESH.num_vertices, - __v2e_field_size_1=connectivity_V2E.max_neighbors, - __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_size_1=connectivity_V2E.shape[1], + __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) @@ -1631,9 +1627,9 @@ def test_gtir_let_lambda_with_connectivity(): C2V_neighbor_idx = 2 cell_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Cell: (0, "ncells")}) - connectivity_C2E = SIMPLE_MESH_OFFSET_PROVIDER["C2E"] + connectivity_C2E = SIMPLE_MESH.offset_provider["C2E"] assert isinstance(connectivity_C2E, gtx_common.NeighborTable) - connectivity_C2V = SIMPLE_MESH_OFFSET_PROVIDER["C2V"] + connectivity_C2V = SIMPLE_MESH.offset_provider["C2V"] assert isinstance(connectivity_C2V, gtx_common.NeighborTable) testee = gtir.Program( @@ -1669,22 +1665,22 @@ def test_gtir_let_lambda_with_connectivity(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) e = np.random.rand(SIMPLE_MESH.num_edges) v = np.random.rand(SIMPLE_MESH.num_vertices) c = np.empty(SIMPLE_MESH.num_cells) ref = ( - e[connectivity_C2E.table[:, C2E_neighbor_idx]] - + v[connectivity_C2V.table[:, C2V_neighbor_idx]] + e[connectivity_C2E.ndarray[:, C2E_neighbor_idx]] + + v[connectivity_C2V.ndarray[:, C2V_neighbor_idx]] ) sdfg( cells=c, edges=e, vertices=v, - connectivity_C2E=connectivity_C2E.table, - connectivity_C2V=connectivity_C2V.table, + connectivity_C2E=connectivity_C2E.ndarray, + connectivity_C2V=connectivity_C2V.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), ) diff --git a/tests/next_tests/unit_tests/test_constructors.py b/tests/next_tests/unit_tests/test_constructors.py index 6e9dfa3d64..0998ab8eab 100644 --- a/tests/next_tests/unit_tests/test_constructors.py +++ b/tests/next_tests/unit_tests/test_constructors.py @@ -11,10 +11,7 @@ from gt4py import next as gtx from gt4py._core import definitions as core_defs -from gt4py.next import allocators as next_allocators, common, float32 -from gt4py.next.program_processors.runners import roundtrip - -from next_tests.integration_tests import cases +from gt4py.next import allocators as next_allocators, common I = gtx.Dimension("I") @@ -154,3 +151,12 @@ def test_field_wrong_origin(): @pytest.mark.xfail(reason="aligned_index not supported yet") def test_aligned_index(): gtx.as_field([I], np.random.rand(sizes[I]).astype(gtx.float32), aligned_index=[I, 0]) + + +@pytest.mark.parametrize( + "data, skip_value", + [([0, 1, 2], None), ([0, 1, common._DEFAULT_SKIP_VALUE], common._DEFAULT_SKIP_VALUE)], +) +def test_as_connectivity(nd_array_implementation, data, skip_value): + testee = gtx.as_connectivity([I], J, nd_array_implementation.array(data)) + assert testee.skip_value is skip_value From 3fb206e46ceecf07b7ef6c668239d62d79028503 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 26 Nov 2024 10:53:19 +0100 Subject: [PATCH 052/178] feat[next][dace]: Symbolic domain without dace array offsets (#1735) Add support for field operator domain with symbolic shape, with dimension extent in non zero-based range. --- .../runners/dace_common/utility.py | 10 +- .../gtir_builtin_translators.py | 127 ++++++++++----- .../runners/dace_fieldview/gtir_dataflow.py | 100 +++++++----- .../runners/dace_fieldview/gtir_sdfg.py | 148 +++++++++++++----- .../runners/dace_fieldview/utility.py | 11 +- .../dace_tests/test_gtir_to_sdfg.py | 123 +++++++++++++-- 6 files changed, 367 insertions(+), 152 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index 29395a30c1..3e96ef3cec 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -9,7 +9,7 @@ from __future__ import annotations import re -from typing import Final, Optional, Sequence +from typing import Final, Literal, Optional, Sequence import dace @@ -51,12 +51,16 @@ def connectivity_identifier(name: str) -> str: return f"connectivity_{name}" +def field_symbol_name(field_name: str, axis: int, sym: Literal["size", "stride"]) -> str: + return f"__{field_name}_{sym}_{axis}" + + def field_size_symbol_name(field_name: str, axis: int) -> str: - return f"__{field_name}_size_{axis}" + return field_symbol_name(field_name, axis, "size") def field_stride_symbol_name(field_name: str, axis: int) -> str: - return f"__{field_name}_stride_{axis}" + return field_symbol_name(field_name, axis, "stride") def is_field_symbol(name: str) -> bool: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 69aedf44d6..60dcd8ddc9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -10,7 +10,7 @@ import abc import dataclasses -from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, TypeAlias +from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace import dace.subsets as sbs @@ -33,6 +33,34 @@ from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg +def _get_domain_indices( + dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None +) -> sbs.Indices: + """ + Helper function to construct the list of indices for a field domain, applying + an optional offset in each dimension as start index. + + Args: + dims: The field dimensions. + offsets: The range start index in each dimension. + + Returns: + A list of indices for field access in dace arrays. As this list is returned + as `dace.subsets.Indices`, it should be converted to `dace.subsets.Range` before + being used in memlet subset because ranges are better supported throughout DaCe. + """ + index_variables = [dace.symbolic.SymExpr(dace_gtir_utils.get_map_variable(dim)) for dim in dims] + if offsets is None: + return sbs.Indices(index_variables) + else: + return sbs.Indices( + [ + index - offset if offset != 0 else index + for index, offset in zip(index_variables, offsets, strict=True) + ] + ) + + @dataclasses.dataclass(frozen=True) class FieldopData: """ @@ -45,42 +73,59 @@ class FieldopData: Args: dc_node: DaCe access node to the data storage. gt_type: GT4Py type definition, which includes the field domain information. + offset: List of index offsets, in each dimension, when the dimension range + does not start from zero; assume zero offset, if not set. """ dc_node: dace.nodes.AccessNode gt_type: ts.FieldType | ts.ScalarType + offset: Optional[list[dace.symbolic.SymExpr]] + + def make_copy(self, data_node: dace.nodes.AccessNode) -> FieldopData: + """Create a copy of this data descriptor with a different access node.""" + assert data_node != self.dc_node + return FieldopData(data_node, self.gt_type, self.offset) def get_local_view( self, domain: FieldopDomain ) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: - """Helper method to access a field in local view, given a field operator domain.""" + """Helper method to access a field in local view, given the compute domain of a field operator.""" if isinstance(self.gt_type, ts.ScalarType): return gtir_dataflow.MemletExpr( dc_node=self.dc_node, gt_dtype=self.gt_type, subset=sbs.Indices([0]) ) if isinstance(self.gt_type, ts.FieldType): - indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { - dim: gtir_dataflow.SymbolExpr(dace_gtir_utils.get_map_variable(dim), INDEX_DTYPE) - for dim, _, _ in domain + domain_dims = [dim for dim, _, _ in domain] + domain_indices = _get_domain_indices(domain_dims) + it_indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { + dim: gtir_dataflow.SymbolExpr(index, INDEX_DTYPE) + for dim, index in zip(domain_dims, domain_indices) } + field_domain = [ + (dim, dace.symbolic.SymExpr(0) if self.offset is None else self.offset[i]) + for i, dim in enumerate(self.gt_type.dims) + ] local_dims = [ dim for dim in self.gt_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL ] - if len(local_dims) == 0: return gtir_dataflow.IteratorExpr( - self.dc_node, self.gt_type.dtype, self.gt_type.dims, indices + self.dc_node, self.gt_type.dtype, field_domain, it_indices ) elif len(local_dims) == 1: field_dtype = itir_ts.ListType( element_type=self.gt_type.dtype, offset_type=local_dims[0] ) - field_dims = [ - dim for dim in self.gt_type.dims if dim.kind != gtx_common.DimensionKind.LOCAL + field_domain = [ + (dim, offset) + for dim, offset in field_domain + if dim.kind != gtx_common.DimensionKind.LOCAL ] - return gtir_dataflow.IteratorExpr(self.dc_node, field_dtype, field_dims, indices) + return gtir_dataflow.IteratorExpr( + self.dc_node, field_dtype, field_domain, it_indices + ) else: raise ValueError( @@ -155,9 +200,9 @@ def _parse_fieldop_arg( return arg.get_local_view(domain) -def _get_field_shape( +def _get_field_layout( domain: FieldopDomain, -) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr]]: +) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr], list[dace.symbolic.SymExpr]]: """ Parse the field operator domain and generates the shape of the result field. @@ -174,11 +219,14 @@ def _get_field_shape( domain: The field operator domain. Returns: - A tuple of two lists: the list of field dimensions and the list of dace - array sizes in each dimension. + A tuple of three lists containing: + - the domain dimensions + - the domain offset in each dimension + - the domain size in each dimension """ - domain_dims, _, domain_ubs = zip(*domain) - return list(domain_dims), list(domain_ubs) + domain_dims, domain_lbs, domain_ubs = zip(*domain) + domain_sizes = [(ub - lb) for lb, ub in zip(domain_lbs, domain_ubs)] + return list(domain_dims), list(domain_lbs), domain_sizes def _create_temporary_field( @@ -189,7 +237,7 @@ def _create_temporary_field( dataflow_output: gtir_dataflow.DataflowOutputEdge, ) -> FieldopData: """Helper method to allocate a temporary field where to write the output of a field operator.""" - field_dims, field_shape = _get_field_shape(domain) + field_dims, field_offset, field_shape = _get_field_layout(domain) output_desc = dataflow_output.result.dc_node.desc(sdfg) if isinstance(output_desc, dace.data.Array): @@ -197,6 +245,7 @@ def _create_temporary_field( assert isinstance(node_type.dtype.element_type, ts.ScalarType) assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype.element_type) # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) + field_offset.extend(output_desc.offset) field_shape.extend(output_desc.shape) elif isinstance(output_desc, dace.data.Scalar): assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) @@ -215,7 +264,11 @@ def _create_temporary_field( assert dataflow_output.result.gt_dtype.offset_type is not None field_dims.append(dataflow_output.result.gt_dtype.offset_type) - return FieldopData(field_node, ts.FieldType(field_dims, field_dtype)) + return FieldopData( + field_node, + ts.FieldType(field_dims, field_dtype), + offset=(field_offset if set(field_offset) != {0} else None), + ) def extract_domain(node: gtir.Node) -> FieldopDomain: @@ -285,7 +338,8 @@ def translate_as_fieldop( # parse the domain of the field operator domain = extract_domain(domain_expr) - domain_indices = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain]) + domain_dims, domain_offsets, _ = zip(*domain) + domain_indices = _get_domain_indices(domain_dims, domain_offsets) # visit the list of arguments to be passed to the lambda expression stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] @@ -350,10 +404,8 @@ def translate_broadcast_scalar( assert cpm.is_ref_to(stencil_expr, "deref") domain = extract_domain(domain_expr) - field_dims, field_shape = _get_field_shape(domain) - field_subset = sbs.Range.from_string( - ",".join(dace_gtir_utils.get_map_variable(dim) for dim in field_dims) - ) + output_dims, output_offset, output_shape = _get_field_layout(domain) + output_subset = sbs.Range.from_indices(_get_domain_indices(output_dims, output_offset)) assert len(node.args) == 1 scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain) @@ -369,26 +421,15 @@ def translate_broadcast_scalar( assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr) if len(node.args[0].type.dims) == 0: # zero-dimensional field input_subset = "0" - elif all( - isinstance(scalar_expr.indices[dim], gtir_dataflow.SymbolExpr) - for dim in scalar_expr.dimensions - if dim not in field_dims - ): - input_subset = ",".join( - dace_gtir_utils.get_map_variable(dim) - if dim in field_dims - else scalar_expr.indices[dim].value # type: ignore[union-attr] # catched by exception above - for dim in scalar_expr.dimensions - ) else: - raise ValueError(f"Cannot deref field {scalar_expr.field} in broadcast expression.") + input_subset = scalar_expr.get_memlet_subset(sdfg) input_node = scalar_expr.field gt_dtype = node.args[0].type.dtype else: raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.") - output, _ = sdfg.add_temp_transient(field_shape, input_node.desc(sdfg).dtype) + output, _ = sdfg.add_temp_transient(output_shape, input_node.desc(sdfg).dtype) output_node = state.add_access(output) sdfg_builder.add_mapped_tasklet( @@ -400,13 +441,13 @@ def translate_broadcast_scalar( }, inputs={"__inp": dace.Memlet(data=input_node.data, subset=input_subset)}, code="__val = __inp", - outputs={"__val": dace.Memlet(data=output_node.data, subset=field_subset)}, + outputs={"__val": dace.Memlet(data=output_node.data, subset=output_subset)}, input_nodes={input_node.data: input_node}, output_nodes={output_node.data: output_node}, external_edges=True, ) - return FieldopData(output_node, ts.FieldType(field_dims, gt_dtype)) + return FieldopData(output_node, ts.FieldType(output_dims, gt_dtype), output_offset) def translate_if( @@ -467,7 +508,7 @@ def construct_output(inner_data: FieldopData) -> FieldopData: outer, _ = sdfg.add_temp_transient_like(inner_desc) outer_node = state.add_access(outer) - return FieldopData(outer_node, inner_data.gt_type) + return inner_data.make_copy(outer_node) result_temps = gtx_utils.tree_map(construct_output)(true_br_args) @@ -513,7 +554,7 @@ def _get_data_nodes( ) -> FieldopResult: if isinstance(data_type, ts.FieldType): data_node = state.add_access(data_name) - return FieldopData(data_node, data_type) + return sdfg_builder.make_field(data_node, data_type) elif isinstance(data_type, ts.ScalarType): if data_name in sdfg.symbols: @@ -522,7 +563,7 @@ def _get_data_nodes( ) else: data_node = state.add_access(data_name) - return FieldopData(data_node, data_type) + return sdfg_builder.make_field(data_node, data_type) elif isinstance(data_type, ts.TupleType): tuple_fields = dace_gtir_utils.get_tuple_fields(data_name, data_type) @@ -579,7 +620,7 @@ def translate_literal( data_type = node.type data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type) - return FieldopData(data_node, data_type) + return FieldopData(data_node, data_type, offset=None) def translate_make_tuple( @@ -708,7 +749,7 @@ def translate_scalar_expr( dace.Memlet(data=temp_name, subset="0"), ) - return FieldopData(temp_node, node.type) + return FieldopData(temp_node, node.type, offset=None) def translate_symbol_ref( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 74142dec66..cfba4d61e5 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -90,17 +90,42 @@ class IteratorExpr: Args: field: Access node to the field this iterator operates on. gt_dtype: GT4Py data type, which includes the `offset_type` local dimension for lists. - dimensions: Field domain represented as a sorted list of dimensions, needed - to order the map index variables and dereference an element in the field. + field_domain: Field domain represented as a sorted list of dimensions and offset values, + used to find the position of a map index variable in the memlet subset. The offset + value is either the start index of dimension range or the compile-time value of + a shift expression, or a composition of both, and it must be subtracted to the index + variable when constructing the memlet subset range. indices: Maps each dimension to an index value, which could be either a symbolic value or the result of a tasklet computation like neighbors connectivity or dynamic offset. """ field: dace.nodes.AccessNode gt_dtype: itir_ts.ListType | ts.ScalarType - dimensions: list[gtx_common.Dimension] + field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]] indices: dict[gtx_common.Dimension, DataExpr] + def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: + if not all(isinstance(self.indices[dim], SymbolExpr) for dim, _ in self.field_domain): + raise ValueError(f"Cannot deref iterator {self}.") + + field_desc = self.field.desc(sdfg) + if isinstance(self.gt_dtype, itir_ts.ListType): + assert len(field_desc.shape) == len(self.field_domain) + 1 + assert self.gt_dtype.offset_type is not None + field_domain = [*self.field_domain, (self.gt_dtype.offset_type, 0)] + else: + assert len(field_desc.shape) == len(self.field_domain) + field_domain = self.field_domain + + return sbs.Range.from_string( + ",".join( + str(self.indices[dim].value - offset) # type: ignore[union-attr] + if dim in self.indices + else f"0:{size}" + for (dim, offset), size in zip(field_domain, field_desc.shape, strict=True) + ) + ) + class DataflowInputEdge(Protocol): """ @@ -271,8 +296,17 @@ def _add_input_data_edge( src_subset: sbs.Range, dst_node: dace.nodes.Node, dst_conn: Optional[str] = None, + src_offset: Optional[list[dace.symbolic.SymExpr]] = None, ) -> None: - edge = MemletInputEdge(self.state, src, src_subset, dst_node, dst_conn) + input_subset = ( + src_subset + if src_offset is None + else sbs.Range( + (start - off, stop - off, step) + for (start, stop, step), off in zip(src_subset, src_offset, strict=True) + ) + ) + edge = MemletInputEdge(self.state, src, input_subset, dst_node, dst_conn) self.input_edges.append(edge) def _add_edge( @@ -440,34 +474,21 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: field_desc = arg_expr.field.desc(self.sdfg) if isinstance(field_desc, dace.data.Scalar): # deref a zero-dimensional field - assert len(arg_expr.dimensions) == 0 + assert len(arg_expr.field_domain) == 0 assert isinstance(node.type, ts.ScalarType) return MemletExpr(arg_expr.field, arg_expr.gt_dtype, subset="0") # default case: deref a field with one or more dimensions if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()): - # when all indices are symblic expressions, we can perform direct field access through a memlet - if isinstance(arg_expr.gt_dtype, itir_ts.ListType): - assert len(field_desc.shape) == len(arg_expr.dimensions) + 1 - assert arg_expr.gt_dtype.offset_type is not None - field_dims = [*arg_expr.dimensions, arg_expr.gt_dtype.offset_type] - else: - assert len(field_desc.shape) == len(arg_expr.dimensions) - field_dims = arg_expr.dimensions - - field_subset = sbs.Range( - (arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr] - if dim in arg_expr.indices - else (0, size - 1, 1) - for dim, size in zip(field_dims, field_desc.shape) - ) + # when all indices are symbolic expressions, we can perform direct field access through a memlet + field_subset = arg_expr.get_memlet_subset(self.sdfg) return MemletExpr(arg_expr.field, arg_expr.gt_dtype, field_subset) # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, # either indirection through connectivity table or dynamic cartesian offset. - assert all(dim in arg_expr.indices for dim in arg_expr.dimensions) - assert len(field_desc.shape) == len(arg_expr.dimensions) - field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions] + assert all(dim in arg_expr.indices for dim, _ in arg_expr.field_domain) + assert len(field_desc.shape) == len(arg_expr.field_domain) + field_indices = [(dim, arg_expr.indices[dim]) for dim, _ in arg_expr.field_domain] index_connectors = [ IndexConnectorFmt.format(dim=dim.value) for dim, index in field_indices @@ -494,6 +515,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: sbs.Range.from_array(field_desc), deref_node, "field", + src_offset=[offset for (_, offset) in arg_expr.field_domain], ) for dim, index_expr in field_indices: @@ -532,7 +554,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) - assert offset_provider.codomain in it.dimensions + assert any(dim == offset_provider.codomain for dim, _ in it.field_domain) assert offset_provider.source_dim in it.indices origin_index = it.indices[offset_provider.source_dim] assert isinstance(origin_index, SymbolExpr) @@ -560,10 +582,12 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: gt_dtype=node.type, subset=sbs.Range.from_string( ",".join( - it.indices[dim].value # type: ignore[union-attr] + str(it.indices[dim].value - offset) # type: ignore[union-attr] if dim != offset_provider.codomain else f"0:{size}" - for dim, size in zip(it.dimensions, field_desc.shape, strict=True) + for (dim, offset), size in zip( + it.field_domain, field_desc.shape, strict=True + ) ) ), ) @@ -971,14 +995,13 @@ def _make_cartesian_shift( self, it: IteratorExpr, offset_dim: gtx_common.Dimension, offset_expr: DataExpr ) -> IteratorExpr: """Implements cartesian shift along one dimension.""" - assert offset_dim in it.dimensions + assert any(dim == offset_dim for dim, _ in it.field_domain) new_index: SymbolExpr | ValueExpr - assert offset_dim in it.indices index_expr = it.indices[offset_dim] if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr): # purely symbolic expression which can be interpreted at compile time new_index = SymbolExpr( - dace.symbolic.pystr_to_symbolic(index_expr.value) + offset_expr.value, + index_expr.value + offset_expr.value, index_expr.dc_dtype, ) else: @@ -1032,15 +1055,10 @@ def _make_cartesian_shift( ) # a new iterator with a shifted index along one dimension - return IteratorExpr( - field=it.field, - gt_dtype=it.gt_dtype, - dimensions=it.dimensions, - indices={ - dim: (new_index if dim == offset_dim else index) - for dim, index in it.indices.items() - }, - ) + shifted_indices = { + dim: (new_index if dim == offset_dim else index) for dim, index in it.indices.items() + } + return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices) def _make_dynamic_neighbor_offset( self, @@ -1094,7 +1112,7 @@ def _make_unstructured_shift( offset_expr: DataExpr, ) -> IteratorExpr: """Implements shift in unstructured domain by means of a neighbor table.""" - assert connectivity.codomain in it.dimensions + assert any(dim == connectivity.codomain for dim, _ in it.field_domain) neighbor_dim = connectivity.codomain assert neighbor_dim not in it.indices @@ -1117,9 +1135,7 @@ def _make_unstructured_shift( offset_expr, offset_table_node, origin_index ) - return IteratorExpr( - field=it.field, gt_dtype=it.gt_dtype, dimensions=it.dimensions, indices=shifted_indices - ) + return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices) def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: # convert builtin-index type to dace type diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 52284edfac..f15287e64c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -16,6 +16,7 @@ import abc import dataclasses +import functools import itertools import operator from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union @@ -98,9 +99,16 @@ def add_mapped_tasklet( class SDFGBuilder(DataflowBuilder, Protocol): """Visitor interface available to GTIR-primitive translators.""" + @abc.abstractmethod + def make_field( + self, data_node: dace.nodes.AccessNode, data_type: ts.FieldType | ts.ScalarType + ) -> gtir_builtin_translators.FieldopData: + """Retrieve the field data descriptor including the domain offset information.""" + ... + @abc.abstractmethod def get_symbol_type(self, symbol_name: str) -> ts.DataType: - """Retrieve the GT4Py type of a symbol used in the program.""" + """Retrieve the GT4Py type of a symbol used in the SDFG.""" ... @abc.abstractmethod @@ -141,6 +149,15 @@ def _collect_symbols_in_domain_expressions( ) +def _get_tuple_type(data: tuple[gtir_builtin_translators.FieldopResult, ...]) -> ts.TupleType: + """ + Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. + """ + return ts.TupleType( + types=[_get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] + ) + + @dataclasses.dataclass(frozen=True) class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. @@ -157,6 +174,9 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): offset_provider_type: gtx_common.OffsetProviderType global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {}) + field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = dataclasses.field( + default_factory=lambda: {} + ) map_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") ) @@ -167,6 +187,15 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderTypeElem: return self.offset_provider_type[offset] + def make_field( + self, data_node: dace.nodes.AccessNode, data_type: ts.FieldType | ts.ScalarType + ) -> gtir_builtin_translators.FieldopData: + if isinstance(data_type, ts.FieldType): + domain_offset = self.field_offsets.get(data_node.data, None) + else: + domain_offset = None + return gtir_builtin_translators.FieldopData(data_node, data_type, domain_offset) + def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.global_symbols[symbol_name] @@ -248,12 +277,10 @@ def _add_storage( """ if isinstance(gt_type, ts.TupleType): tuple_fields = [] - for tname, tsymbol_type in dace_gtir_utils.get_tuple_fields( - name, gt_type, flatten=True - ): + for tname, ttype in dace_gtir_utils.get_tuple_fields(name, gt_type, flatten=True): tuple_fields.extend( self._add_storage( - sdfg, symbolic_arguments, tname, tsymbol_type, transient, tuple_name=name + sdfg, symbolic_arguments, tname, ttype, transient, tuple_name=name ) ) return tuple_fields @@ -275,7 +302,6 @@ def _add_storage( tuple_name, gt_type.dims ) sdfg.add_array(name, sym_shape, dc_dtype, strides=sym_strides, transient=transient) - return [(name, gt_type)] elif isinstance(gt_type, ts.ScalarType): @@ -344,7 +370,7 @@ def make_temps( head_state.add_nedge( field.dc_node, temp_node, sdfg.make_array_memlet(field.dc_node.data) ) - return gtir_builtin_translators.FieldopData(temp_node, field.gt_type) + return field.make_copy(temp_node) temp_result = gtx_utils.tree_map(make_temps)(result) return list(gtx_utils.flatten_nested_tuple((temp_result,))) @@ -405,6 +431,10 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: if node.function_definitions: raise NotImplementedError("Functions expected to be inlined as lambda calls.") + # Since program field arguments are passed to the SDFG as full-shape arrays, + # there is no offset that needs to be compensated. + assert len(self.field_offsets) == 0 + sdfg = dace.SDFG(node.id) sdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) @@ -459,7 +489,7 @@ def visit_SetAt( The SDFG head state, eventually updated if the target write requires a new state. """ - temp_fields = self._visit_expression(stmt.expr, sdfg, state) + source_fields = self._visit_expression(stmt.expr, sdfg, state) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression # in case the statement returns more than one field @@ -482,17 +512,26 @@ def visit_SetAt( } target_state: Optional[dace.SDFGState] = None - for temp, target in zip(temp_fields, target_fields, strict=True): + for source, target in zip(source_fields, target_fields, strict=True): target_desc = sdfg.arrays[target.dc_node.data] assert not target_desc.transient if isinstance(target.gt_type, ts.FieldType): - subset = ",".join( + target_subset = ",".join( f"{domain[dim][0]}:{domain[dim][1]}" for dim in target.gt_type.dims ) + source_subset = ( + target_subset + if source.offset is None + else ",".join( + f"{domain[dim][0] - offset}:{domain[dim][1] - offset}" + for dim, offset in zip(target.gt_type.dims, source.offset, strict=True) + ) + ) else: assert len(domain) == 0 - subset = "0" + target_subset = "0" + source_subset = "0" if target.dc_node.data in state_input_data: # if inout argument, write the result in separate next state @@ -501,17 +540,21 @@ def visit_SetAt( target_state = sdfg.add_state_after(state, f"post_{state.label}") # create new access nodes in the target state target_state.add_nedge( - target_state.add_access(temp.dc_node.data), + target_state.add_access(source.dc_node.data), target_state.add_access(target.dc_node.data), - dace.Memlet(data=target.dc_node.data, subset=subset, other_subset=subset), + dace.Memlet( + data=target.dc_node.data, subset=target_subset, other_subset=source_subset + ), ) # remove isolated access node state.remove_node(target.dc_node) else: state.add_nedge( - temp.dc_node, + source.dc_node, target.dc_node, - dace.Memlet(data=target.dc_node.data, subset=subset, other_subset=subset), + dace.Memlet( + data=target.dc_node.data, subset=target_subset, other_subset=source_subset + ), ) return target_state or state @@ -574,17 +617,65 @@ def visit_Lambda( (str(param.id), arg) for param, arg in zip(node.params, args, strict=True) ] + def flatten_tuples( + name: str, + arg: gtir_builtin_translators.FieldopResult, + ) -> list[tuple[str, gtir_builtin_translators.FieldopData]]: + if isinstance(arg, tuple): + tuple_type = _get_tuple_type(arg) + tuple_field_names = [ + arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type) + ] + tuple_args = zip(tuple_field_names, arg, strict=True) + return list( + itertools.chain(*[flatten_tuples(fname, farg) for fname, farg in tuple_args]) + ) + else: + return [(name, arg)] + + lambda_arg_nodes = dict( + itertools.chain(*[flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) + ) + # inherit symbols from parent scope but eventually override with local symbols lambda_symbols = { sym: self.global_symbols[sym] for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) } | { - pname: dace_gtir_utils.get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type + pname: _get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type for pname, arg in lambda_args_mapping } + def get_field_domain_offset( + p_name: str, p_type: ts.DataType + ) -> dict[str, Optional[list[dace.symbolic.SymExpr]]]: + if isinstance(p_type, ts.FieldType): + if p_name in lambda_arg_nodes: + arg = lambda_arg_nodes[p_name] + assert isinstance(arg, gtir_builtin_translators.FieldopData) + return {p_name: arg.offset} + elif field_domain_offset := self.field_offsets.get(p_name, None): + return {p_name: field_domain_offset} + elif isinstance(p_type, ts.TupleType): + p_fields = dace_gtir_utils.get_tuple_fields(p_name, p_type, flatten=True) + return functools.reduce( + lambda field_offsets, field: ( + field_offsets | get_field_domain_offset(field[0], field[1]) + ), + p_fields, + {}, + ) + return {} + + # populate mapping from field name to domain offset + lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = {} + for p_name, p_type in lambda_symbols.items(): + lambda_field_offsets |= get_field_domain_offset(p_name, p_type) + # lower let-statement lambda node as a nested SDFG - lambda_translator = GTIRToSDFG(self.offset_provider_type, lambda_symbols) + lambda_translator = GTIRToSDFG( + self.offset_provider_type, lambda_symbols, lambda_field_offsets + ) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) nstate = nsdfg.add_state("lambda") @@ -603,30 +694,11 @@ def visit_Lambda( head_state=nstate, ) - def _flatten_tuples( - name: str, - arg: gtir_builtin_translators.FieldopResult, - ) -> list[tuple[str, gtir_builtin_translators.FieldopData]]: - if isinstance(arg, tuple): - tuple_type = dace_gtir_utils.get_tuple_type(arg) - tuple_field_names = [ - arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type) - ] - tuple_args = zip(tuple_field_names, arg, strict=True) - return list( - itertools.chain(*[_flatten_tuples(fname, farg) for fname, farg in tuple_args]) - ) - else: - return [(name, arg)] - # Process lambda inputs # # All input arguments are passed as parameters to the nested SDFG, therefore # we they are stored as non-transient array and scalar objects. # - lambda_arg_nodes = dict( - itertools.chain(*[_flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) - ) connectivity_arrays = { dace_utils.connectivity_identifier(offset) for offset in dace_utils.filter_connectivity_types(self.offset_provider_type) @@ -739,7 +811,7 @@ def construct_output_for_nested_sdfg( head_state.add_edge( nsdfg_node, connector, outer_node, None, sdfg.make_array_memlet(outer) ) - outer_data = gtir_builtin_translators.FieldopData(outer_node, inner_data.gt_type) + outer_data = inner_data.make_copy(outer_node) elif inner_data.dc_node.data in lambda_arg_nodes: # This if branch and the next one handle the non-transient result nodes. # Non-transient nodes are just input nodes that are immediately returned @@ -748,7 +820,7 @@ def construct_output_for_nested_sdfg( outer_data = lambda_arg_nodes[inner_data.dc_node.data] else: outer_node = head_state.add_access(inner_data.dc_node.data) - outer_data = gtir_builtin_translators.FieldopData(outer_node, inner_data.gt_type) + outer_data = inner_data.make_copy(outer_node) # Isolated access node will make validation fail. # Isolated access nodes can be found in the join-state of an if-expression # or in lambda expressions that just construct tuples from input arguments. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index caec6cd87e..118f0449c8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -9,7 +9,7 @@ from __future__ import annotations import itertools -from typing import Any, Dict, TypeVar +from typing import Dict, TypeVar import dace @@ -58,15 +58,6 @@ def get_tuple_fields( return fields -def get_tuple_type(data: tuple[Any, ...]) -> ts.TupleType: - """ - Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. - """ - return ts.TupleType( - types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] - ) - - def replace_invalid_symbols(sdfg: dace.SDFG, ir: gtir.Program) -> gtir.Program: """ Ensure that all symbols used in the program IR are valid strings (e.g. no unicode-strings). diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 9c52ea81c3..f5191fbaaa 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -47,7 +47,7 @@ VFTYPE = ts.FieldType(dims=[Vertex], dtype=FLOAT_TYPE) V2E_FTYPE = ts.FieldType(dims=[Vertex, V2EDim], dtype=EFTYPE.dtype) CARTESIAN_OFFSETS = { - "IDim": IDim, + IDim.value: IDim, } SIMPLE_MESH: MeshDescriptor = simple_mesh() SKIP_VALUE_MESH: MeshDescriptor = skip_value_mesh() @@ -735,13 +735,13 @@ def test_gtir_cartesian_shift_left(): # cartesian shift with literal integer offset stencil1_inlined = im.as_fieldop( - im.lambda_("a")(im.plus(im.deref(im.shift("IDim", OFFSET)("a")), DELTA)), + im.lambda_("a")(im.plus(im.deref(im.shift(IDim.value, OFFSET)("a")), DELTA)), domain, )("x") # fieldview flavor of same stencil, in which a temporary field is initialized with the `DELTA` constant value stencil1_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a")(im.deref(im.shift("IDim", OFFSET)("a"))), + im.lambda_("a")(im.deref(im.shift(IDim.value, OFFSET)("a"))), domain, )("x"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -749,13 +749,15 @@ def test_gtir_cartesian_shift_left(): # use dynamic offset retrieved from field stencil2_inlined = im.as_fieldop( - im.lambda_("a", "off")(im.plus(im.deref(im.shift("IDim", im.deref("off"))("a")), DELTA)), + im.lambda_("a", "off")( + im.plus(im.deref(im.shift(IDim.value, im.deref("off"))("a")), DELTA) + ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil2_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )("x", "x_offset"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -764,14 +766,14 @@ def test_gtir_cartesian_shift_left(): # use the result of an arithmetic field operation as dynamic offset stencil3_inlined = im.as_fieldop( im.lambda_("a", "off")( - im.plus(im.deref(im.shift("IDim", im.plus(im.deref("off"), 0))("a")), DELTA) + im.plus(im.deref(im.shift(IDim.value, im.plus(im.deref("off"), 0))("a")), DELTA) ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil3_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )( "x", @@ -828,13 +830,13 @@ def test_gtir_cartesian_shift_right(): # cartesian shift with literal integer offset stencil1_inlined = im.as_fieldop( - im.lambda_("a")(im.plus(im.deref(im.shift("IDim", -OFFSET)("a")), DELTA)), + im.lambda_("a")(im.plus(im.deref(im.shift(IDim.value, -OFFSET)("a")), DELTA)), domain, )("x") # fieldview flavor of same stencil, in which a temporary field is initialized with the `DELTA` constant value stencil1_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a")(im.deref(im.shift("IDim", -OFFSET)("a"))), + im.lambda_("a")(im.deref(im.shift(IDim.value, -OFFSET)("a"))), domain, )("x"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -842,13 +844,15 @@ def test_gtir_cartesian_shift_right(): # use dynamic offset retrieved from field stencil2_inlined = im.as_fieldop( - im.lambda_("a", "off")(im.plus(im.deref(im.shift("IDim", im.deref("off"))("a")), DELTA)), + im.lambda_("a", "off")( + im.plus(im.deref(im.shift(IDim.value, im.deref("off"))("a")), DELTA) + ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil2_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )("x", "x_offset"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -857,14 +861,14 @@ def test_gtir_cartesian_shift_right(): # use the result of an arithmetic field operation as dynamic offset stencil3_inlined = im.as_fieldop( im.lambda_("a", "off")( - im.plus(im.deref(im.shift("IDim", im.plus(im.deref("off"), 0))("a")), DELTA) + im.plus(im.deref(im.shift(IDim.value, im.plus(im.deref("off"), 0))("a")), DELTA) ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil3_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )( "x", @@ -1539,6 +1543,91 @@ def test_gtir_reduce_with_cond_neighbors(): assert np.allclose(v, v_ref) +def test_gtir_symbolic_domain(): + MARGIN = 2 + assert MARGIN < N + OFFSET = 1000 * 1000 * 1000 + domain = im.domain( + gtx_common.GridType.CARTESIAN, ranges={IDim: (MARGIN, im.minus("size", MARGIN))} + ) + left_domain = im.domain( + gtx_common.GridType.CARTESIAN, + ranges={IDim: (im.minus(MARGIN, OFFSET), im.minus(im.minus("size", MARGIN), OFFSET))}, + ) + right_domain = im.domain( + gtx_common.GridType.CARTESIAN, + ranges={IDim: (im.plus(MARGIN, OFFSET), im.plus(im.plus("size", MARGIN), OFFSET))}, + ) + shift_left_stencil = im.lambda_("a")(im.deref(im.shift(IDim.value, OFFSET)("a"))) + shift_right_stencil = im.lambda_("a")(im.deref(im.shift(IDim.value, -OFFSET)("a"))) + testee = gtir.Program( + id="symbolic_domain", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let( + "xᐞ1", + im.op_as_fieldop("multiplies", left_domain)( + 4.0, + im.as_fieldop( + shift_left_stencil, + left_domain, + )("x"), + ), + )( + im.let( + "xᐞ2", + im.op_as_fieldop("multiplies", right_domain)( + 3.0, + im.as_fieldop( + shift_right_stencil, + right_domain, + )("x"), + ), + )( + im.let( + "xᐞ3", + im.as_fieldop( + shift_right_stencil, + domain, + )("xᐞ1"), + )( + im.let( + "xᐞ4", + im.as_fieldop( + shift_left_stencil, + domain, + )("xᐞ2"), + )( + im.let("xᐞ5", im.op_as_fieldop("plus", domain)("xᐞ3", "xᐞ4"))( + im.op_as_fieldop("plus", domain)("xᐞ5", "x") + ) + ) + ) + ) + ), + domain=domain, + target=gtir.SymRef(id="y"), + ) + ], + ) + + a = np.random.rand(N) + b = np.random.rand(N) + ref = np.concatenate((b[0:MARGIN], a[MARGIN : N - MARGIN] * 8, b[N - MARGIN : N])) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + + sdfg(a, b, **FSYMBOLS) + assert np.allclose(b, ref) + + def test_gtir_let_lambda(): domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) subdomain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, im.minus("size", 1))}) @@ -1722,7 +1811,7 @@ def test_gtir_let_lambda_with_cond(): def test_gtir_let_lambda_with_tuple1(): - domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, im.minus("size", 1))}) testee = gtir.Program( id="let_lambda_with_tuple1", function_definitions=[], @@ -1753,10 +1842,12 @@ def test_gtir_let_lambda_with_tuple1(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a)) + a_ref = np.concatenate((z_fields[0][:1], a[1 : N - 1], z_fields[0][N - 1 :])) + b_ref = np.concatenate((z_fields[1][:1], b[1 : N - 1], z_fields[1][N - 1 :])) sdfg(a, b, *z_fields, **FSYMBOLS) - assert np.allclose(z_fields[0], a) - assert np.allclose(z_fields[1], b) + assert np.allclose(z_fields[0], a_ref) + assert np.allclose(z_fields[1], b_ref) def test_gtir_let_lambda_with_tuple2(): From f6c219bd989e3c5325da1173bade4bff2ac9e650 Mon Sep 17 00:00:00 2001 From: SF-N Date: Tue, 26 Nov 2024 15:59:58 +0100 Subject: [PATCH 053/178] bug[next]: Fix SetAt type inference for ts.DeferredType (#1747) Fix to correctly handle tuples of ts.DeferredType. --------- Co-authored-by: Till Ehrengruber --- src/gt4py/next/iterator/type_system/inference.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 987eb0f308..249019769b 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -509,7 +509,10 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: # the target can have fewer elements than the expr in which case the output from the # expression is simply discarded. expr_type = functools.reduce( - lambda tuple_type, i: tuple_type.types[i], # type: ignore[attr-defined] # format ensured by primitive_constituents + lambda tuple_type, i: tuple_type.types[i] # type: ignore[attr-defined] # format ensured by primitive_constituents + # `ts.DeferredType` only occurs for scans returning a tuple + if not isinstance(tuple_type, ts.DeferredType) + else ts.DeferredType(constraint=None), path, node.expr.type, ) From f6c0498dbffd85a80a32281e5a53bfb35e00e745 Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 27 Nov 2024 09:55:46 +0100 Subject: [PATCH 054/178] feat[next][dace]: Lowering to SDFG of index builtin (#1751) Implements the lowering to SDFG of the GTIR index builtin. --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 14 ++++ .../gtir_builtin_translators.py | 83 ++++++++++++++++--- .../runners/dace_fieldview/gtir_sdfg.py | 2 + tests/next_tests/definitions.py | 1 - .../dace_tests/test_gtir_to_sdfg.py | 50 ++++++++++- 5 files changed, 134 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 2864c7f727..a4e111e785 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -519,6 +519,20 @@ def _impl(it: itir.Expr) -> itir.FunCall: return _impl +def index(dim: common.Dimension) -> itir.FunCall: + """ + Create a call to the `index` builtin, shorthand for `call("index")(axis)`, + after converting the given dimension to `itir.AxisLiteral`. + + Args: + dim: the dimension corresponding to the index axis. + + Returns: + A function that constructs a Field of indices in the given dimension. + """ + return call("index")(itir.AxisLiteral(value=dim.value, kind=dim.kind)) + + def map_(op): """Create a `map_` call.""" return call(call("map_")(op)) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 60dcd8ddc9..94ab3a6f76 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -18,7 +18,7 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -277,20 +277,31 @@ def extract_domain(node: gtir.Node) -> FieldopDomain: the corresponding lower and upper bounds. The returned lower bound is inclusive, the upper bound is exclusive: [lower_bound, upper_bound[ """ - assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) domain = [] - for named_range in node.args: - assert cpm.is_call_to(named_range, "named_range") - assert len(named_range.args) == 3 - axis = named_range.args[0] - assert isinstance(axis, gtir.AxisLiteral) - lower_bound, upper_bound = ( - dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(arg)) - for arg in named_range.args[1:3] - ) - dim = gtx_common.Dimension(axis.value, axis.kind) - domain.append((dim, lower_bound, upper_bound)) + + def parse_range_boundary(expr: gtir.Expr) -> str: + return dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(expr)) + + if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")): + for named_range in node.args: + assert cpm.is_call_to(named_range, "named_range") + assert len(named_range.args) == 3 + axis = named_range.args[0] + assert isinstance(axis, gtir.AxisLiteral) + lower_bound, upper_bound = (parse_range_boundary(arg) for arg in named_range.args[1:3]) + dim = gtx_common.Dimension(axis.value, axis.kind) + domain.append((dim, lower_bound, upper_bound)) + + elif isinstance(node, domain_utils.SymbolicDomain): + assert str(node.grid_type) in {"cartesian_domain", "unstructured_domain"} + for dim, drange in node.ranges.items(): + domain.append( + (dim, parse_range_boundary(drange.start), parse_range_boundary(drange.stop)) + ) + + else: + raise ValueError(f"Invalid domain {node}.") return domain @@ -545,6 +556,51 @@ def construct_output(inner_data: FieldopData) -> FieldopData: return result_temps +def translate_index( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + """ + Lowers the `index` builtin function to a mapped tasklet that writes the dimension + index values to a transient array. The extent of the index range is taken from + the domain information that should be present in the node annex. + """ + assert "domain" in node.annex + domain = extract_domain(node.annex.domain) + assert len(domain) == 1 + dim, lower_bound, upper_bound = domain[0] + dim_index = dace_gtir_utils.get_map_variable(dim) + + field_dims, field_offset, field_shape = _get_field_layout(domain) + field_type = ts.FieldType(field_dims, dace_utils.as_itir_type(INDEX_DTYPE)) + + output, _ = sdfg.add_temp_transient(field_shape, INDEX_DTYPE) + output_node = state.add_access(output) + + sdfg_builder.add_mapped_tasklet( + "index", + state, + map_ranges={ + dim_index: f"{lower_bound}:{upper_bound}", + }, + inputs={}, + code=f"__val = {dim_index}", + outputs={ + "__val": dace.Memlet( + data=output_node.data, + subset=sbs.Range.from_indices(_get_domain_indices(field_dims, field_offset)), + ) + }, + input_nodes={}, + output_nodes={output_node.data: output_node}, + external_edges=True, + ) + + return FieldopData(output_node, field_type, field_offset) + + def _get_data_nodes( sdfg: dace.SDFG, state: dace.SDFGState, @@ -777,6 +833,7 @@ def translate_symbol_ref( translate_as_fieldop, translate_broadcast_scalar, translate_if, + translate_index, translate_literal, translate_make_tuple, translate_tuple_get, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index f15287e64c..6b5e164458 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -568,6 +568,8 @@ def visit_FunCall( # use specialized dataflow builder classes for each builtin function if cpm.is_call_to(node, "if_"): return gtir_builtin_translators.translate_if(node, sdfg, head_state, self) + elif cpm.is_call_to(node, "index"): + return gtir_builtin_translators.translate_index(node, sdfg, head_state, self) elif cpm.is_call_to(node, "make_tuple"): return gtir_builtin_translators.translate_make_tuple(node, sdfg, head_state, self) elif cpm.is_call_to(node, "tuple_get"): diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 01fd18897d..349d3e9f70 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -154,7 +154,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), ] GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ - (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index f5191fbaaa..c7466b853f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -12,15 +12,15 @@ Note: this test module covers the fieldview flavour of ITIR. """ -import copy import functools import numpy as np import pytest -from gt4py.next import common as gtx_common, constructors +from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import infer_domain from gt4py.next.type_system import type_specifications as ts from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -1973,3 +1973,49 @@ def test_gtir_if_values(): sdfg(a, b, c, **FSYMBOLS) assert np.allclose(c, np.where(a < b, a, b)) + + +def test_gtir_index(): + MARGIN = 2 + assert MARGIN < N + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + subdomain = im.domain( + gtx_common.GridType.CARTESIAN, ranges={IDim: (MARGIN, im.minus("size", MARGIN))} + ) + + testee = gtir.Program( + id="gtir_cast", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let("i", im.index(IDim))( + im.op_as_fieldop("plus", domain)( + "i", + im.as_fieldop( + im.lambda_("a")(im.deref(im.shift(IDim.value, 1)("a"))), subdomain + )("i"), + ) + ), + domain=subdomain, + target=gtir.SymRef(id="x"), + ) + ], + ) + + v = np.empty(N, dtype=np.int32) + + # we need to run domain inference in order to add the domain annex information to the index node. + testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) + sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + + ref = np.concatenate( + (v[:MARGIN], np.arange(MARGIN, N - MARGIN, dtype=np.int32), v[N - MARGIN :]) + ) + + sdfg(v, **FSYMBOLS) + np.allclose(v, ref) From 3ece412f0d78f32893d8f01ed0e74c8b38388854 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 28 Nov 2024 13:13:55 -0500 Subject: [PATCH 055/178] fix[cartesian]: Deactivate K offset write in `gt:gpu` (#1755) Following the issue logged as https://github.com/GridTools/gt4py/issues/1754 we are deactivating the K-offset write feature until we can figure out why it's failing. I will monitor any activity on the ticket if users are hit by this. --------- Co-authored-by: Hannes Vogt --- src/gt4py/cartesian/frontend/gtscript_frontend.py | 7 +++++++ .../multi_feature_tests/test_code_generation.py | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index ade05921ef..f155ea6209 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -1460,6 +1460,13 @@ def visit_Assign(self, node: ast.Assign) -> list: loc=nodes.Location.from_ast_node(t), ) + if self.backend_name in ["gt:gpu"]: + raise GTScriptSyntaxError( + message=f"Assignment to non-zero offsets in K is not available in {self.backend_name} as an unsolved bug remains." + "Please refer to https://github.com/GridTools/gt4py/issues/1754.", + loc=nodes.Location.from_ast_node(t), + ) + if not self._is_known(name): if name in self.temp_decls: field_decl = self.temp_decls[name] diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index c4d07d7337..7c4956b3ef 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -667,6 +667,10 @@ def test_K_offset_write_conditional(backend): pytest.skip( f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" ) + if backend in ["gt:gpu"]: + pytest.skip( + f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1754" + ) arraylib = get_array_library(backend) array_shape = (1, 1, 4) From 886058496c1ebcb90ba530a796213d1fec7c7095 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 29 Nov 2024 08:46:06 +0100 Subject: [PATCH 056/178] refact[next][dace]: Helper function for field operator constructor (#1743) Includes refactoring of the code for construction of field operators, in order to make it usable by the three lowering functions that construct fields: `translate_as_fieldop()`, `translate_broadcast_scalar()`, and `translate_index()`. --- .../gtir_builtin_translators.py | 242 +++++++----------- 1 file changed, 94 insertions(+), 148 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 94ab3a6f76..ff011c4193 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -18,7 +18,11 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -229,40 +233,75 @@ def _get_field_layout( return list(domain_dims), list(domain_lbs), domain_sizes -def _create_temporary_field( +def _create_field_operator( sdfg: dace.SDFG, state: dace.SDFGState, domain: FieldopDomain, node_type: ts.FieldType, - dataflow_output: gtir_dataflow.DataflowOutputEdge, + sdfg_builder: gtir_sdfg.SDFGBuilder, + input_edges: Sequence[gtir_dataflow.DataflowInputEdge], + output_edge: gtir_dataflow.DataflowOutputEdge, ) -> FieldopData: - """Helper method to allocate a temporary field where to write the output of a field operator.""" + """ + Helper method to allocate a temporary field to store the output of a field operator. + + Args: + sdfg: The SDFG that represents the scope of the field data. + state: The SDFG state where to create an access node to the field data. + domain: The domain of the field operator that computes the field. + node_type: The GT4Py type of the IR node that produces this field. + sdfg_builder: The object used to build the map scope in the provided SDFG. + input_edges: List of edges to pass input data into the dataflow. + output_edge: Edge representing the dataflow output data. + + Returns: + The field data descriptor, which includes the field access node in the given `state` + and the field domain offset. + """ field_dims, field_offset, field_shape = _get_field_layout(domain) + field_indices = _get_domain_indices(field_dims, field_offset) + + dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) - output_desc = dataflow_output.result.dc_node.desc(sdfg) - if isinstance(output_desc, dace.data.Array): + field_subset = sbs.Range.from_indices(field_indices) + if isinstance(output_edge.result.gt_dtype, ts.ScalarType): + assert output_edge.result.gt_dtype == node_type.dtype + assert isinstance(dataflow_output_desc, dace.data.Scalar) + assert dataflow_output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) + field_dtype = output_edge.result.gt_dtype + else: assert isinstance(node_type.dtype, itir_ts.ListType) - assert isinstance(node_type.dtype.element_type, ts.ScalarType) - assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype.element_type) + assert output_edge.result.gt_dtype.element_type == node_type.dtype.element_type + assert isinstance(dataflow_output_desc, dace.data.Array) + assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) + field_dtype = output_edge.result.gt_dtype.element_type # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) - field_offset.extend(output_desc.offset) - field_shape.extend(output_desc.shape) - elif isinstance(output_desc, dace.data.Scalar): - assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) - else: - raise ValueError(f"Cannot create field for dace type {output_desc}.") + assert output_edge.result.gt_dtype.offset_type is not None + field_dims.append(output_edge.result.gt_dtype.offset_type) + field_shape.extend(dataflow_output_desc.shape) + field_offset.extend(dataflow_output_desc.offset) + field_subset = field_subset + sbs.Range.from_array(dataflow_output_desc) # allocate local temporary storage - temp_name, _ = sdfg.add_temp_transient(field_shape, output_desc.dtype) - field_node = state.add_access(temp_name) + field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) + field_node = state.add_access(field_name) - if isinstance(dataflow_output.result.gt_dtype, ts.ScalarType): - field_dtype = dataflow_output.result.gt_dtype - else: - assert isinstance(dataflow_output.result.gt_dtype.element_type, ts.ScalarType) - field_dtype = dataflow_output.result.gt_dtype.element_type - assert dataflow_output.result.gt_dtype.offset_type is not None - field_dims.append(dataflow_output.result.gt_dtype.offset_type) + # create map range corresponding to the field operator domain + me, mx = sdfg_builder.add_map( + "fieldop", + state, + ndrange={ + dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" + for dim, lower_bound, upper_bound in domain + }, + ) + + # here we setup the edges passing through the map entry node + for edge in input_edges: + edge.connect(me) + + # and here the edge writing the dataflow result data through the map exit node + output_edge.connect(mx, field_node, field_subset) return FieldopData( field_node, @@ -341,7 +380,8 @@ def translate_as_fieldop( # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. - return translate_broadcast_scalar(node, sdfg, state, sdfg_builder) + stencil_expr = im.lambda_("a")(im.deref("a")) + stencil_expr.expr.type = node.type.dtype # type: ignore[attr-defined] else: raise NotImplementedError( f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node." @@ -349,117 +389,18 @@ def translate_as_fieldop( # parse the domain of the field operator domain = extract_domain(domain_expr) - domain_dims, domain_offsets, _ = zip(*domain) - domain_indices = _get_domain_indices(domain_dims, domain_offsets) # visit the list of arguments to be passed to the lambda expression stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) - input_edges, output = taskgen.visit(stencil_expr, args=stencil_args) - output_desc = output.result.dc_node.desc(sdfg) - - if isinstance(node.type.dtype, itir_ts.ListType): - assert isinstance(output_desc, dace.data.Array) - # additional local dimension for neighbors - # TODO(phimuell): Investigate if we should swap the two. - output_subset = sbs.Range.from_indices(domain_indices) + sbs.Range.from_array(output_desc) - else: - assert isinstance(output_desc, dace.data.Scalar) - output_subset = sbs.Range.from_indices(domain_indices) - - # create map range corresponding to the field operator domain - me, mx = sdfg_builder.add_map( - "fieldop", - state, - ndrange={ - dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" - for dim, lower_bound, upper_bound in domain - }, - ) - - # allocate local temporary storage for the result field - result_field = _create_temporary_field(sdfg, state, domain, node.type, output) - - # here we setup the edges from the map entry node - for edge in input_edges: - edge.connect(me) - - # and here the edge writing the result data through the map exit node - output.connect(mx, result_field.dc_node, output_subset) - - return result_field - + input_edges, output_edge = taskgen.visit(stencil_expr, args=stencil_args) -def translate_broadcast_scalar( - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, -) -> FieldopResult: - """ - Generates the dataflow subgraph for the 'as_fieldop' builtin function for the - special case where the argument to 'as_fieldop' is a 'deref' scalar expression, - rather than a lambda function. This case corresponds to broadcasting the scalar - value over the field domain. Therefore, it is lowered to a mapped tasklet that - just writes the scalar value out to all elements of the result field. - """ - assert isinstance(node, gtir.FunCall) - assert cpm.is_call_to(node.fun, "as_fieldop") - assert isinstance(node.type, ts.FieldType) - - fun_node = node.fun - assert len(fun_node.args) == 2 - stencil_expr, domain_expr = fun_node.args - assert cpm.is_ref_to(stencil_expr, "deref") - - domain = extract_domain(domain_expr) - output_dims, output_offset, output_shape = _get_field_layout(domain) - output_subset = sbs.Range.from_indices(_get_domain_indices(output_dims, output_offset)) - - assert len(node.args) == 1 - scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain) - - if isinstance(node.args[0].type, ts.ScalarType): - assert isinstance(scalar_expr, (gtir_dataflow.MemletExpr, gtir_dataflow.ValueExpr)) - input_subset = ( - str(scalar_expr.subset) if isinstance(scalar_expr, gtir_dataflow.MemletExpr) else "0" - ) - input_node = scalar_expr.dc_node - gt_dtype = node.args[0].type - elif isinstance(node.args[0].type, ts.FieldType): - assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr) - if len(node.args[0].type.dims) == 0: # zero-dimensional field - input_subset = "0" - else: - input_subset = scalar_expr.get_memlet_subset(sdfg) - - input_node = scalar_expr.field - gt_dtype = node.args[0].type.dtype - else: - raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.") - - output, _ = sdfg.add_temp_transient(output_shape, input_node.desc(sdfg).dtype) - output_node = state.add_access(output) - - sdfg_builder.add_mapped_tasklet( - "broadcast", - state, - map_ranges={ - dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" - for dim, lower_bound, upper_bound in domain - }, - inputs={"__inp": dace.Memlet(data=input_node.data, subset=input_subset)}, - code="__val = __inp", - outputs={"__val": dace.Memlet(data=output_node.data, subset=output_subset)}, - input_nodes={input_node.data: input_node}, - output_nodes={output_node.data: output_node}, - external_edges=True, + return _create_field_operator( + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge ) - return FieldopData(output_node, ts.FieldType(output_dims, gt_dtype), output_offset) - def translate_if( node: gtir.Node, @@ -567,38 +508,44 @@ def translate_index( index values to a transient array. The extent of the index range is taken from the domain information that should be present in the node annex. """ + assert cpm.is_call_to(node, "index") + assert isinstance(node.type, ts.FieldType) + assert "domain" in node.annex domain = extract_domain(node.annex.domain) assert len(domain) == 1 - dim, lower_bound, upper_bound = domain[0] + dim, _, _ = domain[0] dim_index = dace_gtir_utils.get_map_variable(dim) - field_dims, field_offset, field_shape = _get_field_layout(domain) - field_type = ts.FieldType(field_dims, dace_utils.as_itir_type(INDEX_DTYPE)) - - output, _ = sdfg.add_temp_transient(field_shape, INDEX_DTYPE) - output_node = state.add_access(output) - - sdfg_builder.add_mapped_tasklet( + index_data = sdfg.temp_data_name() + sdfg.add_scalar(index_data, INDEX_DTYPE, transient=True) + index_node = state.add_access(index_data) + index_value = gtir_dataflow.ValueExpr( + dc_node=index_node, + gt_dtype=dace_utils.as_itir_type(INDEX_DTYPE), + ) + index_write_tasklet = sdfg_builder.add_tasklet( "index", state, - map_ranges={ - dim_index: f"{lower_bound}:{upper_bound}", - }, inputs={}, + outputs={"__val"}, code=f"__val = {dim_index}", - outputs={ - "__val": dace.Memlet( - data=output_node.data, - subset=sbs.Range.from_indices(_get_domain_indices(field_dims, field_offset)), - ) - }, - input_nodes={}, - output_nodes={output_node.data: output_node}, - external_edges=True, + ) + state.add_edge( + index_write_tasklet, + "__val", + index_node, + None, + dace.Memlet(data=index_data, subset="0"), ) - return FieldopData(output_node, field_type, field_offset) + input_edges = [ + gtir_dataflow.EmptyInputEdge(state, index_write_tasklet), + ] + output_edge = gtir_dataflow.DataflowOutputEdge(state, index_value) + return _create_field_operator( + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge + ) def _get_data_nodes( @@ -831,7 +778,6 @@ def translate_symbol_ref( # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol __primitive_translators: list[PrimitiveTranslator] = [ translate_as_fieldop, - translate_broadcast_scalar, translate_if, translate_index, translate_literal, From d9b38f476ee5df1995d27b7497037f3f19c9b6e6 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 29 Nov 2024 02:50:43 -0500 Subject: [PATCH 057/178] hotfix[cartesian]: Fixing k offset write utest deactivate (#1757) Missed a utest in #1755 --- .../multi_feature_tests/test_code_generation.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 7c4956b3ef..e51b3ef09d 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -582,13 +582,17 @@ def test_K_offset_write(backend): # Cuda generates bad code for the K offset if backend == "cuda": pytest.skip("cuda K-offset write generates bad code") - if backend in ["gt:gpu", "dace:gpu"]: + if backend in ["dace:gpu"]: import cupy as cp if cp.cuda.runtime.runtimeGetVersion() < 12000: pytest.skip( f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" ) + if backend in ["gt:gpu"]: + pytest.skip( + f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1754" + ) arraylib = get_array_library(backend) array_shape = (1, 1, 4) @@ -660,7 +664,7 @@ def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): def test_K_offset_write_conditional(backend): if backend == "cuda": pytest.skip("Cuda backend is not capable of K offset write") - if backend in ["gt:gpu", "dace:gpu"]: + if backend in ["dace:gpu"]: import cupy as cp if cp.cuda.runtime.runtimeGetVersion() < 12000: From 791f67d031127872fc6375819267f59faeaf85ba Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 29 Nov 2024 10:02:34 +0100 Subject: [PATCH 058/178] test[next]: Fix flaky failure in GTIR to SDFG tests (#1759) The SDFG name has to be unique to avoid issues with parallel build in CI tests. --- .../runners_tests/dace_tests/test_gtir_to_sdfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index c7466b853f..b1ba4ccf22 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1984,7 +1984,7 @@ def test_gtir_index(): ) testee = gtir.Program( - id="gtir_cast", + id="gtir_index", function_definitions=[], params=[ gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), From 04513ba859d5ed55ea99999f6fd826a2a542a627 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Fri, 29 Nov 2024 13:57:10 +0100 Subject: [PATCH 059/178] fix[next]: use current working directory as default cache folder root (#1744) Change the root folder of the gt4py cache directory from the system temp folder to the current working directory, which is more visible and also avoids polluting shared filesystems in hpc clusters. --------- Co-authored-by: Hannes Vogt --- src/gt4py/next/config.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index ed244c2932..7a19f3eb9d 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -11,7 +11,6 @@ import enum import os import pathlib -import tempfile from typing import Final @@ -51,25 +50,22 @@ def env_flag_to_bool(name: str, default: bool) -> bool: ) -_PREFIX: Final[str] = "GT4PY" - #: Master debug flag #: Changes defaults for all the other options to be as helpful for debugging as possible. #: Does not override values set in environment variables. -DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=False) +DEBUG: Final[bool] = env_flag_to_bool("GT4PY_DEBUG", default=False) #: Verbose flag for DSL compilation errors VERBOSE_EXCEPTIONS: bool = env_flag_to_bool( - f"{_PREFIX}_VERBOSE_EXCEPTIONS", default=True if DEBUG else False + "GT4PY_VERBOSE_EXCEPTIONS", default=True if DEBUG else False ) #: Where generated code projects should be persisted. #: Only active if BUILD_CACHE_LIFETIME is set to PERSISTENT BUILD_CACHE_DIR: pathlib.Path = ( - pathlib.Path(os.environ.get(f"{_PREFIX}_BUILD_CACHE_DIR", tempfile.gettempdir())) - / "gt4py_cache" + pathlib.Path(os.environ.get("GT4PY_BUILD_CACHE_DIR", pathlib.Path.cwd())) / ".gt4py_cache" ) @@ -77,11 +73,11 @@ def env_flag_to_bool(name: str, default: bool) -> bool: #: - SESSION: generated code projects get destroyed when the interpreter shuts down #: - PERSISTENT: generated code projects are written to BUILD_CACHE_DIR and persist between runs BUILD_CACHE_LIFETIME: BuildCacheLifetime = BuildCacheLifetime[ - os.environ.get(f"{_PREFIX}_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() + os.environ.get("GT4PY_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() ] #: Build type to be used when CMake is used to compile generated code. #: Might have no effect when CMake is not used as part of the toolchain. CMAKE_BUILD_TYPE: CMakeBuildType = CMakeBuildType[ - os.environ.get(f"{_PREFIX}_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() + os.environ.get("GT4PY_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() ] From d581060e5c6e8b6f64b72cce041d539956ca4727 Mon Sep 17 00:00:00 2001 From: SF-N Date: Sat, 30 Nov 2024 09:39:26 +0100 Subject: [PATCH 060/178] bug[next]: ConstantFolding after create_global_tmps (#1756) Do `ConstantFolding` within `domain_union` to avoid nested minima and maxima by `create_global_tmps` --------- Co-authored-by: Till Ehrengruber --- src/gt4py/next/iterator/ir_utils/domain_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index f5625b509c..4a023f7535 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -16,6 +16,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import trace_shifts +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> dict[str, int]: @@ -168,6 +169,8 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), [domain.ranges[dim].stop for domain in domains], ) + # constant fold expression to keep the tree small + start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr new_domain_ranges[dim] = SymbolicRange(start, stop) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) From a26d91f409ea5d67f168bbbc4a2157df2ed1080b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 21:31:13 +0100 Subject: [PATCH 061/178] fix[next]: Fix annex & type preservation in inline_lambdas (#1760) Co-authored-by: SF-N --- src/gt4py/next/iterator/transforms/inline_lambdas.py | 11 +++++------ src/gt4py/next/iterator/transforms/remap_symbols.py | 5 ++++- src/gt4py/next/iterator/type_system/inference.py | 7 +++++-- .../transforms_tests/test_inline_lambdas.py | 7 +++++++ 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 5ec9ec5d0b..9053214b39 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -97,7 +97,6 @@ def new_name(name): if all(eligible_params): new_expr.location = node.location - return new_expr else: new_expr = ir.FunCall( fun=ir.Lambda( @@ -111,11 +110,11 @@ def new_name(name): args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], location=node.location, ) - for attr in ("type", "recorded_shifts", "domain"): - if hasattr(node.annex, attr): - setattr(new_expr.annex, attr, getattr(node.annex, attr)) - itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) - return new_expr + for attr in ("type", "recorded_shifts", "domain"): + if hasattr(node.annex, attr): + setattr(new_expr.annex, attr, getattr(node.annex, attr)) + itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) + return new_expr @dataclasses.dataclass diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 08d896121d..fb909dc5d0 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -10,6 +10,7 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait from gt4py.next.iterator import ir +from gt4py.next.iterator.type_system import inference as type_inference class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): @@ -46,7 +47,9 @@ def visit_SymRef( self, node: ir.SymRef, *, name_map: Dict[str, str], active: Optional[Set[str]] = None ): if active and node.id in active: - return ir.SymRef(id=name_map.get(node.id, node.id)) + new_ref = ir.SymRef(id=name_map.get(node.id, node.id)) + type_inference.copy_type(from_=node, to=new_ref, allow_untyped=True) + return new_ref return node def generic_visit( # type: ignore[override] diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 249019769b..ffca6cc7a7 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -95,14 +95,17 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: node.type = type_ -def copy_type(from_: itir.Node, to: itir.Node, allow_untyped=False) -> None: +def copy_type(from_: itir.Node, to: itir.Node, allow_untyped: bool = False) -> None: """ Copy type from one node to another. This function mainly exists for readability reasons. """ assert allow_untyped is not None or isinstance(from_.type, ts.TypeSpec) - _set_node_type(to, from_.type) # type: ignore[arg-type] + if from_.type is None: + assert allow_untyped + return + _set_node_type(to, from_.type) def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index 2e0a83d33b..c10d48ad06 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -84,3 +84,10 @@ def test_inline_lambda_args(): ) inlined = InlineLambdas.apply(testee, opcount_preserving=True, force_inline_lambda_args=True) assert inlined == expected + + +def test_type_preservation(): + testee = im.let("a", "b")("a") + testee.type = testee.annex.type = ts.ScalarType(kind=ts.ScalarKind.FLOAT32) + inlined = InlineLambdas.apply(testee) + assert inlined.type == inlined.annex.type == ts.ScalarType(kind=ts.ScalarKind.FLOAT32) From 99c53004663b0b58c7ce8335bcc30e347d3686b5 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 22:08:39 +0100 Subject: [PATCH 062/178] refactor[next]: Use `set_at` & `as_fieldop` instead of `closure` in iterator tests (#1691) --- .../test_cartesian_offset_provider.py | 12 +++--- .../iterator_tests/test_conditional.py | 2 +- .../test_strided_offset_provider.py | 7 ++-- .../iterator_tests/test_trivial.py | 10 ++--- .../iterator_tests/test_tuple.py | 28 +++++-------- .../iterator_tests/test_anton_toy.py | 21 +++++----- .../iterator_tests/test_fvm_nabla.py | 40 ++++++++----------- .../iterator_tests/test_hdiff.py | 10 ++--- 8 files changed, 55 insertions(+), 75 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py index 2ebcd0c033..fedfd83fd2 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py @@ -10,7 +10,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import double_roundtrip, roundtrip @@ -27,16 +27,14 @@ def foo(inp): @fendef(offset_provider={"I": I_loc, "J": J_loc}) def fencil(output, input): - closure( - cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input] - ) + domain = cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)) + set_at(as_fieldop(foo, domain)(input), domain, output) @fendef(offset_provider={"I": J_loc, "J": I_loc}) def fencil_swapped(output, input): - closure( - cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input] - ) + domain = cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)) + set_at(as_fieldop(foo, domain)(input), domain, output) def test_cartesian_offset_provider(): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index 551c567e61..eae66d425b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef +from gt4py.next.iterator.runtime import set_at, fendef, fundef from next_tests.unit_tests.conftest import program_processor, run_processor diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index 7bde55bfd2..68e5f9d532 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -10,8 +10,8 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain, as_fieldop +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.unit_tests.conftest import program_processor, run_processor from gt4py.next.iterator.embedded import StridedConnectivityField @@ -36,7 +36,8 @@ def foo(inp): @fendef(offset_provider={"O": LocA2LocAB_offset_provider}) def fencil(size, out, inp): - closure(unstructured_domain(named_range(LocA, 0, size)), foo, out, [inp]) + domain = unstructured_domain(named_range(LocA, 0, size)) + set_at(as_fieldop(foo, domain)(inp), domain, out) @pytest.mark.uses_strided_neighbor_offset diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index 5f1c70a6b3..fe89fe7c9d 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -12,7 +12,7 @@ import gt4py.next as gtx from gt4py.next.iterator import transforms from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.cases import IDim, JDim, KDim from next_tests.unit_tests.conftest import program_processor, run_processor @@ -94,12 +94,8 @@ def test_shifted_arg_to_lift(program_processor): @fendef def fen_direct_deref(i_size, j_size, out, inp): - closure( - cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size)), - deref, - out, - [inp], - ) + domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size)) + set_at(as_fieldop(deref, domain)(inp), domain, out) def test_direct_deref(program_processor): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 2d84439c93..39d0bd69c3 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef +from gt4py.next.iterator.runtime import set_at, fendef, fundef from next_tests.unit_tests.conftest import program_processor, run_processor @@ -114,16 +114,10 @@ def test_tuple_of_field_output_constructed_inside(program_processor, stencil): @fendef def fencil(size0, size1, size2, inp1, inp2, out1, out2): - closure( - cartesian_domain( - named_range(IDim, 0, size0), - named_range(JDim, 0, size1), - named_range(KDim, 0, size2), - ), - stencil, - make_tuple(out1, out2), - [inp1, inp2], + domain = cartesian_domain( + named_range(IDim, 0, size0), named_range(JDim, 0, size1), named_range(KDim, 0, size2) ) + set_at(as_fieldop(stencil, domain)(inp1, inp2), domain, make_tuple(out1, out2)) shape = [5, 7, 9] rng = np.random.default_rng() @@ -159,15 +153,13 @@ def stencil(inp1, inp2, inp3): @fendef def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3): - closure( - cartesian_domain( - named_range(IDim, 0, size0), - named_range(JDim, 0, size1), - named_range(KDim, 0, size2), - ), - stencil, + domain = cartesian_domain( + named_range(IDim, 0, size0), named_range(JDim, 0, size1), named_range(KDim, 0, size2) + ) + set_at( + as_fieldop(stencil, domain)(inp1, inp2, inp3), + domain, make_tuple(make_tuple(out1, out2), out3), - [inp1, inp2, inp3], ) shape = [5, 7, 9] diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 3ce9d6b470..d0a1601816 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -10,8 +10,15 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import cartesian_domain, deref, lift, named_range, shift -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.builtins import ( + cartesian_domain, + deref, + lift, + named_range, + shift, + as_fieldop, +) +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn from next_tests.unit_tests.conftest import program_processor, run_processor @@ -85,14 +92,10 @@ def test_anton_toy(stencil, program_processor): @fendef(offset_provider={"i": IDim, "j": JDim}) def fencil(x, y, z, out, inp): - closure( - cartesian_domain( - named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z) - ), - stencil, - out, - [inp], + domain = cartesian_domain( + named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z) ) + set_at(as_fieldop(stencil, domain)(inp), domain, out) shape = [5, 7, 9] rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 4487681abf..22b4d8b3c5 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -28,8 +28,9 @@ reduce, tuple_get, unstructured_domain, + as_fieldop, ) -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( assert_close, @@ -55,7 +56,8 @@ def compute_zavgS(pp, S_M): @fendef def compute_zavgS_fencil(n_edges, out, pp, S_M): - closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS, out, [pp, S_M]) + domain = unstructured_domain(named_range(Edge, 0, n_edges)) + set_at(as_fieldop(compute_zavgS, domain)(pp, S_M), domain, out) @fundef @@ -100,12 +102,8 @@ def compute_pnabla2(pp, S_M, sign, vol): @fendef def nabla(n_nodes, out, pp, S_MXX, S_MYY, sign, vol): - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - pnabla, - out, - [pp, S_MXX, S_MYY, sign, vol], - ) + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at(as_fieldop(pnabla, domain)(pp, S_MXX, S_MYY, sign, vol), domain, out) @pytest.mark.requires_atlas @@ -145,7 +143,8 @@ def test_compute_zavgS(program_processor): @fendef def compute_zavgS2_fencil(n_edges, out, pp, S_M): - closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS2, out, [pp, S_M]) + domain = unstructured_domain(named_range(Edge, 0, n_edges)) + set_at(as_fieldop(compute_zavgS2, domain)(pp, S_M), domain, out) @pytest.mark.requires_atlas @@ -212,12 +211,8 @@ def test_nabla(program_processor): @fendef def nabla2(n_nodes, out, pp, S, sign, vol): - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla2, - out, - [pp, S, sign, vol], - ) + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at(as_fieldop(compute_pnabla2, domain)(pp, S, sign, vol), domain, out) @pytest.mark.requires_atlas @@ -276,17 +271,16 @@ def compute_pnabla_sign(pp, S_M, vol, node_index, is_pole_edge): @fendef def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_pole_edge): # TODO replace by single stencil which returns tuple - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla_sign, + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at( + as_fieldop(compute_pnabla_sign, domain)(pp, S_MXX, vol, node_index, is_pole_edge), + domain, out_MXX, - [pp, S_MXX, vol, node_index, is_pole_edge], ) - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla_sign, + set_at( + as_fieldop(compute_pnabla_sign, domain)(pp, S_MYY, vol, node_index, is_pole_edge), + domain, out_MYY, - [pp, S_MYY, vol, node_index, is_pole_edge], ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 45793b1d3e..e44e92013f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests.cases import IDim, JDim @@ -57,12 +57,8 @@ def hdiff_sten(inp, coeff): @fendef(offset_provider={"I": IDim, "J": JDim}) def hdiff(inp, coeff, out, x, y): - closure( - cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y)), - hdiff_sten, - out, - [inp, coeff], - ) + domain = cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y)) + set_at(as_fieldop(hdiff_sten, domain)(inp, coeff), domain, out) @pytest.mark.uses_origin From 6f49699f00ceb9e466fa4448bab779bc061df047 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 2 Dec 2024 13:09:47 +0100 Subject: [PATCH 063/178] style[eve]: remove unused imports and fix typos (#1748) Small cleanup PR in the eve framework: - Removes a stale `.gitignore` file. As far as I understood from the git history, earlier versions of this codebase had many `.gitignore` files in many places. Looks like this one is a leftover from a previous time. - Remove a couple of stale includes. The language server marked them as unused and since tests still pass, I guess we really don't need them anymore. - Fixed a couple of typos in comments - Fixed two typos in the github PR template --- .github/pull_request_template.md | 4 ++-- src/gt4py/eve/.gitignore | 1 - src/gt4py/eve/__init__.py | 14 ++------------ src/gt4py/eve/codegen.py | 6 +++--- src/gt4py/eve/datamodels/__init__.py | 4 ++-- src/gt4py/eve/datamodels/core.py | 16 ++++++++-------- src/gt4py/eve/extended_typing.py | 4 ---- src/gt4py/eve/trees.py | 8 -------- src/gt4py/eve/type_validation.py | 2 +- src/gt4py/eve/utils.py | 2 +- src/gt4py/next/ffront/decorator.py | 2 +- 11 files changed, 20 insertions(+), 43 deletions(-) delete mode 100644 src/gt4py/eve/.gitignore diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 7284a7df04..83304a9c62 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -15,7 +15,7 @@ Delete this comment and add a proper description of the changes contained in thi - test: Adding missing tests or correcting existing tests : cartesian | eve | next | storage - # ONLY if changes are limited to a specific subsytem + # ONLY if changes are limited to a specific subsystem - PR Description: @@ -27,7 +27,7 @@ Delete this comment and add a proper description of the changes contained in thi ## Requirements - [ ] All fixes and/or new features come with corresponding tests. -- [ ] Important design decisions have been documented in the approriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. +- [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. If this PR contains code authored by new contributors please make sure: diff --git a/src/gt4py/eve/.gitignore b/src/gt4py/eve/.gitignore deleted file mode 100644 index 050cda3ca5..0000000000 --- a/src/gt4py/eve/.gitignore +++ /dev/null @@ -1 +0,0 @@ -_version.py diff --git a/src/gt4py/eve/__init__.py b/src/gt4py/eve/__init__.py index 0b8cfa7d62..5adac47da3 100644 --- a/src/gt4py/eve/__init__.py +++ b/src/gt4py/eve/__init__.py @@ -24,8 +24,7 @@ """ -from __future__ import annotations # isort:skip - +from __future__ import annotations from .concepts import ( AnnexManager, @@ -89,15 +88,6 @@ "SymbolRef", "VType", "register_annex_user", - "# datamodels" "Coerced", - "DataModel", - "FrozenModel", - "GenericDataModel", - "Unchecked", - "concretize", - "datamodel", - "field", - "frozenmodel", # datamodels "Coerced", "DataModel", @@ -122,7 +112,7 @@ "pre_walk_values", "walk_items", "walk_values", - "# type_definition", + # type_definitions "NOTHING", "ConstrainedStr", "Enum", diff --git a/src/gt4py/eve/codegen.py b/src/gt4py/eve/codegen.py index 15fda4f3b4..3869ff313b 100644 --- a/src/gt4py/eve/codegen.py +++ b/src/gt4py/eve/codegen.py @@ -347,7 +347,7 @@ def __str__(self) -> str: class Template(Protocol): """Protocol (abstract base class) defining the Template interface. - Direct subclassess of this base class only need to implement the + Direct subclasses of this base class only need to implement the abstract methods to adapt different template engines to this interface. @@ -654,8 +654,8 @@ def apply( # redefinition of symbol Args: root: An IR node. - node_templates (optiona): see :class:`NodeDumper`. - dump_function (optiona): see :class:`NodeDumper`. + node_templates (optional): see :class:`NodeDumper`. + dump_function (optional): see :class:`NodeDumper`. ``**kwargs`` (optional): custom extra parameters forwarded to `visit_NODE_TYPE_NAME()`. Returns: diff --git a/src/gt4py/eve/datamodels/__init__.py b/src/gt4py/eve/datamodels/__init__.py index 68ddea2510..6fd9c7bb21 100644 --- a/src/gt4py/eve/datamodels/__init__.py +++ b/src/gt4py/eve/datamodels/__init__.py @@ -11,7 +11,7 @@ Data Models can be considered as enhanced `attrs `_ / `dataclasses `_ providing additional features like automatic run-time type validation. Values assigned to fields -at initialization can be validated with automatic type checkings using the +at initialization can be validated with automatic type checking using the field type definition. Custom field validation methods can also be added with the :func:`validator` decorator, and global instance validation methods with :func:`root_validator`. @@ -33,7 +33,7 @@ 1. ``__init__()``. a. If a custom ``__init__`` already exists in the class, it will not be overwritten. - It is your responsability to call ``__auto_init__`` from there to obtain + It is your responsibility to call ``__auto_init__`` from there to obtain the described behavior. b. If there is not custom ``__init__``, the one generated by datamodels will be called first. diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index d596f59cfb..1b0e995156 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -24,7 +24,7 @@ try: - # For perfomance reasons, try to use cytoolz when possible (using cython) + # For performance reasons, try to use cytoolz when possible (using cython) import cytoolz as toolz except ModuleNotFoundError: # Fall back to pure Python toolz @@ -270,7 +270,7 @@ def datamodel( @overload -def datamodel( # redefinion of unused symbol +def datamodel( # redefinition of unused symbol cls: Type[_T], /, *, @@ -289,7 +289,7 @@ def datamodel( # redefinion of unused symbol # TODO(egparedes): Use @dataclass_transform(eq_default=True, field_specifiers=("field",)) -def datamodel( # redefinion of unused symbol +def datamodel( # redefinition of unused symbol cls: Optional[Type[_T]] = None, /, *, @@ -867,7 +867,7 @@ def _substitute_typevars( def _make_counting_attr_from_attribute( field_attrib: Attribute, *, include_type: bool = False, **kwargs: Any -) -> Any: # attr.s lies a bit in some typing definitons +) -> Any: # attr.s lies a bit in some typing definitions args = [ "default", "validator", @@ -965,7 +965,7 @@ def _type_converter(value: Any) -> _T: return value if isinstance(value, type_annotation) else type_annotation(value) except Exception as error: raise TypeError( - f"Error during coertion of given value '{value}' for field '{name}'." + f"Error during coercion of given value '{value}' for field '{name}'." ) from error return _type_converter @@ -996,7 +996,7 @@ def _type_converter(value: Any) -> _T: return _make_type_converter(origin_type, name) raise exceptions.EveTypeError( - f"Automatic type coertion for {type_annotation} types is not supported." + f"Automatic type coercion for {type_annotation} types is not supported." ) @@ -1085,7 +1085,7 @@ def _make_datamodel( ) else: - # Create field converter if automatic coertion is enabled + # Create field converter if automatic coercion is enabled converter: TypeConverter = cast( TypeConverter, _make_type_converter(type_hint, qualified_field_name) if coerce_field else None, @@ -1099,7 +1099,7 @@ def _make_datamodel( if isinstance(attr_value_in_cls, _KNOWN_MUTABLE_TYPES): warnings.warn( f"'{attr_value_in_cls.__class__.__name__}' value used as default in '{cls.__name__}.{key}'.\n" - "Mutable types should not defbe normally used as field defaults (use 'default_factory' instead).", + "Mutable types should not be used as field defaults (use 'default_factory' instead).", stacklevel=_stacklevel_offset + 2, ) setattr( diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index e276f3bccf..bf44824b49 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -14,12 +14,8 @@ from __future__ import annotations -import abc as _abc import array as _array -import collections.abc as _collections_abc -import ctypes as _ctypes import dataclasses as _dataclasses -import enum as _enum import functools as _functools import inspect as _inspect import mmap as _mmap diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index c8e8658413..8a3cc30f4b 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -31,14 +31,6 @@ from .type_definitions import Enum -try: - # For performance reasons, try to use cytoolz when possible (using cython) - import cytoolz as toolz -except ModuleNotFoundError: - # Fall back to pure Python toolz - import toolz # noqa: F401 [unused-import] - - TreeKey = Union[int, str] diff --git a/src/gt4py/eve/type_validation.py b/src/gt4py/eve/type_validation.py index 613eca40b2..e150832295 100644 --- a/src/gt4py/eve/type_validation.py +++ b/src/gt4py/eve/type_validation.py @@ -311,7 +311,7 @@ def __call__( # ... # # Since this can be an arbitrary type (not something regular like a collection) there is - # no way to check if the type parameter is verifed in the actual instance. + # no way to check if the type parameter is verified in the actual instance. # The only check can be done at run-time is to verify that the value is an instance of # the original type, completely ignoring the annotation. Ideally, the static type checker # can do a better job to try figure out if the type parameter is ok ... diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 8cb68845d7..2c66d39290 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -69,7 +69,7 @@ try: - # For perfomance reasons, try to use cytoolz when possible (using cython) + # For performance reasons, try to use cytoolz when possible (using cython) import cytoolz as toolz except ModuleNotFoundError: # Fall back to pure Python toolz diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 9ce07d01bb..61756f30c9 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -230,7 +230,7 @@ def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: if self.backend is None: warnings.warn( UserWarning( - f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a perfomance backend." + f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a performance backend." ), stacklevel=2, ) From f57d6e916e17ee2ff574ba6096ccc21911d27533 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 2 Dec 2024 20:02:44 +0100 Subject: [PATCH 064/178] fix[next]: Guard diskcache creation by file lock (#1745) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The disk cache used to cache compilation in the gtfn backend has a race condition manifesting itself in `sqlite3.OperationalError: database is locked` errors if multiple python processes try to initialize the `diskcache.Cache` object concurrently. This PR fixes this by guarding the object creation by a file-based lock in the same directory as the database. While this issue occurred frequently and was observed to be fixed on distributed file systems, the lock does not guarantee correct behavior in particular for accesses to the cache (beyond opening) since the underlying SQLite database is unreliable when stored on an NFS based file system. It does however ensure correctness of concurrent cache accesses on a local file system. See more information here: https://grantjenks.com/docs/diskcache/tutorial.html#settings https://www.sqlite.org/faq.html#q5 https://github.com/tox-dev/filelock/issues/73 NFS safe locking: https://gitlab.com/warsaw/flufl.lock [Barry Warsaw / FLUFL Lock · GitLab](https://gitlab.com/warsaw/flufl.lock) --- .pre-commit-config.yaml | 1 + constraints.txt | 8 ++--- min-extra-requirements-test.txt | 1 + min-requirements-test.txt | 1 + pyproject.toml | 1 + requirements-dev.txt | 8 ++--- .../next/program_processors/runners/gtfn.py | 32 ++++++++++++++++--- 7 files changed, 40 insertions(+), 12 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1c3b6e693f..7e1870c67f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -102,6 +102,7 @@ repos: - devtools==0.12.2 - diskcache==5.6.3 - factory-boy==3.3.1 + - filelock==3.16.1 - frozendict==2.4.6 - gridtools-cpp==2.3.8 - importlib-resources==6.4.5 diff --git a/constraints.txt b/constraints.txt index b4b8bc00d4..f039fa2125 100644 --- a/constraints.txt +++ b/constraints.txt @@ -49,7 +49,7 @@ executing==2.1.0 # via devtools, stack-data factory-boy==3.3.1 # via gt4py (pyproject.toml), pytest-factoryboy faker==33.0.0 # via factory-boy fastjsonschema==2.20.0 # via nbformat -filelock==3.16.1 # via tox, virtualenv +filelock==3.16.1 # via gt4py (pyproject.toml), tox, virtualenv fonttools==4.55.0 # via matplotlib fparser==0.1.4 # via dace frozendict==2.4.6 # via gt4py (pyproject.toml) @@ -113,8 +113,8 @@ psutil==6.1.0 # via -r requirements-dev.in, ipykernel, pytest-xdist ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data pybind11==2.13.6 # via gt4py (pyproject.toml) -pydantic==2.9.2 # via bump-my-version, pydantic-settings -pydantic-core==2.23.4 # via pydantic +pydantic==2.10.0 # via bump-my-version, pydantic-settings +pydantic-core==2.27.0 # via pydantic pydantic-settings==2.6.1 # via bump-my-version pydot==3.0.2 # via tach pygments==2.18.0 # via -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx @@ -159,7 +159,7 @@ stack-data==0.6.3 # via ipython stdlib-list==0.10.0 # via tach sympy==1.13.3 # via dace tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.14.3 # via -r requirements-dev.in +tach==0.14.4 # via -r requirements-dev.in tomli==2.1.0 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via tach tomlkit==0.13.2 # via bump-my-version diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 57c0d3969d..d7679a1f0f 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -67,6 +67,7 @@ deepdiff==5.6.0 devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 +filelock==3.0.0 frozendict==2.3 gridtools-cpp==2.3.8 hypothesis==6.0.0 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index 81a1c2dea3..cf505e88d6 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -63,6 +63,7 @@ deepdiff==5.6.0 devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 +filelock==3.0.0 frozendict==2.3 gridtools-cpp==2.3.8 hypothesis==6.0.0 diff --git a/pyproject.toml b/pyproject.toml index 02d301957c..1e24094fa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ 'devtools>=0.6', 'diskcache>=5.6.3', 'factory-boy>=3.3.0', + 'filelock>=3.0.0', 'frozendict>=2.3', 'gridtools-cpp>=2.3.8,==2.*', "importlib-resources>=5.0;python_version<'3.9'", diff --git a/requirements-dev.txt b/requirements-dev.txt index 9f95779fd5..6542be36f1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -49,7 +49,7 @@ executing==2.1.0 # via -c constraints.txt, devtools, stack-data factory-boy==3.3.1 # via -c constraints.txt, gt4py (pyproject.toml), pytest-factoryboy faker==33.0.0 # via -c constraints.txt, factory-boy fastjsonschema==2.20.0 # via -c constraints.txt, nbformat -filelock==3.16.1 # via -c constraints.txt, tox, virtualenv +filelock==3.16.1 # via -c constraints.txt, gt4py (pyproject.toml), tox, virtualenv fonttools==4.55.0 # via -c constraints.txt, matplotlib fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) @@ -113,8 +113,8 @@ psutil==6.1.0 # via -c constraints.txt, -r requirements-dev.in, ipyk ptyprocess==0.7.0 # via -c constraints.txt, pexpect pure-eval==0.2.3 # via -c constraints.txt, stack-data pybind11==2.13.6 # via -c constraints.txt, gt4py (pyproject.toml) -pydantic==2.9.2 # via -c constraints.txt, bump-my-version, pydantic-settings -pydantic-core==2.23.4 # via -c constraints.txt, pydantic +pydantic==2.10.0 # via -c constraints.txt, bump-my-version, pydantic-settings +pydantic-core==2.27.0 # via -c constraints.txt, pydantic pydantic-settings==2.6.1 # via -c constraints.txt, bump-my-version pydot==3.0.2 # via -c constraints.txt, tach pygments==2.18.0 # via -c constraints.txt, -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx @@ -158,7 +158,7 @@ stack-data==0.6.3 # via -c constraints.txt, ipython stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.13.3 # via -c constraints.txt, dace tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.14.3 # via -c constraints.txt, -r requirements-dev.in +tach==0.14.4 # via -c constraints.txt, -r requirements-dev.in tomli==2.1.0 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 1f3778f227..55f479c665 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -7,11 +7,14 @@ # SPDX-License-Identifier: BSD-3-Clause import functools +import pathlib +import tempfile import warnings from typing import Any, Optional import diskcache import factory +import filelock import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators @@ -139,13 +142,34 @@ def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str: class FileCache(diskcache.Cache): """ - This class extends `diskcache.Cache` to ensure the cache is closed upon deletion, - i.e. it ensures that any resources associated with the cache are properly - released when the instance is garbage collected. + This class extends `diskcache.Cache` to ensure the cache is properly + - opened when accessed by multiple processes using a file lock. This guards the creating of the + cache object, which has been reported to cause `sqlite3.OperationalError: database is locked` + errors and slow startup times when multiple processes access the cache concurrently. While this + issue occurred frequently and was observed to be fixed on distributed file systems, the lock + does not guarantee correct behavior in particular for accesses to the cache (beyond opening) + since the underlying SQLite database is unreliable when stored on an NFS based file system. + It does however ensure correctness of concurrent cache accesses on a local file system. See + #1745 for more details. + - closed upon deletion, i.e. it ensures that any resources associated with the cache are + properly released when the instance is garbage collected. """ + def __init__(self, directory: Optional[str | pathlib.Path] = None, **settings: Any) -> None: + if directory: + lock_dir = pathlib.Path(directory).parent + else: + lock_dir = pathlib.Path(tempfile.gettempdir()) + + lock = filelock.FileLock(lock_dir / "file_cache.lock") + with lock: + super().__init__(directory=directory, **settings) + + self._init_complete = True + def __del__(self) -> None: - self.close() + if getattr(self, "_init_complete", False): # skip if `__init__` didn't finished + self.close() class GTFNCompileWorkflowFactory(factory.Factory): From e5abcd20839e35c5480b512e1c2ef9b6f01c60e4 Mon Sep 17 00:00:00 2001 From: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com> Date: Tue, 3 Dec 2024 09:55:53 +0100 Subject: [PATCH 065/178] bug[next]: Fix codegen in gtfn for unused vertical offset provider (#1746) Providing an offest provider for a vertical dimension without using that dimension in a program, e.g. no arguments are fields defined on K, resulted in erroneous C++ code. --- .../codegens/gtfn/itir_to_gtfn_ir.py | 3 +++ tests/next_tests/integration_tests/cases.py | 10 +++++++++- .../ffront_tests/test_execution.py | 15 +++++++++++++++ .../ffront_tests/test_gt4py_builtins.py | 17 ++++++++++------- 4 files changed, 37 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 129d81d6f9..dc0012b041 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -198,6 +198,9 @@ def _collect_offset_definitions( "Mapping an offset to a horizontal dimension in unstructured is not allowed." ) # create alias from vertical offset to vertical dimension + offset_definitions[dim.value] = TagDefinition( + name=Sym(id=dim.value), alias=_vertical_dimension + ) offset_definitions[offset_name] = TagDefinition( name=Sym(id=offset_name), alias=SymRef(id=dim.value) ) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 9fb7850666..759cd1cf1f 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -499,13 +499,21 @@ def unstructured_case( Vertex: mesh_descriptor.num_vertices, Edge: mesh_descriptor.num_edges, Cell: mesh_descriptor.num_cells, - KDim: 10, }, grid_type=common.GridType.UNSTRUCTURED, allocator=exec_alloc_descriptor.allocator, ) +@pytest.fixture +def unstructured_case_3d(unstructured_case): + return dataclasses.replace( + unstructured_case, + default_sizes={**unstructured_case.default_sizes, KDim: 10}, + offset_provider={**unstructured_case.offset_provider, "KOff": KDim}, + ) + + def _allocate_from_type( case: Case, arg_type: ts.TypeSpec, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 1a51e3667d..0d994d1b22 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -41,6 +41,7 @@ Edge, cartesian_case, unstructured_case, + unstructured_case_3d, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, @@ -93,6 +94,20 @@ def testee(a: cases.VField) -> cases.EField: ) +def test_horizontal_only_with_3d_mesh(unstructured_case_3d): + # test field operator operating only on horizontal fields while using an offset provider + # including a vertical dimension. + @gtx.field_operator + def testee(a: cases.VField) -> cases.VField: + return a + + cases.verify_with_default_data( + unstructured_case_3d, + testee, + ref=lambda a: a, + ) + + @pytest.mark.uses_unstructured_shift def test_composed_unstructured_shift(unstructured_case): @gtx.field_operator diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 7648d34db7..ab1c625fef 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -29,6 +29,7 @@ Vertex, cartesian_case, unstructured_case, + unstructured_case_3d, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, @@ -105,10 +106,10 @@ def reduction_ke_field( @pytest.mark.parametrize( "fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__ ) -def test_neighbor_sum(unstructured_case, fop): - v2e_table = unstructured_case.offset_provider["V2E"].ndarray +def test_neighbor_sum(unstructured_case_3d, fop): + v2e_table = unstructured_case_3d.offset_provider["V2E"].ndarray - edge_f = cases.allocate(unstructured_case, fop, "edge_f")() + edge_f = cases.allocate(unstructured_case_3d, fop, "edge_f")() local_dim_idx = edge_f.domain.dims.index(Edge) + 1 adv_indexing = tuple( @@ -131,10 +132,10 @@ def test_neighbor_sum(unstructured_case, fop): where=broadcasted_table != common._DEFAULT_SKIP_VALUE, ) cases.verify( - unstructured_case, + unstructured_case_3d, fop, edge_f, - out=cases.allocate(unstructured_case, fop, cases.RETURN)(), + out=cases.allocate(unstructured_case_3d, fop, cases.RETURN)(), ref=ref, ) @@ -463,11 +464,13 @@ def conditional_program( ) -def test_promotion(unstructured_case): +def test_promotion(unstructured_case_3d): @gtx.field_operator def promotion( inp1: gtx.Field[[Edge, KDim], float64], inp2: gtx.Field[[KDim], float64] ) -> gtx.Field[[Edge, KDim], float64]: return inp1 / inp2 - cases.verify_with_default_data(unstructured_case, promotion, ref=lambda inp1, inp2: inp1 / inp2) + cases.verify_with_default_data( + unstructured_case_3d, promotion, ref=lambda inp1, inp2: inp1 / inp2 + ) From a2551acc0cf832ed9628b2930264e1d3998cebbf Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 3 Dec 2024 21:51:05 +0100 Subject: [PATCH 066/178] feat[next]: Remove dace_iterator backend and pass_manager_legacy (#1753) The dace orchestration tests are temporarily skipped until #1742 is merged. The dace backend with SDFG optimization is temporarily disabled in unit tests until #1639 is merged. A second PR will reorganize the files in dace backend module. --- .../transforms/pass_manager_legacy.py | 181 -- .../next/program_processors/runners/dace.py | 62 +- .../runners/dace_common/dace_backend.py | 30 +- .../runners/dace_common/utility.py | 9 +- .../runners/dace_common/workflow.py | 2 +- .../runners/dace_iterator/__init__.py | 377 ---- .../runners/dace_iterator/itir_to_sdfg.py | 809 --------- .../runners/dace_iterator/itir_to_tasklet.py | 1564 ----------------- .../runners/dace_iterator/utility.py | 149 -- .../runners/dace_iterator/workflow.py | 150 -- tests/next_tests/definitions.py | 41 +- .../feature_tests/dace/test_orchestration.py | 37 +- .../ffront_tests/ffront_test_utils.py | 4 +- 13 files changed, 74 insertions(+), 3341 deletions(-) delete mode 100644 src/gt4py/next/iterator/transforms/pass_manager_legacy.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_iterator/__init__.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_iterator/utility.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_iterator/workflow.py diff --git a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py deleted file mode 100644 index 94c962e92d..0000000000 --- a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py +++ /dev/null @@ -1,181 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause -# FIXME[#1582](tehrengruber): file should be removed after refactoring to GTIR -import enum -from typing import Callable, Optional - -from gt4py.eve import utils as eve_utils -from gt4py.next import common -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import fencil_to_program, inline_fundefs -from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet -from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple -from gt4py.next.iterator.transforms.constant_folding import ConstantFolding -from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination -from gt4py.next.iterator.transforms.eta_reduction import EtaReduction -from gt4py.next.iterator.transforms.fuse_maps import FuseMaps -from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars -from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan -from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas -from gt4py.next.iterator.transforms.inline_lifts import InlineLifts -from gt4py.next.iterator.transforms.merge_let import MergeLet -from gt4py.next.iterator.transforms.normalize_shifts import NormalizeShifts -from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref -from gt4py.next.iterator.transforms.scan_eta_reduction import ScanEtaReduction -from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce - - -@enum.unique -class LiftMode(enum.Enum): - FORCE_INLINE = enum.auto() - USE_TEMPORARIES = enum.auto() - - -def _inline_lifts(ir, lift_mode): - if lift_mode == LiftMode.FORCE_INLINE: - return InlineLifts().visit(ir) - elif lift_mode == LiftMode.USE_TEMPORARIES: - return InlineLifts( - flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT - | InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet. - ).visit(ir) - else: - raise ValueError() - - return ir - - -def _inline_into_scan(ir, *, max_iter=10): - for _ in range(10): - # in case there are multiple levels of lambdas around the scan we have to do multiple iterations - inlined = InlineIntoScan().visit(ir) - inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift_args=True) - if inlined == ir: - break - ir = inlined - else: - raise RuntimeError(f"Inlining into 'scan' did not converge within {max_iter} iterations.") - return ir - - -def apply_common_transforms( - ir: itir.Node, - *, - lift_mode=None, - offset_provider=None, - unroll_reduce=False, - common_subexpression_elimination=True, - force_inline_lambda_args=False, - unconditionally_collapse_tuples=False, - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, - symbolic_domain_sizes: Optional[dict[str, str]] = None, - offset_provider_type: Optional[common.OffsetProviderType] = None, -) -> itir.Program: - assert isinstance(ir, itir.FencilDefinition) - # TODO(havogt): if the runtime `offset_provider` is not passed, we cannot run global_tmps - if offset_provider_type is None: - offset_provider_type = common.offset_provider_to_type(offset_provider) - - ir = fencil_to_program.FencilToProgram().apply(ir) - icdlv_uids = eve_utils.UIDGenerator() - - if lift_mode is None: - lift_mode = LiftMode.FORCE_INLINE - assert isinstance(lift_mode, LiftMode) - ir = MergeLet().visit(ir) - ir = inline_fundefs.InlineFundefs().visit(ir) - - ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program - ir = PropagateDeref.apply(ir) - ir = NormalizeShifts().visit(ir) - - for _ in range(10): - inlined = ir - - inlined = InlineCenterDerefLiftVars.apply(inlined, uids=icdlv_uids) # type: ignore[arg-type] # always a fencil - inlined = _inline_lifts(inlined, lift_mode) - - inlined = InlineLambdas.apply( - inlined, - opcount_preserving=True, - force_inline_lift_args=(lift_mode == LiftMode.FORCE_INLINE), - # If trivial lifts are not inlined we might create temporaries for constants. In all - # other cases we want it anyway. - force_inline_trivial_lift_args=True, - ) - inlined = ConstantFolding.apply(inlined) - # This pass is required to be in the loop such that when an `if_` call with tuple arguments - # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply( - inlined, - offset_provider_type=offset_provider_type, - # TODO(tehrengruber): disabled since it increases compile-time too much right now - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, - ) - # This pass is required such that a deref outside of a - # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the - # `tuple_get` is removed by the `CollapseTuple` pass. - inlined = PropagateDeref.apply(inlined) - - if inlined == ir: - break - ir = inlined - else: - raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") - - if lift_mode != LiftMode.FORCE_INLINE: - raise NotImplementedError() - - # Since `CollapseTuple` relies on the type inference which does not support returning tuples - # larger than the number of closure outputs as given by the unconditional collapse, we can - # only run the unconditional version here instead of in the loop above. - if unconditionally_collapse_tuples: - ir = CollapseTuple.apply( - ir, - ignore_tuple_size=True, - offset_provider_type=offset_provider_type, - # TODO(tehrengruber): disabled since it increases compile-time too much right now - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, - ) - - if lift_mode == LiftMode.FORCE_INLINE: - ir = _inline_into_scan(ir) - - ir = NormalizeShifts().visit(ir) - - ir = FuseMaps().visit(ir) - ir = CollapseListGet().visit(ir) - - if unroll_reduce: - for _ in range(10): - unrolled = UnrollReduce.apply(ir, offset_provider_type=offset_provider_type) - if unrolled == ir: - break - ir = unrolled - ir = CollapseListGet().visit(ir) - ir = NormalizeShifts().visit(ir) - ir = _inline_lifts(ir, LiftMode.FORCE_INLINE) - ir = NormalizeShifts().visit(ir) - else: - raise RuntimeError("Reduction unrolling failed.") - - ir = EtaReduction().visit(ir) - ir = ScanEtaReduction().visit(ir) - - if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[type-var] # always an itir.Program - ir = MergeLet().visit(ir) - - ir = InlineLambdas.apply( - ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args - ) - - assert isinstance(ir, itir.Program) - return ir diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace.py index 95186e0b5d..1b3b930818 100644 --- a/src/gt4py/next/program_processors/runners/dace.py +++ b/src/gt4py/next/program_processors/runners/dace.py @@ -8,45 +8,34 @@ import factory +import gt4py._core.definitions as core_defs +import gt4py.next.allocators as next_allocators from gt4py.next import backend +from gt4py.next.otf import workflow from gt4py.next.program_processors.runners.dace_fieldview import workflow as dace_fieldview_workflow -from gt4py.next.program_processors.runners.dace_iterator import workflow as dace_iterator_workflow from gt4py.next.program_processors.runners.gtfn import GTFNBackendFactory -class DaCeIteratorBackendFactory(GTFNBackendFactory): +class DaCeFieldviewBackendFactory(GTFNBackendFactory): + class Meta: + model = backend.Backend + class Params: - otf_workflow = factory.SubFactory( - dace_iterator_workflow.DaCeWorkflowFactory, - device_type=factory.SelfAttribute("..device_type"), - use_field_canonical_representation=factory.SelfAttribute( - "..use_field_canonical_representation" - ), + name_device = "cpu" + name_cached = "" + name_postfix = "" + gpu = factory.Trait( + allocator=next_allocators.StandardGPUFieldBufferAllocator(), + device_type=next_allocators.CUPY_DEVICE or core_defs.DeviceType.CUDA, + name_device="gpu", ) - auto_optimize = factory.Trait( - otf_workflow__translation__auto_optimize=True, name_postfix="_opt" + cached = factory.Trait( + executor=factory.LazyAttribute( + lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) + ), + name_cached="_cached", ) - use_field_canonical_representation: bool = False - - name = factory.LazyAttribute( - lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.itir" - ) - - transforms = backend.LEGACY_TRANSFORMS - - -run_dace_cpu = DaCeIteratorBackendFactory(cached=True, auto_optimize=True) -run_dace_cpu_noopt = DaCeIteratorBackendFactory(cached=True, auto_optimize=False) - -run_dace_gpu = DaCeIteratorBackendFactory(gpu=True, cached=True, auto_optimize=True) -run_dace_gpu_noopt = DaCeIteratorBackendFactory(gpu=True, cached=True, auto_optimize=False) - -itir_cpu = run_dace_cpu -itir_gpu = run_dace_gpu - - -class DaCeFieldviewBackendFactory(GTFNBackendFactory): - class Params: + device_type = core_defs.DeviceType.CPU otf_workflow = factory.SubFactory( dace_fieldview_workflow.DaCeWorkflowFactory, device_type=factory.SelfAttribute("..device_type"), @@ -55,11 +44,16 @@ class Params: auto_optimize = factory.Trait(name_postfix="_opt") name = factory.LazyAttribute( - lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.gtir" + lambda o: f"run_dace_{o.name_device}{o.name_cached}{o.name_postfix}" ) + executor = factory.LazyAttribute(lambda o: o.otf_workflow) + allocator = next_allocators.StandardCPUFieldBufferAllocator() transforms = backend.DEFAULT_TRANSFORMS -gtir_cpu = DaCeFieldviewBackendFactory(cached=True, auto_optimize=False) -gtir_gpu = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=False) +run_dace_cpu = DaCeFieldviewBackendFactory(cached=True, auto_optimize=True) +run_dace_cpu_noopt = DaCeFieldviewBackendFactory(cached=True, auto_optimize=False) + +run_dace_gpu = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=True) +run_dace_gpu_noopt = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=False) diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index 56ba08015b..90e7e07ad5 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -24,7 +24,7 @@ cp = None -def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: bool) -> Any: +def _convert_arg(arg: Any, sdfg_param: str) -> Any: if not isinstance(arg, gtx_common.Field): return arg if len(arg.domain.dims) == 0: @@ -41,26 +41,14 @@ def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: raise RuntimeError( f"Field '{sdfg_param}' passed as array slice with offset {dim_range.start} on dimension {dim.value}." ) - if not use_field_canonical_representation: - return arg.ndarray - # the canonical representation requires alphabetical ordering of the dimensions in field domain definition - sorted_dims = dace_utils.get_sorted_dims(arg.domain.dims) - ndim = len(sorted_dims) - dim_indices = [dim_index for dim_index, _ in sorted_dims] - if isinstance(arg.ndarray, np.ndarray): - return np.moveaxis(arg.ndarray, range(ndim), dim_indices) - else: - assert cp is not None and isinstance(arg.ndarray, cp.ndarray) - return cp.moveaxis(arg.ndarray, range(ndim), dim_indices) - - -def _get_args( - sdfg: dace.SDFG, args: Sequence[Any], use_field_canonical_representation: bool -) -> dict[str, Any]: + return arg.ndarray + + +def _get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: sdfg_params: Sequence[str] = sdfg.arg_names flat_args: Iterable[Any] = gtx_utils.flatten_nested_tuple(tuple(args)) return { - sdfg_param: _convert_arg(arg, sdfg_param, use_field_canonical_representation) + sdfg_param: _convert_arg(arg, sdfg_param) for sdfg_param, arg in zip(sdfg_params, flat_args, strict=True) } @@ -154,10 +142,10 @@ def get_sdfg_conn_args( def get_sdfg_args( sdfg: dace.SDFG, + offset_provider: gtx_common.OffsetProvider, *args: Any, check_args: bool = False, on_gpu: bool = False, - use_field_canonical_representation: bool = True, **kwargs: Any, ) -> dict[str, Any]: """Extracts the arguments needed to call the SDFG. @@ -166,10 +154,10 @@ def get_sdfg_args( Args: sdfg: The SDFG for which we want to get the arguments. + offset_provider: Offset provider. """ - offset_provider = kwargs["offset_provider"] - dace_args = _get_args(sdfg, args, use_field_canonical_representation) + dace_args = _get_args(sdfg, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} dace_conn_args = get_sdfg_conn_args(sdfg, offset_provider, on_gpu) dace_shapes = _get_shape_args(sdfg.arrays, dace_field_args) diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index 3e96ef3cec..ac15bc1cbf 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -9,7 +9,7 @@ from __future__ import annotations import re -from typing import Final, Literal, Optional, Sequence +from typing import Final, Literal, Optional import dace @@ -96,10 +96,3 @@ def filter_connectivity_types( for offset, conn in offset_provider_type.items() if isinstance(conn, gtx_common.NeighborConnectivityType) } - - -def get_sorted_dims( - dims: Sequence[gtx_common.Dimension], -) -> Sequence[tuple[int, gtx_common.Dimension]]: - """Sort list of dimensions in alphabetical order.""" - return sorted(enumerate(dims), key=lambda v: v[1].value) diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace_common/workflow.py index 91e83dba9d..5d9ac863c5 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_common/workflow.py @@ -150,9 +150,9 @@ def decorated_program( sdfg_args = dace_backend.get_sdfg_args( sdfg, + offset_provider, *args, check_args=False, - offset_provider=offset_provider, on_gpu=on_gpu, use_field_canonical_representation=use_field_canonical_representation, ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py deleted file mode 100644 index ef09cf51cd..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ /dev/null @@ -1,377 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import dataclasses -import warnings -from collections import OrderedDict -from collections.abc import Callable, Sequence -from dataclasses import field -from inspect import currentframe, getframeinfo -from pathlib import Path -from typing import Any, ClassVar, Optional - -import dace -import numpy as np -from dace.sdfg import utils as sdutils -from dace.transformation.auto import auto_optimize as autoopt - -import gt4py.next.iterator.ir as itir -from gt4py.next import common -from gt4py.next.ffront import decorator -from gt4py.next.iterator import transforms as itir_transforms -from gt4py.next.iterator.ir import SymRef -from gt4py.next.iterator.transforms import ( - pass_manager_legacy as legacy_itir_transforms, - program_to_fencil, -) -from gt4py.next.iterator.type_system import inference as itir_type_inference -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.type_system import type_specifications as ts - -from .itir_to_sdfg import ItirToSDFG - - -def preprocess_program( - program: itir.FencilDefinition, - offset_provider_type: common.OffsetProviderType, - lift_mode: legacy_itir_transforms.LiftMode, - symbolic_domain_sizes: Optional[dict[str, str]] = None, - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, - unroll_reduce: bool = False, -): - node = legacy_itir_transforms.apply_common_transforms( - program, - common_subexpression_elimination=False, - force_inline_lambda_args=True, - lift_mode=lift_mode, - offset_provider_type=offset_provider_type, - symbolic_domain_sizes=symbolic_domain_sizes, - temporary_extraction_heuristics=temporary_extraction_heuristics, - unroll_reduce=unroll_reduce, - ) - - node = itir_type_inference.infer(node, offset_provider_type=offset_provider_type) - - if isinstance(node, itir.Program): - fencil_definition = program_to_fencil.program_to_fencil(node) - tmps = node.declarations - assert all(isinstance(tmp, itir.Temporary) for tmp in tmps) - else: - raise TypeError(f"Expected 'Program', got '{type(node).__name__}'.") - - return fencil_definition, tmps - - -def build_sdfg_from_itir( - program: itir.FencilDefinition, - arg_types: Sequence[ts.TypeSpec], - offset_provider_type: common.OffsetProviderType, - auto_optimize: bool = False, - on_gpu: bool = False, - column_axis: Optional[common.Dimension] = None, - lift_mode: legacy_itir_transforms.LiftMode = legacy_itir_transforms.LiftMode.FORCE_INLINE, - symbolic_domain_sizes: Optional[dict[str, str]] = None, - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, - load_sdfg_from_file: bool = False, - save_sdfg: bool = True, - use_field_canonical_representation: bool = True, -) -> dace.SDFG: - """Translate a Fencil into an SDFG. - - Args: - program: The Fencil that should be translated. - arg_types: Types of the arguments passed to the fencil. - offset_provider: The set of offset providers that should be used. - auto_optimize: Apply DaCe's `auto_optimize` heuristic. - on_gpu: Performs the translation for GPU, defaults to `False`. - column_axis: The column axis to be used, defaults to `None`. - lift_mode: Which lift mode should be used, defaults `FORCE_INLINE`. - symbolic_domain_sizes: Used for generation of liskov bindings when temporaries are enabled. - load_sdfg_from_file: Allows to read the SDFG from file, instead of generating it, for debug only. - save_sdfg: If `True`, the default the SDFG is stored as a file and can be loaded, this allows to skip the lowering step, requires `load_sdfg_from_file` set to `True`. - use_field_canonical_representation: If `True`, assume that the fields dimensions are sorted alphabetically. - """ - - sdfg_filename = f"_dacegraphs/gt4py/{program.id}.sdfg" - if load_sdfg_from_file and Path(sdfg_filename).exists(): - sdfg: dace.SDFG = dace.SDFG.from_file(sdfg_filename) - sdfg.validate() - return sdfg - - # visit ITIR and generate SDFG - program, tmps = preprocess_program( - program, - offset_provider_type, - lift_mode, - symbolic_domain_sizes, - temporary_extraction_heuristics, - ) - sdfg_genenerator = ItirToSDFG( - list(arg_types), - offset_provider_type, - tmps, - use_field_canonical_representation, - column_axis, - ) - sdfg = sdfg_genenerator.visit(program) - if sdfg is None: - raise RuntimeError(f"Visit failed for program {program.id}.") - - for nested_sdfg in sdfg.all_sdfgs_recursive(): - if not nested_sdfg.debuginfo: - _, frameinfo = ( - warnings.warn( - f"{nested_sdfg.label} does not have debuginfo. Consider adding them in the corresponding nested sdfg.", - stacklevel=2, - ), - getframeinfo(currentframe()), # type: ignore[arg-type] - ) - nested_sdfg.debuginfo = dace.dtypes.DebugInfo( - start_line=frameinfo.lineno, end_line=frameinfo.lineno, filename=frameinfo.filename - ) - - # TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct - sdutils.inline_loop_blocks(sdfg) - - # run DaCe transformations to simplify the SDFG - sdfg.simplify() - - # run DaCe auto-optimization heuristics - if auto_optimize: - # TODO: Investigate performance improvement from SDFG specialization with constant symbols, - # for array shape and strides, although this would imply JIT compilation. - symbols: dict[str, int] = {} - device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU - sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) - elif on_gpu: - autoopt.apply_gpu_storage(sdfg) - - if on_gpu: - sdfg.apply_gpu_transformations() - - # Store the sdfg such that we can later reuse it. - if save_sdfg: - sdfg.save(sdfg_filename) - - return sdfg - - -@dataclasses.dataclass(frozen=True) -class Program(decorator.Program, dace.frontend.python.common.SDFGConvertible): - """Extension of GT4Py Program implementing the SDFGConvertible interface.""" - - sdfg_closure_vars: dict[str, Any] = field(default_factory=dict) - - # Being a ClassVar ensures that in an SDFG with multiple nested GT4Py Programs, - # there is no name mangling of the connectivity tables used across the nested SDFGs - # since they share the same memory address. - connectivity_tables_data_descriptors: ClassVar[ - dict[str, dace.data.Array] - ] = {} # symbolically defined - - def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: - if "dace" not in self.backend.name.lower(): # type: ignore[union-attr] - raise ValueError("The SDFG can be generated only for the DaCe backend.") - - params = {str(p.id): p.type for p in self.itir.params} - fields = {str(p.id): p.type for p in self.itir.params if hasattr(p.type, "dims")} - arg_types = [*params.values()] - - dace_parsed_args = [*args, *kwargs.values()] - gt4py_program_args = [*params.values()] - _crosscheck_dace_parsing(dace_parsed_args, gt4py_program_args) - - if self.connectivities is None: - raise ValueError( - "[DaCe Orchestration] Connectivities -at compile time- are required to generate the SDFG. Use `with_connectivities` method." - ) - offset_provider_type = {**self.connectivities, **self._implicit_offset_provider} - - sdfg = self.backend.executor.step.translation.generate_sdfg( # type: ignore[union-attr] - self.itir, - arg_types, - offset_provider_type=offset_provider_type, - column_axis=kwargs.get("column_axis", None), - ) - self.sdfg_closure_vars["sdfg.arrays"] = sdfg.arrays # use it in __sdfg_closure__ - - # Halo exchange related metadata, i.e. gt4py_program_input_fields, gt4py_program_output_fields, offset_providers_per_input_field - # Add them as dynamic properties to the SDFG - - assert all( - isinstance(in_field, SymRef) - for closure in self.itir.closures - for in_field in closure.inputs - ) # backend only supports SymRef inputs, not `index` calls - input_fields = [ - str(in_field.id) # type: ignore[union-attr] # ensured by assert - for closure in self.itir.closures - for in_field in closure.inputs - if str(in_field.id) in fields # type: ignore[union-attr] # ensured by assert - ] - sdfg.gt4py_program_input_fields = { - in_field: dim - for in_field in input_fields - for dim in fields[in_field].dims # type: ignore[union-attr] - if dim.kind == common.DimensionKind.HORIZONTAL - } - - output_fields = [] - for closure in self.itir.closures: - output = closure.output - if isinstance(output, itir.SymRef): - if str(output.id) in fields: - output_fields.append(str(output.id)) - else: - for arg in output.args: - if str(arg.id) in fields: # type: ignore[attr-defined] - output_fields.append(str(arg.id)) # type: ignore[attr-defined] - sdfg.gt4py_program_output_fields = { - output: dim - for output in output_fields - for dim in fields[output].dims # type: ignore[union-attr] - if dim.kind == common.DimensionKind.HORIZONTAL - } - - sdfg.offset_providers_per_input_field = {} - itir_tmp = legacy_itir_transforms.apply_common_transforms( - self.itir, offset_provider_type=offset_provider_type - ) - itir_tmp_fencil = program_to_fencil.program_to_fencil(itir_tmp) - for closure in itir_tmp_fencil.closures: - params_shifts = itir_transforms.trace_shifts.trace_stencil( - closure.stencil, num_args=len(closure.inputs) - ) - for param, shifts in zip(closure.inputs, params_shifts): - assert isinstance( - param, SymRef - ) # backend only supports SymRef inputs, not `index` calls - if not isinstance(param.id, str): - continue - if param.id not in sdfg.gt4py_program_input_fields: - continue - sdfg.offset_providers_per_input_field.setdefault(param.id, []).extend(list(shifts)) - - return sdfg - - def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[str, Any]: - """ - Returns the closure arrays of the SDFG represented by this object - as a mapping between array name and the corresponding value. - - The connectivity tables are defined symbolically, i.e. table sizes & strides are DaCe symbols. - The need to define the connectivity tables in the `__sdfg_closure__` arises from the fact that - the offset providers are not part of GT4Py Program's arguments. - Keep in mind, that `__sdfg_closure__` is called after `__sdfg__` method. - """ - offset_provider_type = self.connectivities - - # Define DaCe symbols - connectivity_table_size_symbols = { - dace_utils.field_size_symbol_name( - dace_utils.connectivity_identifier(k), axis - ): dace.symbol( - dace_utils.field_size_symbol_name(dace_utils.connectivity_identifier(k), axis) - ) - for k, v in offset_provider_type.items() # type: ignore[union-attr] - for axis in [0, 1] - if isinstance(v, common.NeighborConnectivityType) - and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] - } - - connectivity_table_stride_symbols = { - dace_utils.field_stride_symbol_name( - dace_utils.connectivity_identifier(k), axis - ): dace.symbol( - dace_utils.field_stride_symbol_name(dace_utils.connectivity_identifier(k), axis) - ) - for k, v in offset_provider_type.items() # type: ignore[union-attr] - for axis in [0, 1] - if isinstance(v, common.NeighborConnectivityType) - and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] - } - - symbols = {**connectivity_table_size_symbols, **connectivity_table_stride_symbols} - - # Define the storage location (e.g. CPU, GPU) of the connectivity tables - if "storage" not in Program.connectivity_tables_data_descriptors: - for k, v in offset_provider_type.items(): # type: ignore[union-attr] - if not isinstance(v, common.NeighborConnectivityType): - continue - if dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"]: - Program.connectivity_tables_data_descriptors["storage"] = ( - self.sdfg_closure_vars[ - "sdfg.arrays" - ][dace_utils.connectivity_identifier(k)].storage - ) - break - - # Build the closure dictionary - closure_dict = {} - for k, v in offset_provider_type.items(): # type: ignore[union-attr] - conn_id = dace_utils.connectivity_identifier(k) - if ( - isinstance(v, common.NeighborConnectivityType) - and conn_id in self.sdfg_closure_vars["sdfg.arrays"] - ): - if conn_id not in Program.connectivity_tables_data_descriptors: - Program.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( - dtype=dace.int64 if v.dtype.scalar_type == np.int64 else dace.int32, - shape=[ - symbols[dace_utils.field_size_symbol_name(conn_id, 0)], - symbols[dace_utils.field_size_symbol_name(conn_id, 1)], - ], - strides=[ - symbols[dace_utils.field_stride_symbol_name(conn_id, 0)], - symbols[dace_utils.field_stride_symbol_name(conn_id, 1)], - ], - storage=Program.connectivity_tables_data_descriptors["storage"], - ) - closure_dict[conn_id] = Program.connectivity_tables_data_descriptors[conn_id] - - return closure_dict - - def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]: - args = [] - for arg in self.past_stage.past_node.params: - args.append(arg.id) - return (args, []) - - -def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: list[Any]) -> bool: - for dace_parsed_arg, gt4py_program_arg in zip(dace_parsed_args, gt4py_program_args): - if isinstance(dace_parsed_arg, dace.data.Scalar): - assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg) - elif isinstance( - dace_parsed_arg, (bool, int, float, str, np.bool_, np.integer, np.floating, np.str_) - ): # compile-time constant scalar - assert isinstance(gt4py_program_arg, ts.ScalarType) - if isinstance(dace_parsed_arg, (bool, np.bool_)): - assert gt4py_program_arg.kind == ts.ScalarKind.BOOL - elif isinstance(dace_parsed_arg, (int, np.integer)): - assert gt4py_program_arg.kind in [ts.ScalarKind.INT32, ts.ScalarKind.INT64] - elif isinstance(dace_parsed_arg, (float, np.floating)): - assert gt4py_program_arg.kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] - elif isinstance(dace_parsed_arg, (str, np.str_)): - assert gt4py_program_arg.kind == ts.ScalarKind.STRING - elif isinstance(dace_parsed_arg, dace.data.Array): - assert isinstance(gt4py_program_arg, ts.FieldType) - assert len(dace_parsed_arg.shape) == len(gt4py_program_arg.dims) - assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg.dtype) - elif isinstance( - dace_parsed_arg, (dace.data.Structure, dict, OrderedDict) - ): # offset_provider - continue - else: - raise ValueError(f"Unresolved case for {dace_parsed_arg} (==, !=) {gt4py_program_arg}") - - return True diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py deleted file mode 100644 index 823943cfd5..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ /dev/null @@ -1,809 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import warnings -from typing import Optional, Sequence, cast - -import dace -from dace.sdfg.state import LoopRegion - -import gt4py.eve as eve -from gt4py.next import Dimension, DimensionKind, common -from gt4py.next.ffront import fbuiltins as gtx_fbuiltins -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.type_system import type_info, type_specifications as ts, type_translation as tt - -from .itir_to_tasklet import ( - Context, - GatherOutputSymbolsPass, - PythonTaskletCodegen, - SymbolExpr, - TaskletExpr, - ValueExpr, - closure_to_tasklet_sdfg, - is_scan, -) -from .utility import ( - add_mapped_nested_sdfg, - flatten_list, - get_used_connectivities, - map_nested_sdfg_symbols, - new_array_symbols, - unique_var_name, -) - - -def _get_scan_args(stencil: Expr) -> tuple[bool, Literal]: - """ - Parse stencil expression to extract the scan arguments. - - Returns - ------- - tuple(is_forward, init_carry) - The output tuple fields verify the following semantics: - - is_forward: forward boolean flag - - init_carry: carry initial value - """ - stencil_fobj = cast(FunCall, stencil) - is_forward = stencil_fobj.args[1] - assert isinstance(is_forward, Literal) and type_info.is_logical(is_forward.type) - init_carry = stencil_fobj.args[2] - assert isinstance(init_carry, Literal) - return is_forward.value == "True", init_carry - - -def _get_scan_dim( - column_axis: Dimension, - storage_types: dict[str, ts.TypeSpec], - output: SymRef, - use_field_canonical_representation: bool, -) -> tuple[str, int, ts.ScalarType]: - """ - Extract information about the scan dimension. - - Returns - ------- - tuple(scan_dim_name, scan_dim_index, scan_dim_dtype) - The output tuple fields verify the following semantics: - - scan_dim_name: name of the scan dimension - - scan_dim_index: domain index of the scan dimension - - scan_dim_dtype: data type along the scan dimension - """ - output_type = storage_types[output.id] - assert isinstance(output_type, ts.FieldType) - sorted_dims = [ - dim - for _, dim in ( - dace_utils.get_sorted_dims(output_type.dims) - if use_field_canonical_representation - else enumerate(output_type.dims) - ) - ] - return (column_axis.value, sorted_dims.index(column_axis), output_type.dtype) - - -def _make_array_shape_and_strides( - name: str, - dims: Sequence[Dimension], - offset_provider_type: common.OffsetProviderType, - sort_dims: bool, -) -> tuple[list[dace.symbol], list[dace.symbol]]: - """ - Parse field dimensions and allocate symbols for array shape and strides. - - For local dimensions, the size is known at compile-time and therefore - the corresponding array shape dimension is set to an integer literal value. - - Returns - ------- - tuple(shape, strides) - The output tuple fields are arrays of dace symbolic expressions. - """ - dtype = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) - sorted_dims = dace_utils.get_sorted_dims(dims) if sort_dims else list(enumerate(dims)) - connectivity_types = dace_utils.filter_connectivity_types(offset_provider_type) - shape = [ - ( - connectivity_types[dim.value].max_neighbors - if dim.kind == DimensionKind.LOCAL - # we reuse the same gt4py symbol for field size passed as scalar argument which is used in closure domain - else dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) - ) - for i, dim in sorted_dims - ] - strides = [ - dace.symbol(dace_utils.field_stride_symbol_name(name, i), dtype) for i, _ in sorted_dims - ] - return shape, strides - - -def _check_no_lifts(node: itir.StencilClosure): - """ - Parse stencil closure ITIR to check that lift expressions only appear as child nodes in neighbor reductions. - - Returns - ------- - True if lifts do not appear in the ITIR exception lift expressions in neighbor reductions. False otherwise. - """ - neighbors_call_count = 0 - for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun"): - if getattr(fun, "id", "") == "neighbors": - neighbors_call_count = 3 - elif getattr(fun, "id", "") == "lift" and neighbors_call_count != 1: - return False - neighbors_call_count = max(0, neighbors_call_count - 1) - return True - - -class ItirToSDFG(eve.NodeVisitor): - param_types: list[ts.TypeSpec] - storage_types: dict[str, ts.TypeSpec] - column_axis: Optional[Dimension] - offset_provider_type: common.OffsetProviderType - unique_id: int - use_field_canonical_representation: bool - - def __init__( - self, - param_types: list[ts.TypeSpec], - offset_provider_type: common.OffsetProviderType, - tmps: list[itir.Temporary], - use_field_canonical_representation: bool, - column_axis: Optional[Dimension] = None, - ): - self.param_types = param_types - self.column_axis = column_axis - self.offset_provider_type = offset_provider_type - self.storage_types = {} - self.tmps = tmps - self.use_field_canonical_representation = use_field_canonical_representation - - def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, sort_dimensions: bool): - if isinstance(type_, ts.FieldType): - shape, strides = _make_array_shape_and_strides( - name, type_.dims, self.offset_provider_type, sort_dimensions - ) - dtype = dace_utils.as_dace_type(type_.dtype) - sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) - - elif isinstance(type_, ts.ScalarType): - dtype = dace_utils.as_dace_type(type_) - if name in sdfg.symbols: - assert sdfg.symbols[name].dtype == dtype - else: - sdfg.add_symbol(name, dtype) - - else: - raise NotImplementedError() - self.storage_types[name] = type_ - - def add_storage_for_temporaries( - self, node_params: list[Sym], defs_state: dace.SDFGState, program_sdfg: dace.SDFG - ) -> dict[str, str]: - symbol_map: dict[str, TaskletExpr] = {} - # The shape of temporary arrays might be defined based on scalar values passed as program arguments. - # Here we collect these values in a symbol map. - for sym in node_params: - if isinstance(sym.type, ts.ScalarType): - name_ = str(sym.id) - symbol_map[name_] = SymbolExpr(name_, dace_utils.as_dace_type(sym.type)) - - tmp_symbols: dict[str, str] = {} - for tmp in self.tmps: - tmp_name = str(tmp.id) - - # We visit the domain of the temporary field, passing the set of available symbols. - assert isinstance(tmp.domain, itir.FunCall) - domain_ctx = Context(program_sdfg, defs_state, symbol_map) - tmp_domain = self._visit_domain(tmp.domain, domain_ctx) - - if isinstance(tmp.type, ts.TupleType): - raise NotImplementedError("Temporaries of tuples are not supported.") - assert isinstance(tmp.type, ts.FieldType) and isinstance(tmp.dtype, ts.ScalarType) - - # We store the FieldType for this temporary array. - self.storage_types[tmp_name] = tmp.type - - # N.B.: skip generation of symbolic strides and just let dace assign default strides, for now. - # Another option, in the future, is to use symbolic strides and apply auto-tuning or some heuristics - # to assign optimal stride values. - tmp_shape, _ = new_array_symbols(tmp_name, len(tmp.type.dims)) - _, tmp_array = program_sdfg.add_array( - tmp_name, tmp_shape, dace_utils.as_dace_type(tmp.dtype), transient=True - ) - - # Loop through all dimensions to visit the symbolic expressions for array shape and offset. - # These expressions are later mapped to interstate symbols. - for (_, (begin, end)), shape_sym in zip(tmp_domain, tmp_array.shape): - # The temporary field has a dimension range defined by `begin` and `end` values. - # Therefore, the actual size is given by the difference `end.value - begin.value`. - # Instead of allocating the actual size, we allocate space to enable indexing from 0 - # because we want to avoid using dace array offsets (which will be deprecated soon). - # The result should still be valid, but the stencil will be using only a subset - # of the array. - if not (isinstance(begin, SymbolExpr) and begin.value == "0"): - warnings.warn( - f"Domain start offset for temporary {tmp_name} is ignored.", stacklevel=2 - ) - tmp_symbols[str(shape_sym)] = end.value - - return tmp_symbols - - def create_memlet_at(self, field_name: str, index: dict[str, str]): - field_type = self.storage_types[field_name] - assert isinstance(field_type, ts.FieldType) - if self.use_field_canonical_representation: - field_index = [ - index[dim.value] for _, dim in dace_utils.get_sorted_dims(field_type.dims) - ] - else: - field_index = [index[dim.value] for dim in field_type.dims] - subset = ", ".join(field_index) - return dace.Memlet(data=field_name, subset=subset) - - def get_output_nodes( - self, closure: itir.StencilClosure, sdfg: dace.SDFG, state: dace.SDFGState - ) -> dict[str, dace.nodes.AccessNode]: - # Visit output node, which could be a `make_tuple` expression, to collect the required access nodes - output_symbols_pass = GatherOutputSymbolsPass(sdfg, state) - output_symbols_pass.visit(closure.output) - # Visit output node again to generate the corresponding tasklet - context = Context(sdfg, state, output_symbols_pass.symbol_refs) - translator = PythonTaskletCodegen( - self.offset_provider_type, context, self.use_field_canonical_representation - ) - output_nodes = flatten_list(translator.visit(closure.output)) - return {node.value.data: node.value for node in output_nodes} - - def visit_FencilDefinition(self, node: itir.FencilDefinition): - program_sdfg = dace.SDFG(name=node.id) - program_sdfg.debuginfo = dace_utils.debug_info(node) - entry_state = program_sdfg.add_state("program_entry", is_start_block=True) - - # Filter neighbor tables from offset providers. - connectivity_types = get_used_connectivities(node, self.offset_provider_type) - - # Add program parameters as SDFG storages. - for param, type_ in zip(node.params, self.param_types): - self.add_storage( - program_sdfg, str(param.id), type_, self.use_field_canonical_representation - ) - - if self.tmps: - tmp_symbols = self.add_storage_for_temporaries(node.params, entry_state, program_sdfg) - # on the first interstate edge define symbols for shape and offsets of temporary arrays - last_state = program_sdfg.add_state("init_symbols_for_temporaries") - program_sdfg.add_edge( - entry_state, last_state, dace.InterstateEdge(assignments=tmp_symbols) - ) - else: - last_state = entry_state - - # Add connectivities as SDFG storages. - for offset, connectivity_type in connectivity_types.items(): - scalar_type = tt.from_dtype(connectivity_type.dtype) - type_ = ts.FieldType( - [connectivity_type.source_dim, connectivity_type.neighbor_dim], scalar_type - ) - self.add_storage( - program_sdfg, - dace_utils.connectivity_identifier(offset), - type_, - sort_dimensions=False, - ) - - # Create a nested SDFG for all stencil closures. - for closure in node.closures: - # Translate the closure and its stencil's body to an SDFG. - closure_sdfg, input_names, output_names = self.visit( - closure, array_table=program_sdfg.arrays - ) - - # Create a new state for the closure. - last_state = program_sdfg.add_state_after(last_state) - - # Create memlets to transfer the program parameters - input_mapping = { - name: dace.Memlet.from_array(name, program_sdfg.arrays[name]) - for name in input_names - } - output_mapping = { - name: dace.Memlet.from_array(name, program_sdfg.arrays[name]) - for name in output_names - } - - symbol_mapping = map_nested_sdfg_symbols(program_sdfg, closure_sdfg, input_mapping) - - # Insert the closure's SDFG as a nested SDFG of the program. - nsdfg_node = last_state.add_nested_sdfg( - sdfg=closure_sdfg, - parent=program_sdfg, - inputs=set(input_names), - outputs=set(output_names), - symbol_mapping=symbol_mapping, - debuginfo=closure_sdfg.debuginfo, - ) - - # Add access nodes for the program parameters and connect them to the nested SDFG's inputs via edges. - for inner_name, memlet in input_mapping.items(): - access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) - last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) - - for inner_name, memlet in output_mapping.items(): - access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) - last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) - - # Create the call signature for the SDFG. - # Only the arguments requiered by the Fencil, i.e. `node.params` are added as positional arguments. - # The implicit arguments, such as the offset providers or the arguments created by the translation process, must be passed as keywords only arguments. - program_sdfg.arg_names = [str(a) for a in node.params] - - program_sdfg.validate() - return program_sdfg - - def visit_StencilClosure( - self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array] - ) -> tuple[dace.SDFG, list[str], list[str]]: - assert _check_no_lifts(node) - - # Create the closure's nested SDFG and single state. - closure_sdfg = dace.SDFG(name="closure") - closure_sdfg.debuginfo = dace_utils.debug_info(node) - closure_state = closure_sdfg.add_state("closure_entry") - closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True) - - assert all( - isinstance(inp, SymRef) for inp in node.inputs - ) # backend only supports SymRef inputs, not `index` calls - input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert - neighbor_tables = get_used_connectivities(node, self.offset_provider_type) - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - output_nodes = self.get_output_nodes(node, closure_sdfg, closure_state) - output_names = [k for k, _ in output_nodes.items()] - - # Add DaCe arrays for inputs, outputs and connectivities to closure SDFG. - input_transients_mapping = {} - for name in [*input_names, *connectivity_names, *output_names]: - if name in closure_sdfg.arrays: - assert name in input_names and name in output_names - # In case of closures with in/out fields, there is risk of race condition - # between read/write access nodes in the (asynchronous) map tasklet. - transient_name = unique_var_name() - closure_sdfg.add_array( - transient_name, - shape=array_table[name].shape, - strides=array_table[name].strides, - dtype=array_table[name].dtype, - transient=True, - ) - closure_init_state.add_nedge( - closure_init_state.add_access(name, debuginfo=closure_sdfg.debuginfo), - closure_init_state.add_access(transient_name, debuginfo=closure_sdfg.debuginfo), - dace.Memlet.from_array(name, closure_sdfg.arrays[name]), - ) - input_transients_mapping[name] = transient_name - elif isinstance(self.storage_types[name], ts.FieldType): - closure_sdfg.add_array( - name, - shape=array_table[name].shape, - strides=array_table[name].strides, - dtype=array_table[name].dtype, - ) - else: - assert isinstance(self.storage_types[name], ts.ScalarType) - - input_field_names = [ - input_name - for input_name in input_names - if isinstance(self.storage_types[input_name], ts.FieldType) - ] - - # Closure outputs should all be fields - assert all( - isinstance(self.storage_types[output_name], ts.FieldType) - for output_name in output_names - ) - - # Update symbol table and get output domain of the closure - program_arg_syms: dict[str, TaskletExpr] = {} - for name, type_ in self.storage_types.items(): - if isinstance(type_, ts.ScalarType): - dtype = dace_utils.as_dace_type(type_) - if name in input_names: - out_name = unique_var_name() - closure_sdfg.add_scalar(out_name, dtype, transient=True) - out_tasklet = closure_init_state.add_tasklet( - f"get_{name}", - {}, - {"__result"}, - f"__result = {name}", - debuginfo=closure_sdfg.debuginfo, - ) - access = closure_init_state.add_access( - out_name, debuginfo=closure_sdfg.debuginfo - ) - value = ValueExpr(access, dtype) - memlet = dace.Memlet(data=out_name, subset="0") - closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) - program_arg_syms[name] = value - else: - program_arg_syms[name] = SymbolExpr(name, dtype) - else: - assert isinstance(type_, ts.FieldType) - # make shape symbols (corresponding to field size) available as arguments to domain visitor - if name in input_names or name in output_names: - field_symbols = [ - val - for val in closure_sdfg.arrays[name].shape - if isinstance(val, dace.symbol) and str(val) not in input_names - ] - for sym in field_symbols: - sym_name = str(sym) - program_arg_syms[sym_name] = SymbolExpr(sym, sym.dtype) - closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms) - closure_domain = self._visit_domain(node.domain, closure_ctx) - - # Map SDFG tasklet arguments to parameters - input_local_names = [ - ( - input_transients_mapping[input_name] - if input_name in input_transients_mapping - else ( - input_name - if input_name in input_field_names - else cast(ValueExpr, program_arg_syms[input_name]).value.data - ) - ) - for input_name in input_names - ] - input_memlets = [ - dace.Memlet.from_array(name, closure_sdfg.arrays[name]) - for name in [*input_local_names, *connectivity_names] - ] - - # create and write to transient that is then copied back to actual output array to avoid aliasing of - # same memory in nested SDFG with different names - output_connectors_mapping = {unique_var_name(): output_name for output_name in output_names} - # scan operator should always be the first function call in a closure - if is_scan(node.stencil): - assert len(output_connectors_mapping) == 1, "Scan does not support multiple outputs" - transient_name, output_name = next(iter(output_connectors_mapping.items())) - - nsdfg, map_ranges, scan_dim_index = self._visit_scan_stencil_closure( - node, closure_sdfg.arrays, closure_domain, transient_name - ) - results = [transient_name] - - _, (scan_lb, scan_ub) = closure_domain[scan_dim_index] - output_subset = f"{scan_lb.value}:{scan_ub.value}" - - domain_subset = { - dim: ( - f"i_{dim}" - if f"i_{dim}" in map_ranges - else f"0:{closure_sdfg.arrays[output_name].shape[scan_dim_index]}" - ) - for dim, _ in closure_domain - } - output_memlets = [self.create_memlet_at(output_name, domain_subset)] - else: - nsdfg, map_ranges, results = self._visit_parallel_stencil_closure( - node, closure_sdfg.arrays, closure_domain - ) - - output_subset = "0" - - output_memlets = [ - self.create_memlet_at(output_name, {dim: f"i_{dim}" for dim, _ in closure_domain}) - for output_name in output_connectors_mapping.values() - ] - - input_mapping = { - param: arg for param, arg in zip([*input_names, *connectivity_names], input_memlets) - } - output_mapping = {param: memlet for param, memlet in zip(results, output_memlets)} - - symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, input_mapping) - - nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg( - closure_state, - sdfg=nsdfg, - map_ranges=map_ranges or {"__dummy": "0"}, - inputs=input_mapping, - outputs=output_mapping, - symbol_mapping=symbol_mapping, - output_nodes=output_nodes, - debuginfo=nsdfg.debuginfo, - ) - access_nodes = {edge.data.data: edge.dst for edge in closure_state.out_edges(map_exit)} - for edge in closure_state.in_edges(map_exit): - memlet = edge.data - if memlet.data not in output_connectors_mapping: - continue - transient_access = closure_state.add_access(memlet.data, debuginfo=nsdfg.debuginfo) - closure_state.add_edge( - nsdfg_node, - edge.src_conn, - transient_access, - None, - dace.Memlet(data=memlet.data, subset=output_subset, debuginfo=nsdfg.debuginfo), - ) - inner_memlet = dace.Memlet( - data=memlet.data, subset=output_subset, other_subset=memlet.subset - ) - closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet) - closure_state.remove_edge(edge) - access_nodes[memlet.data].data = output_connectors_mapping[memlet.data] - - return closure_sdfg, input_field_names + connectivity_names, output_names - - def _visit_scan_stencil_closure( - self, - node: itir.StencilClosure, - array_table: dict[str, dace.data.Array], - closure_domain: tuple[ - tuple[str, tuple[ValueExpr | SymbolExpr, ValueExpr | SymbolExpr]], ... - ], - output_name: str, - ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], int]: - # extract scan arguments - is_forward, init_carry_value = _get_scan_args(node.stencil) - # select the scan dimension based on program argument for column axis - assert self.column_axis - assert isinstance(node.output, SymRef) - scan_dim, scan_dim_index, scan_dtype = _get_scan_dim( - self.column_axis, - self.storage_types, - node.output, - self.use_field_canonical_representation, - ) - - assert isinstance(node.output, SymRef) - neighbor_tables = get_used_connectivities(node, self.offset_provider_type) - assert all( - isinstance(inp, SymRef) for inp in node.inputs - ) # backend only supports SymRef inputs, not `index` calls - input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - # find the scan dimension, same as output dimension, and exclude it from the map domain - map_ranges = {} - for dim, (lb, ub) in closure_domain: - lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value - ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value - if not dim == scan_dim: - map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" - else: - scan_lb_str = lb_str - scan_ub_str = ub_str - - # the scan operator is implemented as an SDFG to be nested in the closure SDFG - scan_sdfg = dace.SDFG(name="scan") - scan_sdfg.debuginfo = dace_utils.debug_info(node) - - # the carry value of the scan operator exists only in the scope of the scan sdfg - scan_carry_name = unique_var_name() - scan_sdfg.add_scalar( - scan_carry_name, dtype=dace_utils.as_dace_type(scan_dtype), transient=True - ) - - # create a loop region for lambda call over the scan dimension - scan_loop_var = f"i_{scan_dim}" - if is_forward: - scan_loop = LoopRegion( - label="scan", - condition_expr=f"{scan_loop_var} < {scan_ub_str}", - loop_var=scan_loop_var, - initialize_expr=f"{scan_loop_var} = {scan_lb_str}", - update_expr=f"{scan_loop_var} = {scan_loop_var} + 1", - inverted=False, - ) - else: - scan_loop = LoopRegion( - label="scan", - condition_expr=f"{scan_loop_var} >= {scan_lb_str}", - loop_var=scan_loop_var, - initialize_expr=f"{scan_loop_var} = {scan_ub_str} - 1", - update_expr=f"{scan_loop_var} = {scan_loop_var} - 1", - inverted=False, - ) - scan_sdfg.add_node(scan_loop) - compute_state = scan_loop.add_state("lambda_compute", is_start_block=True) - update_state = scan_loop.add_state("lambda_update") - scan_loop.add_edge(compute_state, update_state, dace.InterstateEdge()) - - start_state = scan_sdfg.add_state("start", is_start_block=True) - scan_sdfg.add_edge(start_state, scan_loop, dace.InterstateEdge()) - - # tasklet for initialization of carry - carry_init_tasklet = start_state.add_tasklet( - "get_carry_init_value", - {}, - {"__result"}, - f"__result = {init_carry_value}", - debuginfo=scan_sdfg.debuginfo, - ) - start_state.add_edge( - carry_init_tasklet, - "__result", - start_state.add_access(scan_carry_name, debuginfo=scan_sdfg.debuginfo), - None, - dace.Memlet(data=scan_carry_name, subset="0"), - ) - - # add storage to scan SDFG for inputs - for name in [*input_names, *connectivity_names]: - assert name not in scan_sdfg.arrays - if isinstance(self.storage_types[name], ts.FieldType): - scan_sdfg.add_array( - name, - shape=array_table[name].shape, - strides=array_table[name].strides, - dtype=array_table[name].dtype, - ) - else: - scan_sdfg.add_scalar( - name, - dtype=dace_utils.as_dace_type(cast(ts.ScalarType, self.storage_types[name])), - ) - # add storage to scan SDFG for output - scan_sdfg.add_array( - output_name, - shape=(array_table[node.output.id].shape[scan_dim_index],), - strides=(array_table[node.output.id].strides[scan_dim_index],), - dtype=array_table[node.output.id].dtype, - ) - - # implement the lambda function as a nested SDFG that computes a single item in the scan dimension - lambda_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} - input_arrays = [(scan_carry_name, scan_dtype)] + [ - (name, self.storage_types[name]) for name in input_names - ] - connectivity_arrays = [(scan_sdfg.arrays[name], name) for name in connectivity_names] - lambda_context, lambda_outputs = closure_to_tasklet_sdfg( - node, - self.offset_provider_type, - lambda_domain, - input_arrays, - connectivity_arrays, - self.use_field_canonical_representation, - ) - - lambda_input_names = [name for name, _ in input_arrays] - lambda_output_names = [connector.value.data for connector in lambda_outputs] - - input_memlets = [ - dace.Memlet.from_array(name, scan_sdfg.arrays[name]) for name in lambda_input_names - ] - connectivity_memlets = [ - dace.Memlet.from_array(name, scan_sdfg.arrays[name]) for name in connectivity_names - ] - input_mapping = {param: arg for param, arg in zip(lambda_input_names, input_memlets)} - connectivity_mapping = { - param: arg for param, arg in zip(connectivity_names, connectivity_memlets) - } - array_mapping = {**input_mapping, **connectivity_mapping} - symbol_mapping = map_nested_sdfg_symbols(scan_sdfg, lambda_context.body, array_mapping) - - scan_inner_node = compute_state.add_nested_sdfg( - lambda_context.body, - parent=scan_sdfg, - inputs=set(lambda_input_names) | set(connectivity_names), - outputs=set(lambda_output_names), - symbol_mapping=symbol_mapping, - debuginfo=lambda_context.body.debuginfo, - ) - - # connect scan SDFG to lambda inputs - for name, memlet in array_mapping.items(): - access_node = compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo) - compute_state.add_edge(access_node, None, scan_inner_node, name, memlet) - - output_names = [output_name] - assert len(lambda_output_names) == 1 - # connect lambda output to scan SDFG - for name, connector in zip(output_names, lambda_output_names): - compute_state.add_edge( - scan_inner_node, - connector, - compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo), - None, - dace.Memlet(data=name, subset=scan_loop_var), - ) - - update_state.add_nedge( - update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo), - update_state.add_access(scan_carry_name, debuginfo=lambda_context.body.debuginfo), - dace.Memlet(data=output_name, subset=scan_loop_var, other_subset="0"), - ) - - return scan_sdfg, map_ranges, scan_dim_index - - def _visit_parallel_stencil_closure( - self, - node: itir.StencilClosure, - array_table: dict[str, dace.data.Array], - closure_domain: tuple[ - tuple[str, tuple[ValueExpr | SymbolExpr, ValueExpr | SymbolExpr]], ... - ], - ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]: - neighbor_tables = get_used_connectivities(node, self.offset_provider_type) - assert all( - isinstance(inp, SymRef) for inp in node.inputs - ) # backend only supports SymRef inputs, not `index` calls - input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - # find the scan dimension, same as output dimension, and exclude it from the map domain - map_ranges = {} - for dim, (lb, ub) in closure_domain: - lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value - ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value - map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" - - # Create an SDFG for the tasklet that computes a single item of the output domain. - index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} - - input_arrays = [(name, self.storage_types[name]) for name in input_names] - connectivity_arrays = [(array_table[name], name) for name in connectivity_names] - - context, results = closure_to_tasklet_sdfg( - node, - self.offset_provider_type, - index_domain, - input_arrays, - connectivity_arrays, - self.use_field_canonical_representation, - ) - - return context.body, map_ranges, [r.value.data for r in results] - - def _visit_domain( - self, node: itir.FunCall, context: Context - ) -> tuple[tuple[str, tuple[SymbolExpr | ValueExpr, SymbolExpr | ValueExpr]], ...]: - assert isinstance(node.fun, itir.SymRef) - assert node.fun.id == "cartesian_domain" or node.fun.id == "unstructured_domain" - - bounds: list[tuple[str, tuple[ValueExpr, ValueExpr]]] = [] - - for named_range in node.args: - assert isinstance(named_range, itir.FunCall) - assert isinstance(named_range.fun, itir.SymRef) - assert len(named_range.args) == 3 - dimension = named_range.args[0] - assert isinstance(dimension, itir.AxisLiteral) - lower_bound = named_range.args[1] - upper_bound = named_range.args[2] - translator = PythonTaskletCodegen( - self.offset_provider_type, - context, - self.use_field_canonical_representation, - ) - lb = translator.visit(lower_bound)[0] - ub = translator.visit(upper_bound)[0] - bounds.append((dimension.value, (lb, ub))) - - return tuple(bounds) - - @staticmethod - def _check_shift_offsets_are_literals(node: itir.StencilClosure): - fun_calls = eve.walk_values(node).if_isinstance(itir.FunCall) - shifts = [nd for nd in fun_calls if getattr(nd.fun, "id", "") == "shift"] - for shift in shifts: - if not all(isinstance(arg, (itir.Literal, itir.OffsetLiteral)) for arg in shift.args): - return False - return True diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py deleted file mode 100644 index 2b2669187a..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ /dev/null @@ -1,1564 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import copy -import dataclasses -import itertools -from collections.abc import Sequence -from typing import Any, Callable, Optional, TypeAlias, cast - -import dace -import numpy as np - -import gt4py.eve.codegen -from gt4py import eve -from gt4py.next import common -from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir import FunCall, Lambda -from gt4py.next.iterator.type_system import type_specifications as it_ts -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.type_system import type_specifications as ts - -from .utility import ( - add_mapped_nested_sdfg, - flatten_list, - get_used_connectivities, - map_nested_sdfg_symbols, - new_array_symbols, - unique_name, - unique_var_name, -) - - -_TYPE_MAPPING = { - "float": dace.float64, - "float32": dace.float32, - "float64": dace.float64, - "int": dace.int32 if np.dtype(int).itemsize == 4 else dace.int64, - "int32": dace.int32, - "int64": dace.int64, - "bool": dace.bool_, -} - - -def itir_type_as_dace_type(type_: ts.TypeSpec): - # TODO(tehrengruber): this function just converts the scalar type of whatever it is given, - # let it be a field, iterator, or directly a scalar. The caller should take care of the - # extraction. - dtype: ts.TypeSpec - if isinstance(type_, ts.FieldType): - dtype = type_.dtype - elif isinstance(type_, it_ts.IteratorType): - dtype = type_.element_type - else: - dtype = type_ - assert isinstance(dtype, ts.ScalarType) - return _TYPE_MAPPING[dtype.kind.name.lower()] - - -def get_reduce_identity_value(op_name_: str, type_: Any): - if op_name_ == "plus": - init_value = type_(0) - elif op_name_ == "multiplies": - init_value = type_(1) - elif op_name_ == "minimum": - init_value = type_("inf") - elif op_name_ == "maximum": - init_value = type_("-inf") - else: - raise NotImplementedError() - - return init_value - - -_MATH_BUILTINS_MAPPING = { - "abs": "abs({})", - "sin": "math.sin({})", - "cos": "math.cos({})", - "tan": "math.tan({})", - "arcsin": "asin({})", - "arccos": "acos({})", - "arctan": "atan({})", - "sinh": "math.sinh({})", - "cosh": "math.cosh({})", - "tanh": "math.tanh({})", - "arcsinh": "asinh({})", - "arccosh": "acosh({})", - "arctanh": "atanh({})", - "sqrt": "math.sqrt({})", - "exp": "math.exp({})", - "log": "math.log({})", - "gamma": "tgamma({})", - "cbrt": "cbrt({})", - "isfinite": "isfinite({})", - "isinf": "isinf({})", - "isnan": "isnan({})", - "floor": "math.ifloor({})", - "ceil": "ceil({})", - "trunc": "trunc({})", - "minimum": "min({}, {})", - "maximum": "max({}, {})", - "fmod": "fmod({}, {})", - "power": "math.pow({}, {})", - "float": "dace.float64({})", - "float32": "dace.float32({})", - "float64": "dace.float64({})", - "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", - "int32": "dace.int32({})", - "int64": "dace.int64({})", - "bool": "dace.bool_({})", - "plus": "({} + {})", - "minus": "({} - {})", - "multiplies": "({} * {})", - "divides": "({} / {})", - "floordiv": "({} // {})", - "eq": "({} == {})", - "not_eq": "({} != {})", - "less": "({} < {})", - "less_equal": "({} <= {})", - "greater": "({} > {})", - "greater_equal": "({} >= {})", - "and_": "({} & {})", - "or_": "({} | {})", - "xor_": "({} ^ {})", - "mod": "({} % {})", - "not_": "(not {})", # ~ is not bitwise in numpy -} - - -# Define type of variables used for field indexing -_INDEX_DTYPE = _TYPE_MAPPING["int64"] - - -@dataclasses.dataclass -class SymbolExpr: - value: dace.symbolic.SymbolicType - dtype: dace.typeclass - - -@dataclasses.dataclass -class ValueExpr: - value: dace.nodes.AccessNode - dtype: dace.typeclass - - -@dataclasses.dataclass -class IteratorExpr: - field: dace.nodes.AccessNode - indices: dict[str, dace.nodes.AccessNode] - dtype: dace.typeclass - dimensions: list[str] - - -# Union of possible expression types -TaskletExpr: TypeAlias = IteratorExpr | SymbolExpr | ValueExpr - - -@dataclasses.dataclass -class Context: - body: dace.SDFG - state: dace.SDFGState - symbol_map: dict[str, TaskletExpr] - # if we encounter a reduction node, the reduction state needs to be pushed to child nodes - reduce_identity: Optional[SymbolExpr] - - def __init__( - self, - body: dace.SDFG, - state: dace.SDFGState, - symbol_map: dict[str, TaskletExpr], - reduce_identity: Optional[SymbolExpr] = None, - ): - self.body = body - self.state = state - self.symbol_map = symbol_map - self.reduce_identity = reduce_identity - - -def _visit_lift_in_neighbors_reduction( - transformer: PythonTaskletCodegen, - node: itir.FunCall, - node_args: Sequence[IteratorExpr | list[ValueExpr]], - connectivity_type: common.NeighborConnectivityType, - map_entry: dace.nodes.MapEntry, - map_exit: dace.nodes.MapExit, - neighbor_index_node: dace.nodes.AccessNode, - neighbor_value_node: dace.nodes.AccessNode, -) -> list[ValueExpr]: - assert transformer.context.reduce_identity is not None - neighbor_dim = connectivity_type.codomain.value - origin_dim = connectivity_type.source_dim.value - - lifted_args: list[IteratorExpr | ValueExpr] = [] - for arg in node_args: - if isinstance(arg, IteratorExpr): - if origin_dim in arg.indices: - lifted_indices = arg.indices.copy() - lifted_indices.pop(origin_dim) - lifted_indices[neighbor_dim] = neighbor_index_node - lifted_args.append( - IteratorExpr(arg.field, lifted_indices, arg.dtype, arg.dimensions) - ) - else: - lifted_args.append(arg) - else: - lifted_args.append(arg[0]) - - lift_context, inner_inputs, inner_outputs = transformer.visit(node.args[0], args=lifted_args) - assert len(inner_outputs) == 1 - inner_out_connector = inner_outputs[0].value.data - - input_nodes = {} - iterator_index_nodes = {} - lifted_index_connectors = [] - - for x, y in inner_inputs: - if isinstance(y, IteratorExpr): - field_connector, inner_index_table = x - input_nodes[field_connector] = y.field - for dim, connector in inner_index_table.items(): - if dim == neighbor_dim: - lifted_index_connectors.append(connector) - iterator_index_nodes[connector] = y.indices[dim] - else: - assert isinstance(y, ValueExpr) - input_nodes[x] = y.value - - neighbor_tables = get_used_connectivities(node.args[0], transformer.offset_provider_type) - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - parent_sdfg = transformer.context.body - parent_state = transformer.context.state - - input_mapping = { - connector: dace.Memlet.from_array(node.data, node.desc(parent_sdfg)) - for connector, node in input_nodes.items() - } - connectivity_mapping = { - name: dace.Memlet.from_array(name, parent_sdfg.arrays[name]) for name in connectivity_names - } - array_mapping = {**input_mapping, **connectivity_mapping} - symbol_mapping = map_nested_sdfg_symbols(parent_sdfg, lift_context.body, array_mapping) - - nested_sdfg_node = parent_state.add_nested_sdfg( - lift_context.body, - parent_sdfg, - inputs={*array_mapping.keys(), *iterator_index_nodes.keys()}, - outputs={inner_out_connector}, - symbol_mapping=symbol_mapping, - debuginfo=lift_context.body.debuginfo, - ) - - for connectivity_connector, memlet in connectivity_mapping.items(): - parent_state.add_memlet_path( - parent_state.add_access(memlet.data, debuginfo=lift_context.body.debuginfo), - map_entry, - nested_sdfg_node, - dst_conn=connectivity_connector, - memlet=memlet, - ) - - for inner_connector, access_node in input_nodes.items(): - parent_state.add_memlet_path( - access_node, - map_entry, - nested_sdfg_node, - dst_conn=inner_connector, - memlet=input_mapping[inner_connector], - ) - - for inner_connector, access_node in iterator_index_nodes.items(): - memlet = dace.Memlet(data=access_node.data, subset="0") - if inner_connector in lifted_index_connectors: - parent_state.add_edge(access_node, None, nested_sdfg_node, inner_connector, memlet) - else: - parent_state.add_memlet_path( - access_node, map_entry, nested_sdfg_node, dst_conn=inner_connector, memlet=memlet - ) - - parent_state.add_memlet_path( - nested_sdfg_node, - map_exit, - neighbor_value_node, - src_conn=inner_out_connector, - memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)), - ) - - if connectivity_type.has_skip_values: - # check neighbor validity on if/else inter-state edge - # use one branch for connectivity case - start_state = lift_context.body.add_state_before( - lift_context.body.start_state, - "start", - condition=f"{lifted_index_connectors[0]} != {neighbor_skip_value}", - ) - # use the other branch for skip value case - skip_neighbor_state = lift_context.body.add_state("skip_neighbor") - skip_neighbor_state.add_edge( - skip_neighbor_state.add_tasklet( - "identity", {}, {"val"}, f"val = {transformer.context.reduce_identity.value}" - ), - "val", - skip_neighbor_state.add_access(inner_outputs[0].value.data), - None, - dace.Memlet(data=inner_outputs[0].value.data, subset="0"), - ) - lift_context.body.add_edge( - start_state, - skip_neighbor_state, - dace.InterstateEdge(condition=f"{lifted_index_connectors[0]} == {neighbor_skip_value}"), - ) - - return [ValueExpr(neighbor_value_node, inner_outputs[0].dtype)] - - -def builtin_neighbors( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - sdfg: dace.SDFG = transformer.context.body - state: dace.SDFGState = transformer.context.state - - di = dace_utils.debug_info(node, default=sdfg.debuginfo) - offset_literal, data = node_args - assert isinstance(offset_literal, itir.OffsetLiteral) - offset_dim = offset_literal.value - assert isinstance(offset_dim, str) - connectivity_type = transformer.offset_provider_type[offset_dim] - if not isinstance(connectivity_type, common.NeighborConnectivityType): - raise NotImplementedError( - "Neighbor reduction only implemented for connectivity based on neighbor tables." - ) - - lift_node = None - if isinstance(data, FunCall): - assert isinstance(data.fun, itir.FunCall) - fun_node = data.fun - if isinstance(fun_node.fun, itir.SymRef) and fun_node.fun.id == "lift": - lift_node = fun_node - lift_args = transformer.visit(data.args) - iterator = next(filter(lambda x: isinstance(x, IteratorExpr), lift_args), None) - if lift_node is None: - iterator = transformer.visit(data) - assert isinstance(iterator, IteratorExpr) - field_desc = iterator.field.desc(transformer.context.body) - origin_index_node = iterator.indices[connectivity_type.source_dim.value] - - assert transformer.context.reduce_identity is not None - assert transformer.context.reduce_identity.dtype == iterator.dtype - - # gather the neighbors in a result array dimensioned for `max_neighbors` - neighbor_value_var = unique_var_name() - sdfg.add_array( - neighbor_value_var, - dtype=iterator.dtype, - shape=(connectivity_type.max_neighbors,), - transient=True, - ) - neighbor_value_node = state.add_access(neighbor_value_var, debuginfo=di) - - # allocate scalar to store index for direct addressing of neighbor field - neighbor_index_var = unique_var_name() - sdfg.add_scalar(neighbor_index_var, _INDEX_DTYPE, transient=True) - neighbor_index_node = state.add_access(neighbor_index_var, debuginfo=di) - - # generate unique map index name to avoid conflict with other maps inside same state - neighbor_map_index = unique_name(f"{offset_dim}_neighbor_map_idx") - me, mx = state.add_map( - f"{offset_dim}_neighbor_map", - ndrange={neighbor_map_index: f"0:{connectivity_type.max_neighbors}"}, - debuginfo=di, - ) - - table_name = dace_utils.connectivity_identifier(offset_dim) - shift_tasklet = state.add_tasklet( - "shift", - code=f"__result = __table[__idx, {neighbor_map_index}]", - inputs={"__table", "__idx"}, - outputs={"__result"}, - debuginfo=di, - ) - state.add_memlet_path( - state.add_access(table_name, debuginfo=di), - me, - shift_tasklet, - memlet=dace.Memlet.from_array(table_name, sdfg.arrays[table_name]), - dst_conn="__table", - ) - state.add_memlet_path( - origin_index_node, - me, - shift_tasklet, - memlet=dace.Memlet(data=origin_index_node.data, subset="0"), - dst_conn="__idx", - ) - state.add_edge( - shift_tasklet, - "__result", - neighbor_index_node, - None, - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - - if lift_node is not None: - _visit_lift_in_neighbors_reduction( - transformer, - lift_node, - lift_args, - connectivity_type, - me, - mx, - neighbor_index_node, - neighbor_value_node, - ) - else: - sorted_dims = transformer.get_sorted_field_dimensions(iterator.dimensions) - data_access_index = ",".join(f"{dim}_v" for dim in sorted_dims) - connector_neighbor_dim = f"{connectivity_type.codomain.value}_v" - data_access_tasklet = state.add_tasklet( - "data_access", - code=f"__data = __field[{data_access_index}] " - + ( - f"if {connector_neighbor_dim} != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" - if connectivity_type.has_skip_values - else "" - ), - inputs={"__field"} | {f"{dim}_v" for dim in iterator.dimensions}, - outputs={"__data"}, - debuginfo=di, - ) - state.add_memlet_path( - iterator.field, - me, - data_access_tasklet, - memlet=dace.Memlet.from_array(iterator.field.data, field_desc), - dst_conn="__field", - ) - for dim in iterator.dimensions: - connector = f"{dim}_v" - if dim == connectivity_type.codomain.value: - state.add_edge( - neighbor_index_node, - None, - data_access_tasklet, - connector, - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - else: - state.add_memlet_path( - iterator.indices[dim], - me, - data_access_tasklet, - dst_conn=connector, - memlet=dace.Memlet(data=iterator.indices[dim].data, subset="0"), - ) - - state.add_memlet_path( - data_access_tasklet, - mx, - neighbor_value_node, - memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index), - src_conn="__data", - ) - - if not connectivity_type.has_skip_values: - return [ValueExpr(neighbor_value_node, iterator.dtype)] - else: - """ - In case of neighbor tables with skip values, in addition to the array of neighbor values this function also - returns an array of booleans to indicate if the neighbor value is present or not. This node is only used - for neighbor reductions with lambda functions, a very specific case. For single input neighbor reductions, - the regular case, this node will be removed by the simplify pass. - """ - neighbor_valid_var = unique_var_name() - sdfg.add_array( - neighbor_valid_var, - dtype=dace.dtypes.bool, - shape=(connectivity_type.max_neighbors,), - transient=True, - ) - neighbor_valid_node = state.add_access(neighbor_valid_var, debuginfo=di) - - neighbor_valid_tasklet = state.add_tasklet( - f"check_valid_neighbor_{offset_dim}", - {"__idx"}, - {"__valid"}, - f"__valid = True if __idx != {neighbor_skip_value} else False", - debuginfo=di, - ) - state.add_edge( - neighbor_index_node, - None, - neighbor_valid_tasklet, - "__idx", - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - state.add_memlet_path( - neighbor_valid_tasklet, - mx, - neighbor_valid_node, - memlet=dace.Memlet(data=neighbor_valid_var, subset=neighbor_map_index), - src_conn="__valid", - ) - return [ - ValueExpr(neighbor_value_node, iterator.dtype), - ValueExpr(neighbor_valid_node, dace.dtypes.bool), - ] - - -def builtin_can_deref( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - # first visit shift, to get set of indices for deref - can_deref_callable = node_args[0] - assert isinstance(can_deref_callable, itir.FunCall) - shift_callable = can_deref_callable.fun - assert isinstance(shift_callable, itir.FunCall) - assert isinstance(shift_callable.fun, itir.SymRef) - assert shift_callable.fun.id == "shift" - iterator = transformer._visit_shift(can_deref_callable) - - # TODO: remove this special case when ITIR reduce-unroll pass is able to catch it - if not isinstance(iterator, IteratorExpr): - assert len(iterator) == 1 and isinstance(iterator[0], ValueExpr) - # We can always deref a value expression, therefore hard-code `can_deref` to True. - # Returning a SymbolExpr would be preferable, but it requires update to type-checking. - result_name = unique_var_name() - transformer.context.body.add_scalar(result_name, dace.dtypes.bool, transient=True) - result_node = transformer.context.state.add_access(result_name, debuginfo=di) - transformer.context.state.add_edge( - transformer.context.state.add_tasklet( - "can_always_deref", {}, {"_out"}, "_out = True", debuginfo=di - ), - "_out", - result_node, - None, - dace.Memlet(data=result_name, subset="0"), - ) - return [ValueExpr(result_node, dace.dtypes.bool)] - - # create tasklet to check that field indices are non-negative (-1 is invalid) - args = [ValueExpr(access_node, _INDEX_DTYPE) for access_node in iterator.indices.values()] - internals = [f"{arg.value.data}_v" for arg in args] - expr_code = " and ".join(f"{v} != {neighbor_skip_value}" for v in internals) - - return transformer.add_expr_tasklet( - list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref", dace_debuginfo=di - ) - - -def builtin_if( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - assert len(node_args) == 3 - sdfg = transformer.context.body - current_state = transformer.context.state - is_start_state = sdfg.start_block == current_state - - # build an empty state to join true and false branches - join_state = sdfg.add_state_before(current_state, "join") - - def build_if_state(arg, state): - symbol_map = copy.deepcopy(transformer.context.symbol_map) - node_context = Context(sdfg, state, symbol_map) - node_taskgen = PythonTaskletCodegen( - transformer.offset_provider_type, - node_context, - transformer.use_field_canonical_representation, - ) - return node_taskgen.visit(arg) - - # represent the if-statement condition as a tasklet inside an `if_statement` state preceding `join` state - stmt_state = sdfg.add_state_before(join_state, "if_statement", is_start_state) - stmt_node = build_if_state(node_args[0], stmt_state)[0] - assert isinstance(stmt_node, ValueExpr) - assert stmt_node.dtype == dace.dtypes.bool - assert sdfg.arrays[stmt_node.value.data].shape == (1,) - - # visit true and false branches (here called `tbr` and `fbr`) as separate states, following `if_statement` state - tbr_state = sdfg.add_state("true_branch") - sdfg.add_edge( - stmt_state, tbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == True") - ) - sdfg.add_edge(tbr_state, join_state, dace.InterstateEdge()) - tbr_values = flatten_list(build_if_state(node_args[1], tbr_state)) - # - fbr_state = sdfg.add_state("false_branch") - sdfg.add_edge( - stmt_state, fbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == False") - ) - sdfg.add_edge(fbr_state, join_state, dace.InterstateEdge()) - fbr_values = flatten_list(build_if_state(node_args[2], fbr_state)) - - assert isinstance(stmt_node, ValueExpr) - assert stmt_node.dtype == dace.dtypes.bool - # make the result of the if-statement evaluation available inside current state - ctx_stmt_node = ValueExpr(current_state.add_access(stmt_node.value.data), stmt_node.dtype) - - # we distinguish between select if-statements, where both true and false branches are symbolic expressions, - # and therefore do not require exclusive branch execution, and regular if-statements where at least one branch - # is a value expression, which has to be evaluated at runtime with conditional state transition - result_values = [] - assert len(tbr_values) == len(fbr_values) - for tbr_value, fbr_value in zip(tbr_values, fbr_values): - assert isinstance(tbr_value, (SymbolExpr, ValueExpr)) - assert isinstance(fbr_value, (SymbolExpr, ValueExpr)) - assert tbr_value.dtype == fbr_value.dtype - - if all(isinstance(x, SymbolExpr) for x in (tbr_value, fbr_value)): - # both branches return symbolic expressions, therefore the if-node can be translated - # to a select-tasklet inside current state - # TODO: use select-memlet when it becomes available in dace - code = f"{tbr_value.value} if _cond else {fbr_value.value}" - if_expr = transformer.add_expr_tasklet( - [(ctx_stmt_node, "_cond")], code, tbr_value.dtype, "if_select" - )[0] - result_values.append(if_expr) - else: - # at least one of the two branches contains a value expression, which should be evaluated - # only if the corresponding true/false condition is satisfied - desc = sdfg.arrays[ - tbr_value.value.data if isinstance(tbr_value, ValueExpr) else fbr_value.value.data - ] - var = unique_var_name() - if isinstance(desc, dace.data.Scalar): - sdfg.add_scalar(var, desc.dtype, transient=True) - else: - sdfg.add_array(var, desc.shape, desc.dtype, transient=True) - - # write result to transient data container and access it in the original state - for state, expr in [(tbr_state, tbr_value), (fbr_state, fbr_value)]: - val_node = state.add_access(var) - if isinstance(expr, ValueExpr): - state.add_nedge( - expr.value, val_node, dace.Memlet.from_array(expr.value.data, desc) - ) - else: - assert desc.shape == (1,) - state.add_edge( - state.add_tasklet("write_symbol", {}, {"_out"}, f"_out = {expr.value}"), - "_out", - val_node, - None, - dace.Memlet(var, "0"), - ) - result_values.append(ValueExpr(current_state.add_access(var), desc.dtype)) - - if tbr_state.is_empty() and fbr_state.is_empty(): - # if all branches are symbolic expressions, the true/false and join states can be removed - # as well as the conditional state transition - sdfg.remove_nodes_from([join_state, tbr_state, fbr_state]) - sdfg.add_edge(stmt_state, current_state, dace.InterstateEdge()) - elif tbr_state.is_empty(): - # use direct edge from if-statement to join state for true branch - tbr_condition = sdfg.edges_between(stmt_state, tbr_state)[0].condition - sdfg.edges_between(stmt_state, join_state)[0].contition = tbr_condition - sdfg.remove_node(tbr_state) - elif fbr_state.is_empty(): - # use direct edge from if-statement to join state for false branch - fbr_condition = sdfg.edges_between(stmt_state, fbr_state)[0].condition - sdfg.edges_between(stmt_state, join_state)[0].contition = fbr_condition - sdfg.remove_node(fbr_state) - else: - # remove direct edge from if-statement to join state - sdfg.remove_edge(sdfg.edges_between(stmt_state, join_state)[0]) - # the if-statement condition is not used in current state - current_state.remove_node(ctx_stmt_node.value) - - return result_values - - -def builtin_list_get( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - args = list(itertools.chain(*transformer.visit(node_args))) - assert len(args) == 2 - # index node - if isinstance(args[0], SymbolExpr): - index_value = args[0].value - result_name = unique_var_name() - transformer.context.body.add_scalar(result_name, args[1].dtype, transient=True) - result_node = transformer.context.state.add_access(result_name) - transformer.context.state.add_nedge( - args[1].value, result_node, dace.Memlet(data=args[1].value.data, subset=index_value) - ) - return [ValueExpr(result_node, args[1].dtype)] - - else: - expr_args = [(arg, f"{arg.value.data}_v") for arg in args] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[1]}[{internals[0]}]" - return transformer.add_expr_tasklet( - expr_args, expr, args[1].dtype, "list_get", dace_debuginfo=di - ) - - -def builtin_cast( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - args = transformer.visit(node_args[0]) - internals = [f"{arg.value.data}_v" for arg in args] - target_type = node_args[1] - assert isinstance(target_type, itir.SymRef) - expr = _MATH_BUILTINS_MAPPING[target_type.id].format(*internals) - type_ = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference - return transformer.add_expr_tasklet( - list(zip(args, internals)), expr, type_, "cast", dace_debuginfo=di - ) - - -def builtin_make_const_list( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - args = [transformer.visit(arg)[0] for arg in node_args] - assert all(isinstance(x, (SymbolExpr, ValueExpr)) for x in args) - args_dtype = [x.dtype for x in args] - assert len(set(args_dtype)) == 1 - dtype = args_dtype[0] - - var_name = unique_var_name() - transformer.context.body.add_array(var_name, (len(args),), dtype, transient=True) - var_node = transformer.context.state.add_access(var_name, debuginfo=di) - - for i, arg in enumerate(args): - if isinstance(arg, SymbolExpr): - transformer.context.state.add_edge( - transformer.context.state.add_tasklet( - f"get_arg{i}", {}, {"val"}, f"val = {arg.value}" - ), - "val", - var_node, - None, - dace.Memlet(data=var_name, subset=f"{i}"), - ) - else: - assert arg.value.desc(transformer.context.body).shape == (1,) - transformer.context.state.add_nedge( - arg.value, - var_node, - dace.Memlet(data=arg.value.data, subset="0", other_subset=f"{i}"), - ) - - return [ValueExpr(var_node, dtype)] - - -def builtin_make_tuple( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - args = [transformer.visit(arg) for arg in node_args] - return args - - -def builtin_tuple_get( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - elements = transformer.visit(node_args[1]) - index = node_args[0] - if isinstance(index, itir.Literal): - return [elements[int(index.value)]] - raise ValueError("Tuple can only be subscripted with compile-time constants.") - - -_GENERAL_BUILTIN_MAPPING: dict[ - str, Callable[[PythonTaskletCodegen, itir.Expr, list[itir.Expr]], list[ValueExpr]] -] = { - "can_deref": builtin_can_deref, - "cast_": builtin_cast, - "if_": builtin_if, - "list_get": builtin_list_get, - "make_const_list": builtin_make_const_list, - "make_tuple": builtin_make_tuple, - "neighbors": builtin_neighbors, - "tuple_get": builtin_tuple_get, -} - - -class GatherLambdaSymbolsPass(eve.NodeVisitor): - _sdfg: dace.SDFG - _state: dace.SDFGState - _symbol_map: dict[str, TaskletExpr | tuple[ValueExpr]] - _parent_symbol_map: dict[str, TaskletExpr] - - def __init__(self, sdfg, state, parent_symbol_map): - self._sdfg = sdfg - self._state = state - self._symbol_map = {} - self._parent_symbol_map = parent_symbol_map - - @property - def symbol_refs(self): - """Dictionary of symbols referenced from the lambda expression.""" - return self._symbol_map - - def _add_symbol(self, param, arg): - if isinstance(arg, ValueExpr): - # create storage in lambda sdfg - self._sdfg.add_scalar(param, dtype=arg.dtype) - # update table of lambda symbols - self._symbol_map[param] = ValueExpr( - self._state.add_access(param, debuginfo=self._sdfg.debuginfo), arg.dtype - ) - elif isinstance(arg, IteratorExpr): - # create storage in lambda sdfg - ndims = len(arg.dimensions) - shape, strides = new_array_symbols(param, ndims) - self._sdfg.add_array(param, shape=shape, strides=strides, dtype=arg.dtype) - index_names = {dim: f"__{param}_i_{dim}" for dim in arg.indices.keys()} - for _, index_name in index_names.items(): - self._sdfg.add_scalar(index_name, dtype=_INDEX_DTYPE) - # update table of lambda symbols - field = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) - indices = { - dim: self._state.add_access(index_arg, debuginfo=self._sdfg.debuginfo) - for dim, index_arg in index_names.items() - } - self._symbol_map[param] = IteratorExpr(field, indices, arg.dtype, arg.dimensions) - else: - assert isinstance(arg, SymbolExpr) - self._symbol_map[param] = arg - - def _add_tuple(self, param, args): - nodes = [] - # create storage in lambda sdfg for each tuple element - for arg in args: - var = unique_var_name() - self._sdfg.add_scalar(var, dtype=arg.dtype) - arg_node = self._state.add_access(var, debuginfo=self._sdfg.debuginfo) - nodes.append(ValueExpr(arg_node, arg.dtype)) - # update table of lambda symbols - self._symbol_map[param] = tuple(nodes) - - def visit_SymRef(self, node: itir.SymRef): - name = str(node.id) - if name in self._parent_symbol_map and name not in self._symbol_map: - arg = self._parent_symbol_map[name] - self._add_symbol(name, arg) - - def visit_Lambda(self, node: itir.Lambda, args: Optional[Sequence[TaskletExpr]] = None): - if args is not None: - if len(node.params) == len(args): - for param, arg in zip(node.params, args): - self._add_symbol(str(param.id), arg) - else: - # implicitly make tuple - assert len(node.params) == 1 - self._add_tuple(str(node.params[0].id), args) - self.visit(node.expr) - - -class GatherOutputSymbolsPass(eve.NodeVisitor): - _sdfg: dace.SDFG - _state: dace.SDFGState - _symbol_map: dict[str, TaskletExpr] - - @property - def symbol_refs(self): - """Dictionary of symbols referenced from the output expression.""" - return self._symbol_map - - def __init__(self, sdfg, state): - self._sdfg = sdfg - self._state = state - self._symbol_map = {} - - def visit_SymRef(self, node: itir.SymRef): - param = str(node.id) - if param not in _GENERAL_BUILTIN_MAPPING and param not in self._symbol_map: - access_node = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) - self._symbol_map[param] = ValueExpr( - access_node, - dtype=itir_type_as_dace_type(node.type), # type: ignore[arg-type] # ensure by type inference - ) - - -@dataclasses.dataclass -class PythonTaskletCodegen(gt4py.eve.codegen.TemplatedGenerator): - offset_provider_type: common.OffsetProviderType - context: Context - use_field_canonical_representation: bool - - def get_sorted_field_dimensions(self, dims: Sequence[str]): - return sorted(dims) if self.use_field_canonical_representation else dims - - def visit_FunctionDefinition(self, node: itir.FunctionDefinition, **kwargs): - raise NotImplementedError() - - def visit_Lambda( - self, node: itir.Lambda, args: Sequence[TaskletExpr], use_neighbor_tables: bool = True - ) -> tuple[ - Context, - list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]], - list[ValueExpr], - ]: - func_name = f"lambda_{abs(hash(node)):x}" - neighbor_tables = ( - get_used_connectivities(node, self.offset_provider_type) if use_neighbor_tables else {} - ) - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - # Create the SDFG for the lambda's body - lambda_sdfg = dace.SDFG(func_name) - lambda_sdfg.debuginfo = dace_utils.debug_info(node, default=self.context.body.debuginfo) - lambda_state = lambda_sdfg.add_state(f"{func_name}_body", is_start_block=True) - - lambda_symbols_pass = GatherLambdaSymbolsPass( - lambda_sdfg, lambda_state, self.context.symbol_map - ) - lambda_symbols_pass.visit(node, args=args) - - # Add for input nodes for lambda symbols - inputs: list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]] = [] - for sym, input_node in lambda_symbols_pass.symbol_refs.items(): - params = [str(p.id) for p in node.params] - try: - param_index = params.index(sym) - except ValueError: - param_index = -1 - if param_index >= 0: - outer_node = args[param_index] - else: - # the symbol is not found among lambda arguments, then it is inherited from parent scope - outer_node = self.context.symbol_map[sym] - if isinstance(input_node, IteratorExpr): - assert isinstance(outer_node, IteratorExpr) - index_params = { - dim: index_node.data for dim, index_node in input_node.indices.items() - } - inputs.append(((sym, index_params), outer_node)) - elif isinstance(input_node, ValueExpr): - assert isinstance(outer_node, ValueExpr) - inputs.append((sym, outer_node)) - elif isinstance(input_node, tuple): - assert param_index >= 0 - for i, input_node_i in enumerate(input_node): - arg_i = args[param_index + i] - assert isinstance(arg_i, ValueExpr) - assert isinstance(input_node_i, ValueExpr) - inputs.append((input_node_i.value.data, arg_i)) - - # Add connectivities as arrays - for name in connectivity_names: - shape, strides = new_array_symbols(name, ndim=2) - dtype = self.context.body.arrays[name].dtype - lambda_sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) - - # Translate the lambda's body in its own context - lambda_context = Context( - lambda_sdfg, - lambda_state, - lambda_symbols_pass.symbol_refs, - reduce_identity=self.context.reduce_identity, - ) - lambda_taskgen = PythonTaskletCodegen( - self.offset_provider_type, - lambda_context, - self.use_field_canonical_representation, - ) - - results: list[ValueExpr] = [] - # We are flattening the returned list of value expressions because the multiple outputs of a lambda - # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this. - node.expr.location = node.location - for expr in flatten_list(lambda_taskgen.visit(node.expr)): - if isinstance(expr, ValueExpr): - result_name = unique_var_name() - lambda_sdfg.add_scalar(result_name, expr.dtype, transient=True) - result_access = lambda_state.add_access( - result_name, debuginfo=lambda_sdfg.debuginfo - ) - lambda_state.add_nedge( - expr.value, result_access, dace.Memlet(data=result_access.data, subset="0") - ) - result = ValueExpr(value=result_access, dtype=expr.dtype) - else: - # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors - result = lambda_taskgen.add_expr_tasklet( - [], expr.value, expr.dtype, "forward", dace_debuginfo=lambda_sdfg.debuginfo - )[0] - lambda_sdfg.arrays[result.value.data].transient = False - results.append(result) - - # remove isolated access nodes for connectivity arrays not consumed by lambda - for sub_node in lambda_state.nodes(): - if isinstance(sub_node, dace.nodes.AccessNode): - if lambda_state.out_degree(sub_node) == 0 and lambda_state.in_degree(sub_node) == 0: - lambda_state.remove_node(sub_node) - - return lambda_context, inputs, results - - def visit_SymRef(self, node: itir.SymRef) -> list[ValueExpr | SymbolExpr] | IteratorExpr: - param = str(node.id) - value = self.context.symbol_map[param] - if isinstance(value, (ValueExpr, SymbolExpr)): - return [value] - return value - - def visit_Literal(self, node: itir.Literal) -> list[SymbolExpr]: - return [SymbolExpr(node.value, itir_type_as_dace_type(node.type))] - - def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: - node.fun.location = node.location - if isinstance(node.fun, itir.SymRef) and node.fun.id == "deref": - return self._visit_deref(node) - if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): - if node.fun.fun.id == "shift": - return self._visit_shift(node) - elif node.fun.fun.id == "reduce": - return self._visit_reduce(node) - - if isinstance(node.fun, itir.SymRef): - builtin_name = str(node.fun.id) - if builtin_name in _MATH_BUILTINS_MAPPING: - return self._visit_numeric_builtin(node) - elif builtin_name in _GENERAL_BUILTIN_MAPPING: - return self._visit_general_builtin(node) - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") - return self._visit_call(node) - - def _visit_call(self, node: itir.FunCall): - args = self.visit(node.args) - args = [arg if isinstance(arg, Sequence) else [arg] for arg in args] - args = list(itertools.chain(*args)) - node.fun.location = node.location - func_context, func_inputs, results = self.visit(node.fun, args=args) - - nsdfg_inputs = {} - for name, value in func_inputs: - if isinstance(value, ValueExpr): - nsdfg_inputs[name] = dace.Memlet.from_array( - value.value.data, self.context.body.arrays[value.value.data] - ) - else: - assert isinstance(value, IteratorExpr) - field = name[0] - indices = name[1] - nsdfg_inputs[field] = dace.Memlet.from_array( - value.field.data, self.context.body.arrays[value.field.data] - ) - for dim, var in indices.items(): - store = value.indices[dim].data - nsdfg_inputs[var] = dace.Memlet.from_array( - store, self.context.body.arrays[store] - ) - - neighbor_tables = get_used_connectivities(node.fun, self.offset_provider_type) - for offset in neighbor_tables.keys(): - var = dace_utils.connectivity_identifier(offset) - nsdfg_inputs[var] = dace.Memlet.from_array(var, self.context.body.arrays[var]) - - symbol_mapping = map_nested_sdfg_symbols(self.context.body, func_context.body, nsdfg_inputs) - - nsdfg_node = self.context.state.add_nested_sdfg( - func_context.body, - None, - inputs=set(nsdfg_inputs.keys()), - outputs=set(r.value.data for r in results), - symbol_mapping=symbol_mapping, - debuginfo=dace_utils.debug_info(node, default=func_context.body.debuginfo), - ) - - for name, value in func_inputs: - if isinstance(value, ValueExpr): - value_memlet = nsdfg_inputs[name] - self.context.state.add_edge(value.value, None, nsdfg_node, name, value_memlet) - else: - assert isinstance(value, IteratorExpr) - field = name[0] - indices = name[1] - field_memlet = nsdfg_inputs[field] - self.context.state.add_edge(value.field, None, nsdfg_node, field, field_memlet) - for dim, var in indices.items(): - store = value.indices[dim] - idx_memlet = nsdfg_inputs[var] - self.context.state.add_edge(store, None, nsdfg_node, var, idx_memlet) - for offset in neighbor_tables.keys(): - var = dace_utils.connectivity_identifier(offset) - memlet = nsdfg_inputs[var] - access = self.context.state.add_access(var, debuginfo=nsdfg_node.debuginfo) - self.context.state.add_edge(access, None, nsdfg_node, var, memlet) - - result_exprs = [] - for result in results: - name = unique_var_name() - self.context.body.add_scalar(name, result.dtype, transient=True) - result_access = self.context.state.add_access(name, debuginfo=nsdfg_node.debuginfo) - result_exprs.append(ValueExpr(result_access, result.dtype)) - memlet = dace.Memlet.from_array(name, self.context.body.arrays[name]) - self.context.state.add_edge(nsdfg_node, result.value.data, result_access, None, memlet) - - return result_exprs - - def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - iterator = self.visit(node.args[0]) - if not isinstance(iterator, IteratorExpr): - # already a list of ValueExpr - return iterator - - sorted_dims = self.get_sorted_field_dimensions(iterator.dimensions) - if all([dim in iterator.indices for dim in iterator.dimensions]): - # The deref iterator has index values on all dimensions: the result will be a scalar - args = [ValueExpr(iterator.field, iterator.dtype)] + [ - ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in sorted_dims - ] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{', '.join(internals[1:])}]" - return self.add_expr_tasklet( - list(zip(args, internals)), expr, iterator.dtype, "deref", dace_debuginfo=di - ) - - else: - dims_not_indexed = [dim for dim in iterator.dimensions if dim not in iterator.indices] - assert len(dims_not_indexed) == 1 - offset = dims_not_indexed[0] - offset_provider_type = self.offset_provider_type[offset] - assert isinstance(offset_provider_type, common.NeighborConnectivityType) - neighbor_dim = offset_provider_type.codomain.value - - result_name = unique_var_name() - self.context.body.add_array( - result_name, (offset_provider_type.max_neighbors,), iterator.dtype, transient=True - ) - result_array = self.context.body.arrays[result_name] - result_node = self.context.state.add_access(result_name, debuginfo=di) - - deref_connectors = ["_inp"] + [ - f"_i_{dim}" for dim in sorted_dims if dim in iterator.indices - ] - deref_nodes = [iterator.field] + [ - iterator.indices[dim] for dim in sorted_dims if dim in iterator.indices - ] - deref_memlets = [ - dace.Memlet.from_array(iterator.field.data, iterator.field.desc(self.context.body)) - ] + [dace.Memlet(data=node.data, subset="0") for node in deref_nodes[1:]] - - # we create a mapped tasklet for array slicing - index_name = unique_name(f"_i_{neighbor_dim}") - map_ranges = {index_name: f"0:{offset_provider_type.max_neighbors}"} - src_subset = ",".join( - [f"_i_{dim}" if dim in iterator.indices else index_name for dim in sorted_dims] - ) - self.context.state.add_mapped_tasklet( - "deref", - map_ranges, - inputs={k: v for k, v in zip(deref_connectors, deref_memlets)}, - outputs={"_out": dace.Memlet.from_array(result_name, result_array)}, - code=f"_out[{index_name}] = _inp[{src_subset}]", - external_edges=True, - input_nodes={node.data: node for node in deref_nodes}, - output_nodes={result_name: result_node}, - debuginfo=di, - ) - return [ValueExpr(result_node, iterator.dtype)] - - def _split_shift_args( - self, args: list[itir.Expr] - ) -> tuple[list[itir.Expr], Optional[list[itir.Expr]]]: - pairs = [args[i : i + 2] for i in range(0, len(args), 2)] - assert len(pairs) >= 1 - assert all(len(pair) == 2 for pair in pairs) - return pairs[-1], list(itertools.chain(*pairs[0:-1])) if len(pairs) > 1 else None - - def _make_shift_for_rest(self, rest, iterator): - return itir.FunCall( - fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), - args=[iterator], - location=iterator.location, - ) - - def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - shift = node.fun - assert isinstance(shift, itir.FunCall) - tail, rest = self._split_shift_args(shift.args) - if rest: - iterator = self.visit(self._make_shift_for_rest(rest, node.args[0])) - else: - iterator = self.visit(node.args[0]) - if not isinstance(iterator, IteratorExpr): - # shift cannot be applied because the argument is not iterable - # TODO: remove this special case when ITIR pass is able to catch it - assert isinstance(iterator, list) and len(iterator) == 1 - assert isinstance(iterator[0], ValueExpr) - return iterator - - assert isinstance(tail[0], itir.OffsetLiteral) - offset_dim = tail[0].value - assert isinstance(offset_dim, str) - offset_node = self.visit(tail[1])[0] - assert offset_node.dtype in dace.dtypes.INTEGER_TYPES - - if isinstance(self.offset_provider_type[offset_dim], common.NeighborConnectivityType): - offset_provider_type = cast( - common.NeighborConnectivityType, self.offset_provider_type[offset_dim] - ) # ensured by condition - connectivity = self.context.state.add_access( - dace_utils.connectivity_identifier(offset_dim), debuginfo=di - ) - - shifted_dim_tag = offset_provider_type.source_dim.value - target_dim_tag = offset_provider_type.codomain.value - args = [ - ValueExpr(connectivity, _INDEX_DTYPE), - ValueExpr(iterator.indices[shifted_dim_tag], offset_node.dtype), - offset_node, - ] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{internals[1]}, {internals[2]}]" - else: - shifted_dim = self.offset_provider_type[offset_dim] - assert isinstance(shifted_dim, common.Dimension) - - shifted_dim_tag = shifted_dim.value - target_dim_tag = shifted_dim_tag - args = [ValueExpr(iterator.indices[shifted_dim_tag], offset_node.dtype), offset_node] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]} + {internals[1]}" - - shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, offset_node.dtype, "shift", dace_debuginfo=di - )[0].value - - shifted_index = {dim: value for dim, value in iterator.indices.items()} - del shifted_index[shifted_dim_tag] - shifted_index[target_dim_tag] = shifted_value - - return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) - - def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - offset = node.value - assert isinstance(offset, int) - offset_var = unique_var_name() - self.context.body.add_scalar(offset_var, _INDEX_DTYPE, transient=True) - offset_node = self.context.state.add_access(offset_var, debuginfo=di) - tasklet_node = self.context.state.add_tasklet( - "get_offset", {}, {"__out"}, f"__out = {offset}", debuginfo=di - ) - self.context.state.add_edge( - tasklet_node, "__out", offset_node, None, dace.Memlet(data=offset_var, subset="0") - ) - return [ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)] - - def _visit_reduce(self, node: itir.FunCall): - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - reduce_dtype = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference - - if len(node.args) == 1: - assert ( - isinstance(node.args[0], itir.FunCall) - and isinstance(node.args[0].fun, itir.SymRef) - and node.args[0].fun.id == "neighbors" - ) - assert isinstance(node.fun, itir.FunCall) - op_name = node.fun.args[0] - assert isinstance(op_name, itir.SymRef) - reduce_identity = node.fun.args[1] - assert isinstance(reduce_identity, itir.Literal) - - # set reduction state - self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) - - args = self.visit(node.args[0]) - - assert 1 <= len(args) <= 2 - reduce_input_node = args[0].value - - else: - assert isinstance(node.fun, itir.FunCall) - assert isinstance(node.fun.args[0], itir.Lambda) - fun_node = node.fun.args[0] - assert isinstance(fun_node.expr, itir.FunCall) - - op_name = fun_node.expr.fun - assert isinstance(op_name, itir.SymRef) - reduce_identity = get_reduce_identity_value(op_name.id, reduce_dtype) - - # set reduction state in visit context - self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) - - args = self.visit(node.args) - - # clear context - self.context.reduce_identity = None - - # check that all neighbor expressions have the same shape - args_shape = [ - arg[0].value.desc(self.context.body).shape - for arg in args - if arg[0].value.desc(self.context.body).shape != (1,) - ] - assert len(set(args_shape)) == 1 - nreduce_shape = args_shape[0] - - input_args = [arg[0] for arg in args] - input_valid_args = [arg[1] for arg in args if len(arg) == 2] - - assert len(nreduce_shape) == 1 - nreduce_index = unique_name("_i") - nreduce_domain = {nreduce_index: f"0:{nreduce_shape[0]}"} - - reduce_input_name = unique_var_name() - self.context.body.add_array( - reduce_input_name, nreduce_shape, reduce_dtype, transient=True - ) - - lambda_node = itir.Lambda( - expr=fun_node.expr.args[1], params=fun_node.params[1:], location=node.location - ) - lambda_context, inner_inputs, inner_outputs = self.visit( - lambda_node, args=input_args, use_neighbor_tables=False - ) - - input_mapping = { - param: ( - dace.Memlet(data=arg.value.data, subset="0") - if arg.value.desc(self.context.body).shape == (1,) - else dace.Memlet(data=arg.value.data, subset=nreduce_index) - ) - for (param, _), arg in zip(inner_inputs, input_args) - } - output_mapping = { - inner_outputs[0].value.data: dace.Memlet( - data=reduce_input_name, subset=nreduce_index - ) - } - symbol_mapping = map_nested_sdfg_symbols( - self.context.body, lambda_context.body, input_mapping - ) - - if input_valid_args: - """ - The neighbors builtin returns an array of booleans in case the connectivity table contains skip values. - These booleans indicate whether the neighbor is present or not, and are used in a tasklet to select - the result of field access or the identity value, respectively. - If the neighbor table has full connectivity (no skip values by type definition), the input_valid node - is not built, and the construction of the select tasklet below is also skipped. - """ - input_args.append(input_valid_args[0]) - input_valid_node = input_valid_args[0].value - lambda_output_node = inner_outputs[0].value - # add input connector to nested sdfg - lambda_context.body.add_scalar("_valid_neighbor", dace.dtypes.bool) - input_mapping["_valid_neighbor"] = dace.Memlet( - data=input_valid_node.data, subset=nreduce_index - ) - # add select tasklet before writing to output node - # TODO: consider replacing it with a select-memlet once it is supported by DaCe SDFG API - output_edge = lambda_context.state.in_edges(lambda_output_node)[0] - assert isinstance( - lambda_context.body.arrays[output_edge.src.data], dace.data.Scalar - ) - select_tasklet = lambda_context.state.add_tasklet( - "neighbor_select", - {"_inp", "_valid"}, - {"_out"}, - f"_out = _inp if _valid else {reduce_identity}", - ) - lambda_context.state.add_edge( - output_edge.src, - None, - select_tasklet, - "_inp", - dace.Memlet(data=output_edge.src.data, subset="0"), - ) - lambda_context.state.add_edge( - lambda_context.state.add_access("_valid_neighbor"), - None, - select_tasklet, - "_valid", - dace.Memlet(data="_valid_neighbor", subset="0"), - ) - lambda_context.state.add_edge( - select_tasklet, - "_out", - lambda_output_node, - None, - dace.Memlet(data=lambda_output_node.data, subset="0"), - ) - lambda_context.state.remove_edge(output_edge) - - reduce_input_node = self.context.state.add_access(reduce_input_name, debuginfo=di) - - nsdfg_node, map_entry, _ = add_mapped_nested_sdfg( - self.context.state, - sdfg=lambda_context.body, - map_ranges=nreduce_domain, - inputs=input_mapping, - outputs=output_mapping, - symbol_mapping=symbol_mapping, - input_nodes={arg.value.data: arg.value for arg in input_args}, - output_nodes={reduce_input_name: reduce_input_node}, - debuginfo=di, - ) - - reduce_input_desc = reduce_input_node.desc(self.context.body) - - result_name = unique_var_name() - # we allocate an array instead of a scalar because the reduce library node is generic and expects an array node - self.context.body.add_array(result_name, (1,), reduce_dtype, transient=True) - result_access = self.context.state.add_access(result_name, debuginfo=di) - - reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y") - reduce_node = self.context.state.add_reduce(reduce_wcr, None, reduce_identity) - self.context.state.add_nedge( - reduce_input_node, - reduce_node, - dace.Memlet.from_array(reduce_input_node.data, reduce_input_desc), - ) - self.context.state.add_nedge( - reduce_node, result_access, dace.Memlet(data=result_name, subset="0") - ) - - return [ValueExpr(result_access, reduce_dtype)] - - def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: - assert isinstance(node.fun, itir.SymRef) - fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] - args = flatten_list(self.visit(node.args)) - expr_args = [ - (arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr) - ] - internals = [ - arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args - ] - expr = fmt.format(*internals) - type_ = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference - return self.add_expr_tasklet( - expr_args, - expr, - type_, - "numeric", - dace_debuginfo=dace_utils.debug_info(node, default=self.context.body.debuginfo), - ) - - def _visit_general_builtin(self, node: itir.FunCall) -> list[ValueExpr]: - assert isinstance(node.fun, itir.SymRef) - expr_func = _GENERAL_BUILTIN_MAPPING[str(node.fun.id)] - return expr_func(self, node, node.args) - - def add_expr_tasklet( - self, - args: list[tuple[ValueExpr, str]], - expr: str, - result_type: Any, - name: str, - dace_debuginfo: Optional[dace.dtypes.DebugInfo] = None, - ) -> list[ValueExpr]: - di = dace_debuginfo if dace_debuginfo else self.context.body.debuginfo - result_name = unique_var_name() - self.context.body.add_scalar(result_name, result_type, transient=True) - result_access = self.context.state.add_access(result_name, debuginfo=di) - - expr_tasklet = self.context.state.add_tasklet( - name=name, - inputs={internal for _, internal in args}, - outputs={"__result"}, - code=f"__result = {expr}", - debuginfo=di, - ) - - for arg, internal in args: - edges = self.context.state.in_edges(expr_tasklet) - used = False - for edge in edges: - if edge.dst_conn == internal: - used = True - break - if used: - continue - elif not isinstance(arg, SymbolExpr): - memlet = dace.Memlet.from_array( - arg.value.data, self.context.body.arrays[arg.value.data] - ) - self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - - memlet = dace.Memlet(data=result_access.data, subset="0") - self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) - - return [ValueExpr(result_access, result_type)] - - -def is_scan(node: itir.Node) -> bool: - return isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="scan") - - -def closure_to_tasklet_sdfg( - node: itir.StencilClosure, - offset_provider_type: common.OffsetProviderType, - domain: dict[str, str], - inputs: Sequence[tuple[str, ts.TypeSpec]], - connectivities: Sequence[tuple[dace.ndarray, str]], - use_field_canonical_representation: bool, -) -> tuple[Context, Sequence[ValueExpr]]: - body = dace.SDFG("tasklet_toplevel") - body.debuginfo = dace_utils.debug_info(node) - state = body.add_state("tasklet_toplevel_entry", True) - symbol_map: dict[str, TaskletExpr] = {} - - idx_accesses = {} - for dim, idx in domain.items(): - name = f"{idx}_value" - body.add_scalar(name, dtype=_INDEX_DTYPE, transient=True) - tasklet = state.add_tasklet( - f"get_{dim}", set(), {"value"}, f"value = {idx}", debuginfo=body.debuginfo - ) - access = state.add_access(name, debuginfo=body.debuginfo) - idx_accesses[dim] = access - state.add_edge(tasklet, "value", access, None, dace.Memlet(data=name, subset="0")) - for name, ty in inputs: - if isinstance(ty, ts.FieldType): - ndim = len(ty.dims) - shape, strides = new_array_symbols(name, ndim) - dims = [dim.value for dim in ty.dims] - dtype = dace_utils.as_dace_type(ty.dtype) - body.add_array(name, shape=shape, strides=strides, dtype=dtype) - field = state.add_access(name, debuginfo=body.debuginfo) - indices = {dim: idx_accesses[dim] for dim in domain.keys()} - symbol_map[name] = IteratorExpr(field, indices, dtype, dims) - else: - assert isinstance(ty, ts.ScalarType) - dtype = dace_utils.as_dace_type(ty) - body.add_scalar(name, dtype=dtype) - symbol_map[name] = ValueExpr(state.add_access(name, debuginfo=body.debuginfo), dtype) - for arr, name in connectivities: - shape, strides = new_array_symbols(name, ndim=2) - body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) - - context = Context(body, state, symbol_map) - translator = PythonTaskletCodegen( - offset_provider_type, context, use_field_canonical_representation - ) - - args = [itir.SymRef(id=name) for name, _ in inputs] - if is_scan(node.stencil): - stencil = cast(FunCall, node.stencil) - assert isinstance(stencil.args[0], Lambda) - lambda_node = itir.Lambda( - expr=stencil.args[0].expr, params=stencil.args[0].params, location=node.location - ) - fun_node = itir.FunCall(fun=lambda_node, args=args, location=node.location) - else: - fun_node = itir.FunCall(fun=node.stencil, args=args, location=node.location) - - results = translator.visit(fun_node) - for r in results: - context.body.arrays[r.value.data].transient = False - - return context, results diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py deleted file mode 100644 index 72bb32f003..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ /dev/null @@ -1,149 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import itertools -from typing import Any - -import dace - -import gt4py.next.iterator.ir as itir -from gt4py import eve -from gt4py.next import common -from gt4py.next.ffront import fbuiltins as gtx_fbuiltins -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils - - -def get_used_connectivities( - node: itir.Node, offset_provider_type: common.OffsetProviderType -) -> dict[str, common.NeighborConnectivityType]: - connectivities = dace_utils.filter_connectivity_types(offset_provider_type) - offset_dims = set(eve.walk_values(node).if_isinstance(itir.OffsetLiteral).getattr("value")) - return {offset: connectivities[offset] for offset in offset_dims if offset in connectivities} - - -def map_nested_sdfg_symbols( - parent_sdfg: dace.SDFG, nested_sdfg: dace.SDFG, array_mapping: dict[str, dace.Memlet] -) -> dict[str, str]: - symbol_mapping: dict[str, str] = {} - for param, arg in array_mapping.items(): - arg_array = parent_sdfg.arrays[arg.data] - param_array = nested_sdfg.arrays[param] - if not isinstance(param_array, dace.data.Scalar): - assert len(arg.subset.size()) == len(param_array.shape) - for arg_shape, param_shape in zip(arg.subset.size(), param_array.shape): - if isinstance(param_shape, dace.symbol): - symbol_mapping[str(param_shape)] = str(arg_shape) - assert len(arg_array.strides) == len(param_array.strides) - for arg_stride, param_stride in zip(arg_array.strides, param_array.strides): - if isinstance(param_stride, dace.symbol): - symbol_mapping[str(param_stride)] = str(arg_stride) - else: - assert arg.subset.num_elements() == 1 - for sym in nested_sdfg.free_symbols: - if str(sym) not in symbol_mapping: - symbol_mapping[str(sym)] = str(sym) - return symbol_mapping - - -def add_mapped_nested_sdfg( - state: dace.SDFGState, - map_ranges: dict[str, str | dace.subsets.Subset] | list[tuple[str, str | dace.subsets.Subset]], - inputs: dict[str, dace.Memlet], - outputs: dict[str, dace.Memlet], - sdfg: dace.SDFG, - symbol_mapping: dict[str, Any] | None = None, - schedule: Any = dace.dtypes.ScheduleType.Default, - unroll_map: bool = False, - location: Any = None, - debuginfo: Any = None, - input_nodes: dict[str, dace.nodes.AccessNode] | None = None, - output_nodes: dict[str, dace.nodes.AccessNode] | None = None, -) -> tuple[dace.nodes.NestedSDFG, dace.nodes.MapEntry, dace.nodes.MapExit]: - if not symbol_mapping: - symbol_mapping = {sym: sym for sym in sdfg.free_symbols} - - nsdfg_node = state.add_nested_sdfg( - sdfg, - None, - set(inputs.keys()), - set(outputs.keys()), - symbol_mapping, - name=sdfg.name, - schedule=schedule, - location=location, - debuginfo=debuginfo, - ) - - map_entry, map_exit = state.add_map( - f"{sdfg.name}_map", map_ranges, schedule, unroll_map, debuginfo - ) - - if input_nodes is None: - input_nodes = { - memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) - for name, memlet in inputs.items() - } - if output_nodes is None: - output_nodes = { - memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) - for name, memlet in outputs.items() - } - if not inputs: - state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) - for name, memlet in inputs.items(): - state.add_memlet_path( - input_nodes[memlet.data], - map_entry, - nsdfg_node, - memlet=memlet, - src_conn=None, - dst_conn=name, - propagate=True, - ) - if not outputs: - state.add_edge(nsdfg_node, None, map_exit, None, dace.Memlet()) - for name, memlet in outputs.items(): - state.add_memlet_path( - nsdfg_node, - map_exit, - output_nodes[memlet.data], - memlet=memlet, - src_conn=name, - dst_conn=None, - propagate=True, - ) - - return nsdfg_node, map_entry, map_exit - - -def unique_name(prefix): - unique_id = getattr(unique_name, "_unique_id", 0) # static variable - setattr(unique_name, "_unique_id", unique_id + 1) # noqa: B010 [set-attr-with-constant] - - return f"{prefix}_{unique_id}" - - -def unique_var_name(): - return unique_name("_var") - - -def new_array_symbols(name: str, ndim: int) -> tuple[list[dace.symbol], list[dace.symbol]]: - dtype = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) - shape = [dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) for i in range(ndim)] - strides = [ - dace.symbol(dace_utils.field_stride_symbol_name(name, i), dtype) for i in range(ndim) - ] - return shape, strides - - -def flatten_list(node_list: list[Any]) -> list[Any]: - return list( - itertools.chain.from_iterable( - [flatten_list(e) if isinstance(e, list) else [e] for e in node_list] - ) - ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py deleted file mode 100644 index 653ed4719d..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ /dev/null @@ -1,150 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import dataclasses -import functools -from typing import Callable, Optional, Sequence - -import dace -import factory - -from gt4py._core import definitions as core_defs -from gt4py.next import common, config -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import program_to_fencil -from gt4py.next.otf import languages, recipes, stages, step_types, workflow -from gt4py.next.otf.binding import interface -from gt4py.next.otf.languages import LanguageSettings -from gt4py.next.program_processors.runners.dace_common import workflow as dace_workflow -from gt4py.next.type_system import type_specifications as ts - -from . import build_sdfg_from_itir - - -@dataclasses.dataclass(frozen=True) -class DaCeTranslator( - workflow.ChainableWorkflowMixin[ - stages.CompilableProgram, stages.ProgramSource[languages.SDFG, languages.LanguageSettings] - ], - step_types.TranslationStep[languages.SDFG, languages.LanguageSettings], -): - auto_optimize: bool = False - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - symbolic_domain_sizes: Optional[dict[str, str]] = None - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None - use_field_canonical_representation: bool = False - - def _language_settings(self) -> languages.LanguageSettings: - return languages.LanguageSettings( - formatter_key="", formatter_style="", file_extension="sdfg" - ) - - def generate_sdfg( - self, - program: itir.FencilDefinition, - arg_types: Sequence[ts.TypeSpec], - offset_provider_type: common.OffsetProviderType, - column_axis: Optional[common.Dimension], - ) -> dace.SDFG: - on_gpu = ( - True - if self.device_type in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM] - else False - ) - - return build_sdfg_from_itir( - program, - arg_types, - offset_provider_type=offset_provider_type, - auto_optimize=self.auto_optimize, - on_gpu=on_gpu, - column_axis=column_axis, - symbolic_domain_sizes=self.symbolic_domain_sizes, - temporary_extraction_heuristics=self.temporary_extraction_heuristics, - load_sdfg_from_file=False, - save_sdfg=False, - use_field_canonical_representation=self.use_field_canonical_representation, - ) - - def __call__( - self, inp: stages.CompilableProgram - ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: - """Generate DaCe SDFG file from the ITIR definition.""" - program: itir.FencilDefinition | itir.Program = inp.data - - if isinstance(program, itir.Program): - program = program_to_fencil.program_to_fencil(program) - - sdfg = self.generate_sdfg( - program, - inp.args.args, - common.offset_provider_to_type(inp.args.offset_provider), - inp.args.column_axis, - ) - - param_types = tuple( - interface.Parameter(param, arg) for param, arg in zip(sdfg.arg_names, inp.args.args) - ) - - module: stages.ProgramSource[languages.SDFG, languages.LanguageSettings] = ( - stages.ProgramSource( - entry_point=interface.Function(program.id, param_types), - source_code=sdfg.to_json(), - library_deps=tuple(), - language=languages.SDFG, - language_settings=self._language_settings(), - implicit_domain=inp.data.implicit_domain, - ) - ) - return module - - -class DaCeTranslationStepFactory(factory.Factory): - class Meta: - model = DaCeTranslator - - -def _no_bindings(inp: stages.ProgramSource) -> stages.CompilableSource: - return stages.CompilableSource(program_source=inp, binding_source=None) - - -class DaCeWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFCompileWorkflow - - class Params: - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( - lambda: config.CMAKE_BUILD_TYPE - ) - use_field_canonical_representation: bool = False - - translation = factory.SubFactory( - DaCeTranslationStepFactory, - device_type=factory.SelfAttribute("..device_type"), - use_field_canonical_representation=factory.SelfAttribute( - "..use_field_canonical_representation" - ), - ) - bindings = _no_bindings - compilation = factory.SubFactory( - dace_workflow.DaCeCompilationStepFactory, - cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), - cmake_build_type=factory.SelfAttribute("..cmake_build_type"), - ) - decoration = factory.LazyAttribute( - lambda o: functools.partial( - dace_workflow.convert_args, - device=o.device_type, - use_field_canonical_representation=o.use_field_canonical_representation, - ) - ) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 349d3e9f70..1593ab3ba6 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -11,11 +11,10 @@ import dataclasses import enum import importlib -from typing import Final, Optional, Protocol import pytest -from gt4py.next import allocators as next_allocators, backend as next_backend +from gt4py.next import allocators as next_allocators # Skip definitions @@ -67,10 +66,10 @@ class EmbeddedIds(_PythonObjectIdMixin, str, enum.Enum): class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): - DACE_CPU = "gt4py.next.program_processors.runners.dace.itir_cpu" - DACE_GPU = "gt4py.next.program_processors.runners.dace.itir_gpu" - GTIR_DACE_CPU = "gt4py.next.program_processors.runners.dace.gtir_cpu" - GTIR_DACE_GPU = "gt4py.next.program_processors.runners.dace.gtir_gpu" + DACE_CPU = "gt4py.next.program_processors.runners.dace.run_dace_cpu" + DACE_GPU = "gt4py.next.program_processors.runners.dace.run_dace_gpu" + DACE_CPU_NO_OPT = "gt4py.next.program_processors.runners.dace.run_dace_cpu_noopt" + DACE_GPU_NO_OPT = "gt4py.next.program_processors.runners.dace.run_dace_gpu_noopt" class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): @@ -139,21 +138,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ] -DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ - (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), - (USES_IR_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCALAR_IN_DOMAIN_AND_FO, XFAIL, UNSUPPORTED_MESSAGE), - (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), -] -GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ +DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), @@ -189,10 +174,16 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): BACKEND_SKIP_TEST_MATRIX = { EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, - OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.GTIR_DACE_CPU: GTIR_DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.GTIR_DACE_GPU: GTIR_DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST + + [ + (ALL, SKIP, UNSUPPORTED_MESSAGE) + ], # TODO(edopao): Enable once the optimization pipeline is merged + OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST + + [ + (ALL, SKIP, UNSUPPORTED_MESSAGE) + ], # TODO(edopao): Enable once the optimization pipeline is merged. + OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index f5646c71e4..08904c06f3 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -6,14 +6,11 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from types import ModuleType -from typing import Optional - import numpy as np import pytest import gt4py.next as gtx -from gt4py.next import backend as next_backend, common +from gt4py.next import allocators as gtx_allocators, common as gtx_common from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case, unstructured_case @@ -34,24 +31,22 @@ try: import dace - from gt4py.next.program_processors.runners.dace import ( - itir_cpu as run_dace_cpu, - itir_gpu as run_dace_gpu, - ) except ImportError: dace: Optional[ModuleType] = None # type:ignore[no-redef] - run_dace_cpu: Optional[next_backend.Backend] = None - run_dace_gpu: Optional[next_backend.Backend] = None pytestmark = pytest.mark.requires_dace def test_sdfgConvertible_laplap(cartesian_case): - # TODO(kotsaloscv): Temporary solution until the `requires_dace` marker is fully functional - if cartesian_case.backend not in [run_dace_cpu, run_dace_gpu]: + if not cartesian_case.backend or "dace" not in cartesian_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") - if cartesian_case.backend == run_dace_gpu: + # TODO(ricoh): enable test after adding GTIR support + pytest.skip("DaCe SDFGConvertible interface does not support GTIR program.") + + allocator, backend = unstructured_case.allocator, unstructured_case.backend + + if gtx_allocators.is_field_allocator_factory_for(allocator, gtx_allocators.CUPY_DEVICE): import cupy as xp else: import numpy as xp @@ -64,13 +59,13 @@ def test_sdfgConvertible_laplap(cartesian_case): def sdfg(): tmp_field = xp.empty_like(out_field) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( - cartesian_case.backend - ).with_connectivities(common.offset_provider_to_type(cartesian_case.offset_provider))( + backend + ).with_connectivities(gtx_common.offset_provider_to_type(cartesian_case.offset_provider))( in_field, tmp_field ) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( - cartesian_case.backend - ).with_connectivities(common.offset_provider_to_type(cartesian_case.offset_provider))( + backend + ).with_connectivities(gtx_common.offset_provider_to_type(cartesian_case.offset_provider))( tmp_field, out_field ) @@ -94,13 +89,15 @@ def testee(a: gtx.Field[gtx.Dims[Vertex], gtx.float64], b: gtx.Field[gtx.Dims[Ed @pytest.mark.uses_unstructured_shift def test_sdfgConvertible_connectivities(unstructured_case): - # TODO(kotsaloscv): Temporary solution until the `requires_dace` marker is fully functional - if unstructured_case.backend not in [run_dace_cpu, run_dace_gpu]: + if not unstructured_case.backend or "dace" not in unstructured_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") + # TODO(ricoh): enable test after adding GTIR support + pytest.skip("DaCe SDFGConvertible interface does not support GTIR program.") + allocator, backend = unstructured_case.allocator, unstructured_case.backend - if backend == run_dace_gpu: + if gtx_allocators.is_field_allocator_factory_for(allocator, gtx_allocators.CUPY_DEVICE): import cupy as xp dace_storage_type = dace.StorageType.GPU_Global diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 794dd06709..1147f4bc3e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -66,11 +66,11 @@ def __gt_allocator__( marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), ), pytest.param( - next_tests.definitions.OptionalProgramBackendId.GTIR_DACE_CPU, + next_tests.definitions.OptionalProgramBackendId.DACE_CPU_NO_OPT, marks=pytest.mark.requires_dace, ), pytest.param( - next_tests.definitions.OptionalProgramBackendId.GTIR_DACE_GPU, + next_tests.definitions.OptionalProgramBackendId.DACE_GPU_NO_OPT, marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), ), ], From a936761243319dbfa2c94e28222dc6962a96f14a Mon Sep 17 00:00:00 2001 From: SF-N Date: Wed, 4 Dec 2024 15:05:43 +0100 Subject: [PATCH 067/178] bug[next]: Fix astype for local fields (#1761) Fix astype by calling `_map` additionally and add corresponding tests Co-authored-by: Edoardo Paone --- src/gt4py/next/ffront/foast_to_gtir.py | 6 +-- .../dace_fieldview/gtir_python_codegen.py | 46 ++++++++++++------- .../ffront_tests/test_execution.py | 16 +++++++ .../ffront_tests/test_foast_to_gtir.py | 16 ++++++- 4 files changed, 61 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 2c2971f49a..3c65695aec 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -359,11 +359,7 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall: - if isinstance(t[0], ts.FieldType): - return im.cast_as_fieldop(str(new_type))(expr) - else: - assert isinstance(t[0], ts.ScalarType) - return im.call("cast_")(expr, str(new_type)) + return _map(im.lambda_("val")(im.call("cast_")("val", str(new_type))), (expr,), t) if not isinstance(node.type, ts.TupleType): # to keep the IR simpler return create_cast(obj, (node.args[0].type,)) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index 6aee33c56e..4bdb602f5f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -118,29 +118,41 @@ class PythonCodegen(codegen.TemplatedGenerator): as in the case of field domain definitions, for sybolic array shape and map range. """ - SymRef = as_fmt("{id}") Literal = as_fmt("{value}") - def _visit_deref(self, node: gtir.FunCall) -> str: - assert len(node.args) == 1 - if isinstance(node.args[0], gtir.SymRef): - return self.visit(node.args[0]) - raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - - def visit_FunCall(self, node: gtir.FunCall) -> str: - if cpm.is_call_to(node, "deref"): - return self._visit_deref(node) + def visit_FunCall(self, node: gtir.FunCall, args_map: dict[str, gtir.Node]) -> str: + if isinstance(node.fun, gtir.Lambda): + # update the mapping from lambda parameters to corresponding argument expressions + lambda_args_map = args_map | { + p.id: arg for p, arg in zip(node.fun.params, node.args, strict=True) + } + return self.visit(node.fun.expr, args_map=lambda_args_map) + elif cpm.is_call_to(node, "deref"): + assert len(node.args) == 1 + if not isinstance(node.args[0], gtir.SymRef): + # shift expressions are not expected in this visitor context + raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") + return self.visit(node.args[0], args_map=args_map) elif isinstance(node.fun, gtir.SymRef): - args = self.visit(node.args) + args = self.visit(node.args, args_map=args_map) builtin_name = str(node.fun.id) return format_builtin(builtin_name, *args) raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") + def visit_SymRef(self, node: gtir.SymRef, args_map: dict[str, gtir.Node]) -> str: + symbol = str(node.id) + if symbol in args_map: + return self.visit(args_map[symbol], args_map=args_map) + return symbol + -get_source = PythonCodegen.apply -""" -Specialized visit method for symbolic expressions. +def get_source(node: gtir.Node) -> str: + """ + Specialized visit method for symbolic expressions. -Returns: - A string containing the Python code corresponding to a symbolic expression -""" + The visitor uses `args_map` to map lambda parameters to the corresponding argument expressions. + + Returns: + A string containing the Python code corresponding to a symbolic expression + """ + return PythonCodegen.apply(node, args_map={}) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 0d994d1b22..4eed7f5cde 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -438,6 +438,22 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: ) +def test_astype_int_local_field(unstructured_case): + @gtx.field_operator + def testee(a: gtx.Field[[Vertex], np.float64]) -> gtx.Field[[Edge], int64]: + tmp = astype(a(E2V), int64) + return neighbor_sum(tmp, axis=E2VDim) + + e2v_table = unstructured_case.offset_provider["E2V"].ndarray + + cases.verify_with_default_data( + unstructured_case, + testee, + ref=lambda a: np.sum(a.astype(int64)[e2v_table], axis=1, initial=0), + comparison=lambda a, b: np.all(a == b), + ) + + @pytest.mark.uses_tuple_returns def test_astype_on_tuples(cartesian_case): @gtx.field_operator diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 516890ea46..59a8dc961b 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -283,9 +283,22 @@ def foo(a: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) + lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) reference = im.cast_as_fieldop("int32")("a") + assert lowered_inlined.expr == reference + + +def test_astype_local_field(): + def foo(a: gtx.Field[gtx.Dims[Vertex, V2EDim], float64]): + return astype(a, int32) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.op_as_fieldop(im.map_(im.lambda_("val")(im.call("cast_")("val", "int32"))))("a") + assert lowered.expr == reference @@ -295,10 +308,11 @@ def foo(a: float64): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) + lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) reference = im.call("cast_")("a", "int32") - assert lowered.expr == reference + assert lowered_inlined.expr == reference def test_astype_tuple(): From 10adb2c7b3d26a31f9580218d2f8edc6fe67abbf Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 4 Dec 2024 15:09:45 +0100 Subject: [PATCH 068/178] test[next]: cleanup test markers (#1767) - Remove some test markers related to ITIR. - Fuse `uses_index_builtin` marker into `uses_index_fields`. --- pyproject.toml | 3 --- tests/next_tests/definitions.py | 5 ----- .../feature_tests/ffront_tests/test_execution.py | 4 ---- .../feature_tests/iterator_tests/test_program.py | 7 ++----- 4 files changed, 2 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1e24094fa2..e859c9b4f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -240,17 +240,14 @@ markers = [ 'requires_atlas: tests that require `atlas4py` bindings package', 'requires_dace: tests that require `dace` package', 'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)', - 'starts_from_gtir_program: tests that require backend to start lowering from GTIR program', 'uses_applied_shifts: tests that require backend support for applied-shifts', 'uses_constant_fields: tests that require backend support for constant fields', 'uses_dynamic_offsets: tests that require backend support for dynamic offsets', 'uses_floordiv: tests that require backend support for floor division', 'uses_if_stmts: tests that require backend support for if-statements', 'uses_index_fields: tests that require backend support for index fields', - 'uses_lift_expressions: tests that require backend support for lift expressions', 'uses_negative_modulo: tests that require backend support for modulo on negative numbers', 'uses_origin: tests that require backend support for domain origin', - 'uses_reduction_over_lift_expressions: tests that require backend support for reduction over lift expressions', 'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields', 'uses_scan: tests that uses scan', 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 1593ab3ba6..80b8f4f39b 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -85,8 +85,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): # to avoid needing to mark all tests. ALL = "all" REQUIRES_ATLAS = "requires_atlas" -# TODO(havogt): Remove, skipped during refactoring to GTIR -STARTS_FROM_GTIR_PROGRAM = "starts_from_gtir_program" USES_APPLIED_SHIFTS = "uses_applied_shifts" USES_CONSTANT_FIELDS = "uses_constant_fields" USES_DYNAMIC_OFFSETS = "uses_dynamic_offsets" @@ -94,10 +92,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_IF_STMTS = "uses_if_stmts" USES_IR_IF_STMTS = "uses_ir_if_stmts" USES_INDEX_FIELDS = "uses_index_fields" -USES_LIFT_EXPRESSIONS = "uses_lift_expressions" USES_NEGATIVE_MODULO = "uses_negative_modulo" USES_ORIGIN = "uses_origin" -USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions" USES_SCAN = "uses_scan" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" USES_SCAN_IN_STENCIL = "uses_scan_in_stencil" @@ -117,7 +113,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values" USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo" CHECKS_SPECIFIC_ERROR = "checks_specific_error" -USES_INDEX_BUILTIN = "uses_index_builtin" # Skip messages (available format keys: 'marker', 'backend') UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend" diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 4eed7f5cde..9de4449ac2 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -291,7 +291,6 @@ def testee(a: tuple[cases.VField, cases.EField]) -> cases.VField: ) -@pytest.mark.uses_index_fields @pytest.mark.uses_cartesian_shift def test_scalar_arg_with_field(cartesian_case): @gtx.field_operator @@ -602,7 +601,6 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField: @pytest.mark.uses_unstructured_shift -@pytest.mark.uses_reduction_over_lift_expressions def test_nested_reduction(unstructured_case): @gtx.field_operator def testee(a: cases.VField) -> cases.VField: @@ -722,7 +720,6 @@ def simple_scan_operator(carry: float) -> float: @pytest.mark.uses_scan -@pytest.mark.uses_lift_expressions @pytest.mark.uses_scan_nested def test_solve_triag(cartesian_case): @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0)) @@ -804,7 +801,6 @@ def testee( @pytest.mark.uses_constant_fields @pytest.mark.uses_unstructured_shift -@pytest.mark.uses_reduction_over_lift_expressions def test_ternary_builtin_neighbor_sum(unstructured_case): @gtx.field_operator def testee(a: cases.EField, b: cases.EField) -> cases.VField: diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index f6fd0a48d0..c79f8dbb6b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -42,7 +42,6 @@ def copy_program(inp, out, size): ) -@pytest.mark.starts_from_gtir_program def test_prog(program_processor): program_processor, validate = program_processor @@ -64,8 +63,7 @@ def index_program_simple(out, size): ) -@pytest.mark.starts_from_gtir_program -@pytest.mark.uses_index_builtin +@pytest.mark.uses_index_fields def test_index_builtin(program_processor): program_processor, validate = program_processor @@ -88,8 +86,7 @@ def index_program_shift(out, size): ) -@pytest.mark.starts_from_gtir_program -@pytest.mark.uses_index_builtin +@pytest.mark.uses_index_fields def test_index_builtin_shift(program_processor): program_processor, validate = program_processor From ea616597483aff2b29a6195a7a11071f907dedbb Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 4 Dec 2024 15:36:24 +0100 Subject: [PATCH 069/178] test[next]: Disable iterator tests on DaCe GTIR backend (#1768) There are 2 places where the `program_processor` fixture used in tests is configured: ``` tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py tests/next_tests/unit_tests/conftest.py ``` and there are four DaCe backends: ``` DACE_CPU = "gt4py.next.program_processors.runners.dace.run_dace_cpu" DACE_GPU = "gt4py.next.program_processors.runners.dace.run_dace_gpu" DACE_CPU_NO_OPT = "gt4py.next.program_processors.runners.dace.run_dace_cpu_noopt" DACE_GPU_NO_OPT = "gt4py.next.program_processors.runners.dace.run_dace_gpu_noopt" ``` The `DACE_CPU` and `DACE_GPU` backends will be the default backends, that also apply the DaCe optimization pipeline. However, these backends are disabled for now because we are awaiting #1639. The `DACE_CPU_NO_OPT` and `DACE_GPU_NO_OPT` apply the lowering to SDFG but do not run the optimization pipeline. These backends are currently enabled in GT4Py tests. However, we observed failures in some iterator tests controlled by the `program_processor` fixture in `tests/next_tests/unit_tests/conftest.py`, once `DACE_CPU` is enabled. In this PR, we are disabling such tests: we will address these issues in a separate PR. --- tests/next_tests/unit_tests/conftest.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index f1269f1ed8..99bc44efa7 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -60,15 +60,6 @@ def _program_processor(request) -> tuple[ProgramProcessor, bool]: (next_tests.definitions.ProgramFormatterId.LISP_FORMATTER, False), (next_tests.definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), (next_tests.definitions.ProgramFormatterId.GTFN_CPP_FORMATTER, False), - pytest.param( - (next_tests.definitions.OptionalProgramBackendId.DACE_CPU, True), - marks=pytest.mark.requires_dace, - ), - # TODO(havogt): update tests to use proper allocation - # pytest.param( - # (next_tests.definitions.OptionalProgramBackendId.DACE_GPU, True), - # marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), - # ), ], ids=lambda p: p[0].short_id() if p[0] is not None else "None", ) From 33c5ba33e07923ed8830a30b2907598d1a32867d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Wed, 4 Dec 2024 15:50:01 +0100 Subject: [PATCH 070/178] feat[dace]: Updated DaCe Transformations (#1639) The [initial version](https://github.com/GridTools/gt4py/pull/1594) of the optimization pipeline only contained a rough draft. Currently this PR contains a copy of the map fusion transformations from DaCe that are currently under [review](https://github.com/spcl/dace/pull/1629). As soon as that PR is merged and DaCe was updated in GT4Py these files will be deleted. This PR collects some general improvements: - [x] More liberal `LoopBlocking` transformation (with tests). - [x] Incorporate `MapFusionParallel` - [x] Using of `can_be_applied_to()` as soon as DaCe is updated (`TrivialGPUMapElimination`, `SerialMapPromoter`). - [x] Looking at strides that the Lowering generates. (Partly done) However, it still uses MapFusion implementation that ships with GT4Py and not the one in DaCe. Note: Because of commit 60e4226 this PR must be merged after [PR1768](https://github.com/GridTools/gt4py/pull/1768). --- .../transformations/__init__.py | 42 +- .../{auto_opt.py => auto_optimize.py} | 271 ++--- .../transformations/gpu_utils.py | 666 ++++++++--- .../transformations/local_double_buffering.py | 393 +++++++ .../transformations/loop_blocking.py | 284 ++--- .../transformations/map_fusion_helper.py | 882 ++++++++------ .../transformations/map_fusion_parallel.py | 170 +++ .../transformations/map_fusion_serial.py | 1007 ++++++++++++++++ .../transformations/map_orderer.py | 144 ++- .../transformations/map_promoter.py | 42 +- .../transformations/map_serial_fusion.py | 483 -------- .../transformations/simplify.py | 1010 +++++++++++++++++ .../dace_fieldview/transformations/strides.py | 99 ++ .../dace_fieldview/transformations/util.py | 317 ++++-- tests/next_tests/definitions.py | 10 +- .../transformation_tests/conftest.py | 4 +- .../test_constant_substitution.py | 142 +++ .../test_create_local_double_buffering.py | 239 ++++ .../test_distributed_buffer_relocator.py | 84 ++ .../test_global_self_copy_elimination.py | 148 +++ .../transformation_tests/test_gpu_utils.py | 108 +- .../test_loop_blocking.py | 508 ++++++++- .../test_map_buffer_elimination.py | 264 +++++ .../transformation_tests/test_map_fusion.py | 124 +- .../transformation_tests/test_map_order.py | 100 ++ .../test_move_tasklet_into_map.py | 164 +++ .../test_serial_map_promoter.py | 4 +- .../dace_tests/transformation_tests/util.py | 6 +- 28 files changed, 6118 insertions(+), 1597 deletions(-) rename src/gt4py/next/program_processors/runners/dace_fieldview/transformations/{auto_opt.py => auto_optimize.py} (67%) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/local_double_buffering.py create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_parallel.py create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_serial.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 8852dd6d2d..2232bcef01 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -12,32 +12,56 @@ that explains the general structure and requirements on the SDFGs. """ -from .auto_opt import ( +from .auto_optimize import gt_auto_optimize +from .gpu_utils import ( + GPUSetBlockSize, + gt_gpu_transform_non_standard_memlet, + gt_gpu_transformation, + gt_set_gpu_blocksize, +) +from .local_double_buffering import gt_create_local_double_buffering +from .loop_blocking import LoopBlocking +from .map_fusion_parallel import MapFusionParallel +from .map_fusion_serial import MapFusionSerial +from .map_orderer import MapIterationOrder, gt_set_iteration_order +from .map_promoter import SerialMapPromoter +from .simplify import ( GT_SIMPLIFY_DEFAULT_SKIP_SET, - gt_auto_optimize, + GT4PyGlobalSelfCopyElimination, + GT4PyMapBufferElimination, + GT4PyMoveTaskletIntoMap, gt_inline_nested_sdfg, - gt_set_iteration_order, + gt_reduce_distributed_buffering, gt_simplify, + gt_substitute_compiletime_symbols, ) -from .gpu_utils import GPUSetBlockSize, gt_gpu_transformation, gt_set_gpu_blocksize -from .loop_blocking import LoopBlocking -from .map_orderer import MapIterationOrder -from .map_promoter import SerialMapPromoter -from .map_serial_fusion import SerialMapFusion +from .strides import gt_change_transient_strides +from .util import gt_find_constant_arguments, gt_make_transients_persistent __all__ = [ "GT_SIMPLIFY_DEFAULT_SKIP_SET", "GPUSetBlockSize", + "GT4PyGlobalSelfCopyElimination", + "GT4PyMoveTaskletIntoMap", + "GT4PyMapBufferElimination", "LoopBlocking", "MapIterationOrder", - "SerialMapFusion", + "MapFusionParallel", + "MapFusionSerial", "SerialMapPromoter", "SerialMapPromoterGPU", "gt_auto_optimize", + "gt_change_transient_strides", + "gt_create_local_double_buffering", "gt_gpu_transformation", "gt_inline_nested_sdfg", "gt_set_iteration_order", "gt_set_gpu_blocksize", "gt_simplify", + "gt_make_transients_persistent", + "gt_reduce_distributed_buffering", + "gt_find_constant_arguments", + "gt_substitute_compiletime_symbols", + "gt_gpu_transform_non_standard_memlet", ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py similarity index 67% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py index e070cdfe4e..bc1d21ca05 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py @@ -8,10 +8,10 @@ """Fast access to the auto optimization on DaCe.""" -from typing import Any, Final, Iterable, Optional, Sequence +from typing import Any, Optional, Sequence, Union import dace -from dace.transformation import dataflow as dace_dataflow, passes as dace_passes +from dace.transformation import dataflow as dace_dataflow from dace.transformation.auto import auto_optimize as dace_aoptimize from gt4py.next import common as gtx_common @@ -20,146 +20,12 @@ ) -GT_SIMPLIFY_DEFAULT_SKIP_SET: Final[set[str]] = {"ScalarToSymbolPromotion", "ConstantPropagation"} -"""Set of simplify passes `gt_simplify()` skips by default. - -The following passes are included: -- `ScalarToSymbolPromotion`: The lowering has sometimes to turn a scalar into a - symbol or vice versa and at a later point to invert this again. However, this - pass has some problems with this pattern so for the time being it is disabled. -- `ConstantPropagation`: Same reasons as `ScalarToSymbolPromotion`. -""" - - -def gt_simplify( - sdfg: dace.SDFG, - validate: bool = True, - validate_all: bool = False, - skip: Optional[Iterable[str]] = None, -) -> Any: - """Performs simplifications on the SDFG in place. - - Instead of calling `sdfg.simplify()` directly, you should use this function, - as it is specially tuned for GridTool based SDFGs. - - This function runs the DaCe simplification pass, but the following passes are - replaced: - - `InlineSDFGs`: Instead `gt_inline_nested_sdfg()` will be called. - - Furthermore, by default, or if `None` is passed fro `skip` the passes listed in - `GT_SIMPLIFY_DEFAULT_SKIP_SET` will be skipped. - - Args: - sdfg: The SDFG to optimize. - validate: Perform validation after the pass has run. - validate_all: Perform extensive validation. - skip: List of simplify passes that should not be applied, defaults - to `GT_SIMPLIFY_DEFAULT_SKIP_SET`. - """ - # Ensure that `skip` is a `set` - skip = GT_SIMPLIFY_DEFAULT_SKIP_SET if skip is None else set(skip) - - if "InlineSDFGs" not in skip: - gt_inline_nested_sdfg( - sdfg=sdfg, - multistate=True, - permissive=False, - validate=validate, - validate_all=validate_all, - ) - - return dace_passes.SimplifyPass( - validate=validate, - validate_all=validate_all, - verbose=False, - skip=(skip | {"InlineSDFGs"}), - ).apply_pass(sdfg, {}) - - -def gt_set_iteration_order( - sdfg: dace.SDFG, - leading_dim: gtx_common.Dimension, - validate: bool = True, - validate_all: bool = False, -) -> Any: - """Set the iteration order of the Maps correctly. - - Modifies the order of the Map parameters such that `leading_dim` - is the fastest varying one, the order of the other dimensions in - a Map is unspecific. `leading_dim` should be the dimensions were - the stride is one. - - Args: - sdfg: The SDFG to process. - leading_dim: The leading dimensions. - validate: Perform validation during the steps. - validate_all: Perform extensive validation. - """ - return sdfg.apply_transformations_once_everywhere( - gtx_transformations.MapIterationOrder( - leading_dim=leading_dim, - ), - validate=validate, - validate_all=validate_all, - ) - - -def gt_inline_nested_sdfg( - sdfg: dace.SDFG, - multistate: bool = True, - permissive: bool = False, - validate: bool = True, - validate_all: bool = False, -) -> dace.SDFG: - """Perform inlining of nested SDFG into their parent SDFG. - - The function uses DaCe's `InlineSDFG` transformation, the same used in simplify. - However, before the inline transformation is run the function will run some - cleaning passes that allows inlining nested SDFGs. - As a side effect, the function will split stages into more states. - - Args: - sdfg: The SDFG that should be processed, will be modified in place and returned. - multistate: Allow inlining of multistate nested SDFG, defaults to `True`. - permissive: Be less strict on the accepted SDFGs. - validate: Perform validation after the transformation has finished. - validate_all: Performs extensive validation. - """ - first_iteration = True - i = 0 - while True: - print(f"ITERATION: {i}") - nb_preproccess = sdfg.apply_transformations_repeated( - [dace_dataflow.PruneSymbols, dace_dataflow.PruneConnectors], - validate=False, - validate_all=validate_all, - ) - if (nb_preproccess == 0) and (not first_iteration): - break - - # Create and configure the inline pass - inline_sdfg = dace_passes.InlineSDFGs() - inline_sdfg.progress = False - inline_sdfg.permissive = permissive - inline_sdfg.multistate = multistate - - # Apply the inline pass - nb_inlines = inline_sdfg.apply_pass(sdfg, {}) - - # Check result, if needed and test if we can stop - if validate_all or validate: - sdfg.validate() - if nb_inlines == 0: - break - first_iteration = False - - return sdfg - - def gt_auto_optimize( sdfg: dace.SDFG, gpu: bool, - leading_dim: Optional[gtx_common.Dimension] = None, + leading_dim: Optional[ + Union[str, gtx_common.Dimension, list[Union[str, gtx_common.Dimension]]] + ] = None, aggressive_fusion: bool = True, max_optimization_rounds_p2: int = 100, make_persistent: bool = True, @@ -169,6 +35,8 @@ def gt_auto_optimize( reuse_transients: bool = False, gpu_launch_bounds: Optional[int | str] = None, gpu_launch_factor: Optional[int] = None, + constant_symbols: Optional[dict[str, Any]] = None, + assume_pointwise: bool = True, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -184,6 +52,9 @@ def gt_auto_optimize( different aspects of the SDFG. The initial SDFG is assumed to have a very large number of rather simple Maps. + Note, because of how `gt_auto_optimizer()` works it is not save to call + it twice on the same SDFG. + 1. Some general simplification transformations, beyond classical simplify, are applied to the SDFG. 2. Tries to create larger kernels by fusing smaller ones, see @@ -223,20 +94,31 @@ def gt_auto_optimize( gpu_launch_bounds: Use this value as `__launch_bounds__` for _all_ GPU Maps. gpu_launch_factor: Use the number of threads times this value as `__launch_bounds__` for _all_ GPU Maps. + constant_symbols: Symbols listed in this `dict` will be replaced by the + respective value inside the SDFG. This might increase performance. + assume_pointwise: Assume that the SDFG has no risk for race condition in + global data access. See the `GT4PyMapBufferElimination` transformation for more. validate: Perform validation during the steps. validate_all: Perform extensive validation. + + Note: + For identifying symbols that can be treated as compile time constants + `gt_find_constant_arguments()` function can be used. + Todo: - - Make sure that `SDFG.simplify()` is not called indirectly, by temporarily - overwriting it with `gt_simplify()`. + - Update the description. The Phases are nice, but they have lost their + link to reality a little bit. + - Improve the determination of the strides and iteration order of the + transients. + - Set padding of transients, i.e. alignment, the DaCe datadescriptor + can do that. + - Handle nested SDFGs better. - Specify arguments to set the size of GPU thread blocks depending on the dimensions. I.e. be able to use a different size for 1D than 2D Maps. - - Add a parallel version of Map fusion. - Implement some model to further guide to determine what we want to fuse. Something along the line "Fuse if operational intensity goes up, but not if we have too much internal space (register pressure). - - Create a custom array elimination pass that honors rule 1. - - Check if a pipeline could be used to speed up some computations. """ device = dace.DeviceType.GPU if gpu else dace.DeviceType.CPU @@ -249,20 +131,25 @@ def gt_auto_optimize( # to internal serial maps, such that they do not block fusion? # Phase 1: Initial Cleanup - gt_simplify( + gtx_transformations.gt_simplify( sdfg=sdfg, validate=validate, validate_all=validate_all, ) + gtx_transformations.gt_reduce_distributed_buffering(sdfg) + + if constant_symbols: + gtx_transformations.gt_substitute_compiletime_symbols( + sdfg=sdfg, + repl=constant_symbols, + validate=validate, + validate_all=validate_all, + ) + gtx_transformations.gt_simplify(sdfg) + sdfg.apply_transformations_repeated( [ dace_dataflow.TrivialMapElimination, - # TODO(phimuell): The transformation are interesting, but they have - # a bug as they assume that they are not working inside a map scope. - # Before we use them we have to fix them. - # https://chat.spcl.inf.ethz.ch/spcl/pl/8mtgtqjb378hfy7h9a96sy3nhc - # dace_dataflow.MapReduceFusion, - # dace_dataflow.MapWCRFusion, ], validate=validate, validate_all=validate_all, @@ -278,28 +165,62 @@ def gt_auto_optimize( validate_all=validate_all, ) - # Phase 3: Optimizing the kernels, i.e. the larger maps, themselves. - # Currently this only applies fusion inside Maps. + # After we have created big kernels, we will perform some post cleanup. + gtx_transformations.gt_reduce_distributed_buffering(sdfg) sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion( - only_inner_maps=True, - ), + [ + gtx_transformations.GT4PyMoveTaskletIntoMap, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=assume_pointwise), + ], validate=validate, validate_all=validate_all, ) - gt_simplify(sdfg) + + # TODO(phimuell): The `MapReduceFusion` transformation is interesting as + # it moves the initialization of the accumulator at the top, which allows + # further fusing of the accumulator loop. However the transformation has + # a bug, so we can not use it. Furthermore, I have looked at the assembly + # and the compiler is already doing that. + # https://chat.spcl.inf.ethz.ch/spcl/pl/8mtgtqjb378hfy7h9a96sy3nhc + + # After we have created large kernels we run `dace_dataflow.MapReduceFusion`. + + # Phase 3: Optimizing the kernels, i.e. the larger maps, themselves. + # Currently this only applies fusion inside Maps. + gtx_transformations.gt_simplify(sdfg) + while True: + nb_applied = sdfg.apply_transformations_repeated( + [ + gtx_transformations.MapFusionSerial( + only_inner_maps=True, + ), + gtx_transformations.MapFusionParallel( + only_inner_maps=True, + only_if_common_ancestor=False, # TODO(phimuell): Should we? + ), + ], + validate=validate, + validate_all=validate_all, + ) + if not nb_applied: + break + gtx_transformations.gt_simplify(sdfg) # Phase 4: Iteration Space # This essentially ensures that the stride 1 dimensions are handled # by the inner most loop nest (CPU) or x-block (GPU) if leading_dim is not None: - gt_set_iteration_order( + gtx_transformations.gt_set_iteration_order( sdfg=sdfg, leading_dim=leading_dim, validate=validate, validate_all=validate_all, ) + # We now ensure that point wise computations are properly double buffered. + # The main reason is to ensure that rule 3 of ADR18 is maintained. + gtx_transformations.gt_create_local_double_buffering(sdfg) + # Phase 5: Apply blocking if blocking_dim is not None: sdfg.apply_transformations_once_everywhere( @@ -342,9 +263,23 @@ def gt_auto_optimize( dace_aoptimize.set_fast_implementations(sdfg, device) # TODO(phimuell): Fix the bug, it uses the tile value and not the stack array value. dace_aoptimize.move_small_arrays_to_stack(sdfg) + + # Now we modify the strides. + gtx_transformations.gt_change_transient_strides(sdfg, gpu=gpu) + if make_persistent: - # TODO(phimuell): Allow to also to set the lifetime to `SDFG`. - dace_aoptimize.make_transients_persistent(sdfg, device) + gtx_transformations.gt_make_transients_persistent(sdfg=sdfg, device=device) + + if device == dace.DeviceType.GPU: + # NOTE: For unknown reasons the counterpart of the + # `gt_make_transients_persistent()` function in DaCe, resets the + # `wcr_nonatomic` property of every memlet, i.e. makes it atomic. + # However, it does this only for edges on the top level and on GPU. + # For compatibility with DaCe (and until we found out why) the GT4Py + # auto optimizer will emulate this behaviour. + for state in sdfg.states(): + for edge in state.edges(): + edge.data.wcr_nonatomic = False return sdfg @@ -395,9 +330,17 @@ def gt_auto_fuse_top_level_maps( # TODO(phimuell): Add parallel fusion transformation. Should it run after # or with the serial one? sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion( - only_toplevel_maps=True, - ), + [ + gtx_transformations.MapFusionSerial( + only_toplevel_maps=True, + ), + gtx_transformations.MapFusionParallel( + only_toplevel_maps=True, + # This will lead to the creation of big probably unrelated maps. + # However, it might be good. + only_if_common_ancestor=False, + ), + ], validate=validate, validate_all=validate_all, ) @@ -437,7 +380,7 @@ def gt_auto_fuse_top_level_maps( # The SDFG was modified by the transformations above. The SDFG was # modified. Call Simplify and try again to further optimize. - gt_simplify(sdfg, validate=validate, validate_all=validate_all) + gtx_transformations.gt_simplify(sdfg, validate=validate, validate_all=validate_all) else: raise RuntimeWarning("Optimization of the SDFG did not converge.") diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 16c9600a3a..2cd3020180 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -11,10 +11,15 @@ from __future__ import annotations import copy -from typing import Any, Optional, Sequence, Union +from typing import Any, Callable, Final, Optional, Sequence, Union import dace -from dace import properties as dace_properties, transformation as dace_transformation +from dace import ( + dtypes as dace_dtypes, + properties as dace_properties, + transformation as dace_transformation, +) +from dace.codegen.targets import cpp as dace_cpp from dace.sdfg import nodes as dace_nodes from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -51,7 +56,9 @@ def gt_gpu_transformation( will avoid the data copy from host to GPU memory. gpu_block_size: The size of a thread block on the GPU. gpu_launch_bounds: Use this value as `__launch_bounds__` for _all_ GPU Maps. + Will only take effect if `gpu_block_size` is specified. gpu_launch_factor: Use the number of threads times this value as `__launch_bounds__` + Will only take effect if `gpu_block_size` is specified. validate: Perform validation during the steps. validate_all: Perform extensive validation. @@ -82,39 +89,197 @@ def gt_gpu_transformation( validate_all=validate_all, simplify=False, ) + # The documentation recommends to run simplify afterwards gtx_transformations.gt_simplify(sdfg) if try_removing_trivial_maps: - # A Tasklet, outside of a Map, that writes into an array on GPU can not work - # `sdfg.appyl_gpu_transformations()` puts Map around it (if said Tasklet - # would write into a Scalar that then goes into a GPU Map, nothing would - # happen. So we might end up with lot of these trivial Maps, that results - # in a single kernel launch. To prevent this we will try to fuse them. - # NOTE: The current implementation has a bug, because promotion and fusion - # are two different steps. Because of this the function will implicitly - # fuse everything together it can find. - # TODO(phimuell): Fix the issue described above. + # In DaCe a Tasklet, outside of a Map, can not write into an _array_ that is on + # GPU. `sdfg.appyl_gpu_transformations()` will wrap such Tasklets in a Map. So + # we might end up with lots of these trivial Maps, each requiring a separate + # kernel launch. To prevent this we will combine these trivial maps, if + # possible, with their downstream maps. sdfg.apply_transformations_once_everywhere( - TrivialGPUMapPromoter(), + TrivialGPUMapElimination(), validate=False, validate_all=False, ) - sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion( - only_toplevel_maps=True, - ), - validate=validate, - validate_all=validate_all, - ) + gtx_transformations.gt_simplify(sdfg, validate=validate, validate_all=validate_all) + + # TODO(phimuell): Fixing the stride problem. + sdfg = gt_gpu_transform_non_standard_memlet( + sdfg=sdfg, + map_postprocess=True, + validate=validate, + validate_all=validate_all, + ) # Set the GPU block size if it is known. if gpu_block_size is not None: gt_set_gpu_blocksize( sdfg=sdfg, - gpu_block_size=gpu_block_size, - gpu_launch_bounds=gpu_launch_bounds, - gpu_launch_factor=gpu_launch_factor, + block_size=gpu_block_size, + launch_bounds=gpu_launch_bounds, + launch_factor=gpu_launch_factor, + ) + + if validate_all or validate: + sdfg.validate() + + return sdfg + + +def gt_gpu_transform_non_standard_memlet( + sdfg: dace.SDFG, + map_postprocess: bool, + validate: bool = True, + validate_all: bool = False, +) -> dace.SDFG: + """Transform some non standard Melets to Maps. + + The GPU code generator is not able to handle certain sets of Memlets. To + handle them, the code generator transforms them into copy Maps. The main + issue is that this transformation happens after the auto optimizer, thus + the copy-Maps will most likely have the wrong iteration order. + + This function allows to perform the preprocessing step before the actual + code generation. The function will perform the expansion. If + `map_postprocess` is `True` then the function will also apply MapFusion, + to these newly created copy-Maps and set their iteration order correctly. + + A user should not call this function directly, instead this function is + called by the `gt_gpu_transformation()` function. + + Args: + sdfg: The SDFG that we process. + map_postprocess: Enable post processing of the maps that are created. + See the Note section below. + validate: Perform validation at the end of the function. + validate_all: Perform validation also on intermediate steps. + + Note: + - Currently the function applies some crude heuristic to determine the + correct loop order. + - This function should be called after `gt_set_iteration_order()` has run. + """ + new_maps: set[dace_nodes.MapEntry] = set() + + # This code is is copied from DaCe's code generator. + for e, state in list(sdfg.all_edges_recursive()): + nsdfg = state.parent + if ( + isinstance(e.src, dace_nodes.AccessNode) + and isinstance(e.dst, dace_nodes.AccessNode) + and e.src.desc(nsdfg).storage == dace_dtypes.StorageType.GPU_Global + and e.dst.desc(nsdfg).storage == dace_dtypes.StorageType.GPU_Global + ): + a: dace_nodes.AccessNode = e.src + b: dace_nodes.AccessNode = e.dst + + copy_shape, src_strides, dst_strides, _, _ = dace_cpp.memlet_copy_to_absolute_strides( + None, nsdfg, state, e, a, b + ) + dims = len(copy_shape) + if dims == 1: + continue + elif dims == 2: + if src_strides[-1] != 1 or dst_strides[-1] != 1: + try: + is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1] + is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1] + except (TypeError, ValueError): + is_src_cont = False + is_dst_cont = False + if is_src_cont and is_dst_cont: + continue + else: + continue + elif dims > 2: + if not (src_strides[-1] != 1 or dst_strides[-1] != 1): + continue + + # For identifying the new map, we first store all neighbors of `a`. + old_neighbors_of_a: list[dace_nodes.AccessNode] = [ + edge.dst for edge in state.out_edges(a) + ] + + # Turn unsupported copy to a map + try: + dace_transformation.dataflow.CopyToMap.apply_to( + nsdfg, save=False, annotate=False, a=a, b=b + ) + except ValueError: # If transformation doesn't match, continue normally + continue + + # We find the new map by comparing the new neighborhood of `a` with the old one. + new_nodes: set[dace_nodes.MapEntry] = { + edge.dst for edge in state.out_edges(a) if edge.dst not in old_neighbors_of_a + } + assert any(isinstance(new_node, dace_nodes.MapEntry) for new_node in new_nodes) + assert len(new_nodes) == 1 + new_maps.update(new_nodes) + + # If there are no Memlets that are translated to copy-Maps, then we have nothing to do. + if len(new_maps) == 0: + return sdfg + + # This function allows to restrict any fusion operation to the maps + # that we have just created. + def restrict_fusion_to_newly_created_maps( + self: gtx_transformations.map_fusion_helper.MapFusionHelper, + map_entry_1: dace_nodes.MapEntry, + map_entry_2: dace_nodes.MapEntry, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + permissive: bool, + ) -> bool: + return any(new_entry in new_maps for new_entry in [map_entry_1, map_entry_2]) + + # Using the callback to restrict the fusing + sdfg.apply_transformations_repeated( + [ + gtx_transformations.MapFusionSerial( + only_toplevel_maps=True, + apply_fusion_callback=restrict_fusion_to_newly_created_maps, + ), + gtx_transformations.MapFusionParallel( + only_toplevel_maps=True, + apply_fusion_callback=restrict_fusion_to_newly_created_maps, + ), + ], + validate=validate, + validate_all=validate_all, + ) + + # Now we have to find the maps that were not fused. We rely here on the fact + # that at least one of the map that is involved in fusing still exists. + maps_to_modify: set[dace_nodes.MapEntry] = set() + for nsdfg in sdfg.all_sdfgs_recursive(): + for state in nsdfg.states(): + for map_entry in state.nodes(): + if not isinstance(map_entry, dace_nodes.MapEntry): + continue + if map_entry in new_maps: + maps_to_modify.add(map_entry) + assert 0 < len(maps_to_modify) <= len(new_maps) + + # This is a gross hack, but it is needed, for the following reasons: + # - The transients have C order while the non-transients have (most + # likely) FORTRAN order. So there is not an unique stride dimension. + # - The newly created maps have names that does not reflect GT4Py dimensions, + # thus we can not use `gt_set_iteration_order()`. + # For these reasons we do the simplest thing, which is assuming that the maps + # are created in C order and we must make them in FORTRAN order, which means + # just swapping the order of the map parameters. + # TODO(phimuell): Do it properly. + for me_to_modify in maps_to_modify: + map_to_modify: dace_nodes.Map = me_to_modify.map + map_to_modify.params = list(reversed(map_to_modify.params)) + map_to_modify.range = dace.subsets.Range( + (r1, r2, r3, t) + for (r1, r2, r3), t in zip( + reversed(map_to_modify.range.ranges), reversed(map_to_modify.range.tile_sizes) + ) ) return sdfg @@ -122,131 +287,214 @@ def gt_gpu_transformation( def gt_set_gpu_blocksize( sdfg: dace.SDFG, - gpu_block_size: Optional[Sequence[int | str] | str], - gpu_launch_bounds: Optional[int | str] = None, - gpu_launch_factor: Optional[int] = None, + block_size: Optional[Sequence[int | str] | str], + launch_bounds: Optional[int | str] = None, + launch_factor: Optional[int] = None, + **kwargs: Any, ) -> Any: """Set the block size related properties of _all_ Maps. - See `GPUSetBlockSize` for more information. + It supports the same arguments as `GPUSetBlockSize`, however it also has + versions without `_Xd`, these are used as default for the other maps. + If a version with `_Xd` is specified then it takes precedence. Args: sdfg: The SDFG to process. - gpu_block_size: The size of a thread block on the GPU. + block_size: The size of a thread block on the GPU. launch_bounds: The value for the launch bound that should be used. launch_factor: If no `launch_bounds` was given use the number of threads in a block multiplied by this number. """ - xform = GPUSetBlockSize( - block_size=gpu_block_size, - launch_bounds=gpu_launch_bounds, - launch_factor=gpu_launch_factor, - ) - return sdfg.apply_transformations_once_everywhere([xform]) - - -def _gpu_block_parser( - self: GPUSetBlockSize, - val: Any, -) -> None: - """Used by the setter of `GPUSetBlockSize.block_size`.""" - org_val = val - if isinstance(val, (tuple | list)): - pass - elif isinstance(val, str): - val = tuple(x.strip() for x in val.split(",")) - elif isinstance(val, int): - val = (val,) - else: - raise TypeError( - f"Does not know how to transform '{type(org_val).__name__}' into a proper GPU block size." - ) - if 0 < len(val) <= 3: - val = [*val, *([1] * (3 - len(val)))] - else: - raise ValueError(f"Can not parse block size '{org_val}': wrong length") - try: - val = [int(x) for x in val] - except ValueError: - raise TypeError( - f"Currently only block sizes convertible to int are supported, you passed '{val}'." - ) from None - self._block_size = val - + for dim in [1, 2, 3]: + for arg, val in { + "block_size": block_size, + "launch_bounds": launch_bounds, + "launch_factor": launch_factor, + }.items(): + if f"{arg}_{dim}d" not in kwargs: + kwargs[f"{arg}_{dim}d"] = val + return sdfg.apply_transformations_once_everywhere(GPUSetBlockSize(**kwargs)) + + +def _make_gpu_block_parser_for( + dim: int, +) -> Callable[["GPUSetBlockSize", Any], None]: + """Generates a parser for GPU blocks for dimension `dim`. + + The returned function can be used as parser for the `GPUSetBlockSize.block_size_*d` + properties. + """ -def _gpu_block_getter( - self: "GPUSetBlockSize", -) -> tuple[int, int, int]: - """Used as getter in the `GPUSetBlockSize.block_size` property.""" - assert isinstance(self._block_size, (tuple, list)) and len(self._block_size) == 3 - assert all(isinstance(x, int) for x in self._block_size) - return tuple(self._block_size) + def _gpu_block_parser( + self: GPUSetBlockSize, + val: Any, + ) -> None: + """Used by the setter of `GPUSetBlockSize.block_size`.""" + org_val = val + if isinstance(val, (tuple | list)): + pass + elif isinstance(val, str): + val = tuple(x.strip() for x in val.split(",")) + elif isinstance(val, int): + val = (val,) + else: + raise TypeError( + f"Does not know how to transform '{type(org_val).__name__}' into a proper GPU block size." + ) + if len(val) < dim: + raise ValueError( + f"The passed block size only covers {len(val)} dimensions, but dimension was {dim}." + ) + if 0 < len(val) <= 3: + val = [*val, *([1] * (3 - len(val)))] + else: + raise ValueError(f"Can not parse block size '{org_val}': wrong length") + try: + val = [int(x) for x in val] + except ValueError: + raise TypeError( + f"Currently only block sizes convertible to int are supported, you passed '{val}'." + ) from None + + # Remove over specification. + for i in range(dim, 3): + val[i] = 1 + setattr(self, f"_block_size_{dim}d", tuple(val)) + + return _gpu_block_parser + + +def _make_gpu_block_getter_for( + dim: int, +) -> Callable[["GPUSetBlockSize"], tuple[int, int, int]]: + """Makes the getter for the block size of dimension `dim`.""" + + def _gpu_block_getter( + self: "GPUSetBlockSize", + ) -> tuple[int, int, int]: + """Used as getter in the `GPUSetBlockSize.block_size` property.""" + return getattr(self, f"_block_size_{dim}d") + + return _gpu_block_getter + + +def _gpu_launch_bound_parser( + block_size: tuple[int, int, int], + launch_bounds: int | str | None, + launch_factor: int | None = None, +) -> str | None: + """Used by the `GPUSetBlockSize.__init__()` method to parse the launch bounds.""" + if launch_bounds is None and launch_factor is None: + return None + elif launch_bounds is None and launch_factor is not None: + return str(int(launch_factor) * block_size[0] * block_size[1] * block_size[2]) + elif launch_bounds is not None and launch_factor is None: + assert isinstance(launch_bounds, (str, int)) + return str(launch_bounds) + else: + raise ValueError("Specified both `launch_bounds` and `launch_factor`.") @dace_properties.make_properties class GPUSetBlockSize(dace_transformation.SingleStateTransformation): """Sets the GPU block size on GPU Maps. - The transformation will apply to all Maps that have a GPU schedule, regardless - of their dimensionality. + The `block_size` is either a sequence, of up to three integers or a string + of up to three numbers, separated by comma (`,`). The first number is the size + of the block in `x` direction, the second for the `y` direction and the third + for the `z` direction. Missing values will be filled with `1`. - The `gpu_block_size` is either a sequence, of up to three integers or a string - of up to three numbers, separated by comma (`,`). - The first number is the size of the block in `x` direction, the second for the - `y` direction and the third for the `z` direction. Missing values will be filled - with `1`. + A different value for the GPU block size and launch bound can be specified for + maps of dimension 1, 2 or 3 (all maps with higher dimensions are considered + three dimensional). If no value is specified then the block size `(32, 1, 1)` + will be used an no launch bound will be be emitted. Args: - block_size: The size of a thread block on the GPU. - launch_bounds: The value for the launch bound that should be used. - launch_factor: If no `launch_bounds` was given use the number of threads - in a block multiplied by this number. + block_size_Xd: The size of a thread block on the GPU for `X` dimensional maps. + launch_bounds_Xd: The value for the launch bound that should be used for `X` + dimensional maps. + launch_factor_Xd: If no `launch_bounds` was given use the number of threads + in a block multiplied by this number, for maps of dimension `X`. - Todo: - Add the possibility to specify other bounds for 1, 2, or 3 dimensional maps. + Note: + - You should use the `gt_set_gpu_blocksize()` function. + - "Over specification" is ignored, i.e. if `(32, 3, 1)` is passed as block + size for 1 dimensional maps, then it is changed to `(32, 1, 1)`. """ - block_size = dace_properties.Property( - dtype=None, - allow_none=False, - default=(32, 1, 1), - setter=_gpu_block_parser, - getter=_gpu_block_getter, - desc="Size of the block size a GPU Map should have.", - ) + _block_size_default: Final[tuple[int, int, int]] = (32, 1, 1) - launch_bounds = dace_properties.Property( + block_size_1d = dace_properties.Property( + dtype=tuple[int, int, int], + default=_block_size_default, + setter=_make_gpu_block_parser_for(1), + getter=_make_gpu_block_getter_for(1), + desc="Block size for 1 dimensional GPU maps.", + ) + launch_bounds_1d = dace_properties.Property( + dtype=str, + allow_none=True, + default=None, + desc="Set the launch bound property for 1 dimensional map.", + ) + block_size_2d = dace_properties.Property( + dtype=tuple[int, int, int], + default=_block_size_default, + setter=_make_gpu_block_parser_for(2), + getter=_make_gpu_block_getter_for(2), + desc="Block size for 2 dimensional GPU maps.", + ) + launch_bounds_2d = dace_properties.Property( + dtype=str, + allow_none=True, + default=None, + desc="Set the launch bound property for 2 dimensional map.", + ) + block_size_3d = dace_properties.Property( + dtype=tuple[int, int, int], + default=_block_size_default, + setter=_make_gpu_block_parser_for(3), + getter=_make_gpu_block_getter_for(3), + desc="Block size for 3 dimensional GPU maps.", + ) + launch_bounds_3d = dace_properties.Property( dtype=str, allow_none=True, default=None, - desc="Set the launch bound property of the map.", + desc="Set the launch bound property for 3 dimensional map.", ) - map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + # Pattern matching + map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) def __init__( self, - block_size: Sequence[int | str] | str | None = None, - launch_bounds: int | str | None = None, - launch_factor: int | None = None, + block_size_1d: Sequence[int | str] | str | None = None, + block_size_2d: Sequence[int | str] | str | None = None, + block_size_3d: Sequence[int | str] | str | None = None, + launch_bounds_1d: int | str | None = None, + launch_bounds_2d: int | str | None = None, + launch_bounds_3d: int | str | None = None, + launch_factor_1d: int | None = None, + launch_factor_2d: int | None = None, + launch_factor_3d: int | None = None, ) -> None: super().__init__() - if block_size is not None: - self.block_size = block_size - - if launch_factor is not None: - assert launch_bounds is None - self.launch_bounds = str( - int(launch_factor) * self.block_size[0] * self.block_size[1] * self.block_size[2] - ) - elif launch_bounds is None: - self.launch_bounds = None - elif isinstance(launch_bounds, (str, int)): - self.launch_bounds = str(launch_bounds) - else: - raise TypeError( - f"Does not know how to parse '{launch_bounds}' as 'launch_bounds' argument." - ) + if block_size_1d is not None: + self.block_size_1d = block_size_1d + if block_size_2d is not None: + self.block_size_2d = block_size_2d + if block_size_3d is not None: + self.block_size_3d = block_size_3d + self.launch_bounds_1d = _gpu_launch_bound_parser( + self.block_size_1d, launch_bounds_1d, launch_factor_1d + ) + self.launch_bounds_2d = _gpu_launch_bound_parser( + self.block_size_2d, launch_bounds_2d, launch_factor_2d + ) + self.launch_bounds_3d = _gpu_launch_bound_parser( + self.block_size_3d, launch_bounds_3d, launch_factor_3d + ) @classmethod def expressions(cls) -> Any: @@ -266,7 +514,6 @@ def can_be_applied( - If the map is at global scope. - If if the schedule of the map is correct. """ - scope = graph.scope_dict() if scope[self.map_entry] is not None: return False @@ -282,35 +529,69 @@ def apply( sdfg: dace.SDFG, ) -> None: """Modify the map as requested.""" - self.map_entry.map.gpu_block_size = self.block_size - if self.launch_bounds is not None: # Note empty string has a meaning in DaCe - self.map_entry.map.gpu_launch_bounds = self.launch_bounds + gpu_map: dace_nodes.Map = self.map_entry.map + if len(gpu_map.params) == 1: + block_size = self.block_size_1d + launch_bounds = self.launch_bounds_1d + elif len(gpu_map.params) == 2: + block_size = self.block_size_2d + launch_bounds = self.launch_bounds_2d + else: + block_size = self.block_size_3d + launch_bounds = self.launch_bounds_3d + gpu_map.gpu_block_size = block_size + if launch_bounds is not None: # Note: empty string has a meaning in DaCe + gpu_map.gpu_launch_bounds = launch_bounds @dace_properties.make_properties -class TrivialGPUMapPromoter(dace_transformation.SingleStateTransformation): - """Serial Map promoter for empty GPU maps. +class TrivialGPUMapElimination(dace_transformation.SingleStateTransformation): + """Eliminate certain kind of trivial GPU maps. - In CPU mode a Tasklet can be outside of a map, however, this is not - possible in GPU mode. For this reason DaCe wraps such Tasklets in a - trivial Map. - This transformation will look for such Maps and promote them, such - that they can be fused with downstream maps. + A tasklet outside of map can not write to GPU memory, this can only be done + from within a map (a scalar is possible). For that reason DaCe's GPU + transformation wraps such tasklets in trivial maps. + Under certain condition the transformation will fuse the trivial tasklet with + a downstream (serial) map. + + Args: + do_not_fuse: If `True` then the maps are not fused together. + only_gpu_maps: Only apply to GPU maps; `True` by default. Note: - This transformation should not be run on its own, instead it is run within the context of `gt_gpu_transformation()`. - This transformation must be run after the GPU Transformation. - - Currently the transformation does not do the fusion on its own. - Instead map fusion must be run afterwards. - - The transformation assumes that the upper Map is a trivial Tasklet. - Which should be the majority of all cases. """ + only_gpu_maps = dace_properties.Property( + dtype=bool, + default=True, + desc="Only promote maps that are GPU maps (debug option).", + ) + do_not_fuse = dace_properties.Property( + dtype=bool, + default=False, + desc="Only perform the promotion, do not fuse.", + ) + # Pattern Matching - trivial_map_exit = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) - access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) - second_map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + trivial_map_exit = dace_transformation.PatternNode(dace_nodes.MapExit) + access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + second_map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) + + def __init__( + self, + do_not_fuse: Optional[bool] = None, + only_gpu_maps: Optional[bool] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if only_gpu_maps is not None: + self.only_gpu_maps = only_gpu_maps + if do_not_fuse is not None: + self.do_not_fuse = do_not_fuse @classmethod def expressions(cls) -> Any: @@ -332,63 +613,118 @@ def can_be_applied( The tests includes: - Schedule of the maps. - If the map is trivial. - - If the trivial map was not used to define a symbol. - - Intermediate access node can only have in and out degree of 1. - - The trivial map exit can only have one output. + - Tests if the maps can be fused. """ trivial_map_exit: dace_nodes.MapExit = self.trivial_map_exit trivial_map: dace_nodes.Map = trivial_map_exit.map trivial_map_entry: dace_nodes.MapEntry = graph.entry_node(trivial_map_exit) second_map: dace_nodes.Map = self.second_map_entry.map - access_node: dace_nodes.AccessNode = self.access_node # The kind of maps we are interested only have one parameter. if len(trivial_map.params) != 1: return False - - # Check if it is a GPU map - for map_to_check in [trivial_map, second_map]: - if map_to_check.schedule not in [ - dace.dtypes.ScheduleType.GPU_Device, - dace.dtypes.ScheduleType.GPU_Default, - ]: - return False - - # Check if the map is trivial. for rng in trivial_map.range.ranges: if rng[0] != rng[1]: return False - # Now we have to ensure that the symbol is not used inside the scope of the - # map, if it is, then the symbol is just there to define a symbol. - scope_view = graph.scope_subgraph( - trivial_map_entry, - include_entry=False, - include_exit=False, - ) - if any(map_param in scope_view.free_symbols for map_param in trivial_map.params): - return False + # If we do not not fuse, then the second map can not be trivial. + # If we would not prevent that case then we would match these two + # maps again and again. + if self.do_not_fuse and len(second_map.params) <= 1: + for rng in second_map.range.ranges: + if rng[0] == rng[1]: + return False + + # We now check that the Memlets do not depend on the map parameter. + # This is important for the `can_be_applied_to()` check we do below + # because we can avoid calling the replace function. + scope = graph.scope_subgraph(trivial_map_entry) + trivial_map_param: str = trivial_map.params[0] + for edge in scope.edges(): + if trivial_map_param in edge.data.free_symbols: + return False - # ensuring that the trivial map exit and the intermediate node have degree - # one is a cheap way to ensure that the map can be merged into the - # second map. - if graph.in_degree(access_node) != 1: - return False - if graph.out_degree(access_node) != 1: - return False - if graph.out_degree(trivial_map_exit) != 1: - return False + # Check if only GPU maps are involved (this is more a testing debug feature). + if self.only_gpu_maps: + for map_to_check in [trivial_map, second_map]: + if map_to_check.schedule not in [ + dace.dtypes.ScheduleType.GPU_Device, + dace.dtypes.ScheduleType.GPU_Default, + ]: + return False + + # Now we check if the two maps can be fused together. For that we have to + # do a temporary promotion, it is important that we do not perform the + # renaming. If the old symbol is still used, it is used inside a tasklet + # so it would show up (temporarily) as free symbol. + org_trivial_map_params = copy.deepcopy(trivial_map.params) + org_trivial_map_range = copy.deepcopy(trivial_map.range) + try: + self._promote_map(graph, replace_trivail_map_parameter=False) + if not gtx_transformations.MapFusionSerial.can_be_applied_to( + sdfg=sdfg, + map_exit_1=trivial_map_exit, + intermediate_access_node=self.access_node, + map_entry_2=self.second_map_entry, + ): + return False + finally: + trivial_map.params = org_trivial_map_params + trivial_map.range = org_trivial_map_range return True def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: """Performs the Map Promoting. - The function essentially copies the parameters and the ranges from the - bottom map to the top one. + The function will first perform the promotion of the trivial map and then + perform the merging of the two maps in one go. """ + trivial_map_exit: dace_nodes.MapExit = self.trivial_map_exit + second_map_entry: dace_nodes.MapEntry = self.second_map_entry + access_node: dace_nodes.AccessNode = self.access_node + + # Promote the maps. + self._promote_map(graph) + + # Perform the fusing if requested. + if not self.do_not_fuse: + gtx_transformations.MapFusionSerial.apply_to( + sdfg=sdfg, + map_exit_1=trivial_map_exit, + intermediate_access_node=access_node, + map_entry_2=second_map_entry, + verify=True, + ) + + def _promote_map( + self, + state: dace.SDFGState, + replace_trivail_map_parameter: bool = True, + ) -> None: + """Performs the map promoting. + + Essentially this function will copy the parameters and the range from + the non trivial map (`self.second_map_entry.map`) to the trivial map + (`self.trivial_map_exit.map`). + + If `replace_trivail_map_parameter` is `True` (the default value), then the + function will also remove the trivial map parameter with its value. + """ + assert isinstance(self.trivial_map_exit, dace_nodes.MapExit) + assert isinstance(self.second_map_entry, dace_nodes.MapEntry) + assert isinstance(self.access_node, dace_nodes.AccessNode) + + trivial_map_exit: dace_nodes.MapExit = self.trivial_map_exit trivial_map: dace_nodes.Map = self.trivial_map_exit.map + trivial_map_entry: dace_nodes.MapEntry = state.entry_node(trivial_map_exit) second_map: dace_nodes.Map = self.second_map_entry.map + # If requested then replace the map variable with its value. + if replace_trivail_map_parameter: + scope = state.scope_subgraph(trivial_map_entry) + scope.replace(trivial_map.params[0], trivial_map.range[0][0]) + + # Now copy parameter and the ranges from the second to the trivial map. trivial_map.params = copy.deepcopy(second_map.params) trivial_map.range = copy.deepcopy(second_map.range) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/local_double_buffering.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/local_double_buffering.py new file mode 100644 index 0000000000..52f1de3d0c --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/local_double_buffering.py @@ -0,0 +1,393 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +import copy + +import dace +from dace import ( + data as dace_data, + dtypes as dace_dtypes, + symbolic as dace_symbolic, + transformation as dace_transformation, +) +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +def gt_create_local_double_buffering( + sdfg: dace.SDFG, +) -> int: + """Modifies the SDFG such that point wise data dependencies are stable. + + Rule 3 of the ADR18, guarantees that if data is input and output to a map, + then it must be a non transient array and it must only have point wise + dependency. This means that every index that is read is also written by + the same thread and no other thread reads or writes to the same location. + However, because the dataflow inside a map is partially asynchronous + it might happen if something is read multiple times, i.e. Tasklets, + the data might already be overwritten. + This function will scan the SDFG for potential cases and insert an + access node to cache this read. This is essentially a double buffer, but + it is not needed that the whole data is stored, but only the working set + of a single thread. + """ + + processed_maps = 0 + for nsdfg in sdfg.all_sdfgs_recursive(): + processed_maps += _create_local_double_buffering_non_recursive(nsdfg) + return processed_maps + + +def _create_local_double_buffering_non_recursive( + sdfg: dace.SDFG, +) -> int: + """Implementation of the point wise transformation. + + This function does not handle nested SDFGs. + """ + # First we call `EdgeConsolidation`, because of that we know that + # every incoming edge of a `MapEntry` refers to distinct data. + # We do this to simplify our implementation. + edge_consolidation = dace_transformation.passes.ConsolidateEdges() + edge_consolidation.apply_pass(sdfg, None) + + processed_maps = 0 + for state in sdfg.states(): + scope_dict = state.scope_dict() + for node in state.nodes(): + if not isinstance(node, dace_nodes.MapEntry): + continue + if scope_dict[node] is not None: + continue + inout_nodes = _check_if_map_must_be_handled( + map_entry=node, + state=state, + sdfg=sdfg, + ) + if inout_nodes is not None: + processed_maps += _add_local_double_buffering_to( + map_entry=node, + inout_nodes=inout_nodes, + state=state, + sdfg=sdfg, + ) + return processed_maps + + +def _add_local_double_buffering_to( + inout_nodes: dict[str, tuple[dace_nodes.AccessNode, dace_nodes.AccessNode]], + map_entry: dace_nodes.MapEntry, + state: dace.SDFGState, + sdfg: dace.SDFG, +) -> int: + """Adds the double buffering to `map_entry` for `inout_nodes`. + + The function assumes that there is only in incoming edge per data + descriptor at the map entry. If the data is needed multiple times, + then the distribution must be done inside the map. + + The function will now channel all reads to the data descriptor + through an access node, this ensures that the read happens + before the write. + """ + processed_maps = 0 + for inout_node in inout_nodes.values(): + _add_local_double_buffering_to_single_data( + map_entry=map_entry, + inout_node=inout_node, + state=state, + sdfg=sdfg, + ) + processed_maps += 1 + return processed_maps + + +def _add_local_double_buffering_to_single_data( + inout_node: tuple[dace_nodes.AccessNode, dace_nodes.AccessNode], + map_entry: dace_nodes.MapEntry, + state: dace.SDFGState, + sdfg: dace.SDFG, +) -> None: + """Adds the local double buffering for a single data.""" + map_exit: dace_nodes.MapExit = state.exit_node(map_entry) + input_node, output_node = inout_node + input_edges = state.edges_between(input_node, map_entry) + output_edges = state.edges_between(map_exit, output_node) + assert len(input_edges) == 1 + assert len(output_edges) == 1 + inner_read_edges = _get_inner_edges(input_edges[0], map_entry, state, False) + inner_write_edges = _get_inner_edges(output_edges[0], map_exit, state, True) + + # For now we assume that all read the same, which is checked below. + new_double_inner_buff_shape_raw = dace_symbolic.overapproximate( + inner_read_edges[0].data.get_src_subset(inner_read_edges[0], state).size() + ) + + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + squeezed_dims: list[int] = [] # These are the dimensions we removed. + new_double_inner_buff_shape: list[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate( + zip(new_double_inner_buff_shape_raw, input_node.desc(sdfg).shape) + ): + if full_dim_size == 1: # Must be kept! + new_double_inner_buff_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_double_inner_buff_shape.append(proposed_dim_size) + + new_double_inner_buff_name: str = f"__inner_double_buffer_for_{input_node.data}" + # Now generate the intermediate data container. + if len(new_double_inner_buff_shape) == 0: + new_double_inner_buff_name, new_double_inner_buff_desc = sdfg.add_scalar( + new_double_inner_buff_name, + dtype=input_node.desc(sdfg).dtype, + transient=True, + storage=dace_dtypes.StorageType.Register, + find_new_name=True, + ) + else: + new_double_inner_buff_name, new_double_inner_buff_desc = sdfg.add_transient( + new_double_inner_buff_name, + shape=new_double_inner_buff_shape, + dtype=input_node.desc(sdfg).dtype, + find_new_name=True, + storage=dace_dtypes.StorageType.Register, + ) + new_double_inner_buff_node = state.add_access(new_double_inner_buff_name) + + # Now reroute the data flow through the new access node. + for old_inner_read_edge in inner_read_edges: + # To do handle the case the memlet is "fancy" + state.add_edge( + new_double_inner_buff_node, + None, + old_inner_read_edge.dst, + old_inner_read_edge.dst_conn, + dace.Memlet( + data=new_double_inner_buff_name, + subset=dace.subsets.Range.from_array(new_double_inner_buff_desc), + other_subset=copy.deepcopy( + old_inner_read_edge.data.get_dst_subset(old_inner_read_edge, state) + ), + ), + ) + state.remove_edge(old_inner_read_edge) + + # Now create a connection between the map entry and the intermediate node. + state.add_edge( + map_entry, + inner_read_edges[0].src_conn, + new_double_inner_buff_node, + None, + dace.Memlet( + data=input_node.data, + subset=copy.deepcopy( + inner_read_edges[0].data.get_src_subset(inner_read_edges[0], state) + ), + other_subset=dace.subsets.Range.from_array(new_double_inner_buff_desc), + ), + ) + + # To really ensure that a read happens before a write, we have to sequence + # the read first. We do this by connecting the double buffer node with + # empty Memlets to the last row of nodes that writes to the global buffer. + # This is needed to handle the case that some other data path performs the + # write. + # TODO(phimuell): Add a test that only performs this when it is really needed. + for inner_write_edge in inner_write_edges: + state.add_nedge( + new_double_inner_buff_node, + inner_write_edge.src, + dace.Memlet(), + ) + + +def _check_if_map_must_be_handled_classify_adjacent_access_node( + data_node: dace_nodes.AccessNode, + sdfg: dace.SDFG, + known_nodes: dict[str, dace_nodes.AccessNode], +) -> bool: + """Internal function used by `_check_if_map_must_be_handled()` to classify nodes. + + If the function returns `True` it means that the input/output, does not + violates an internal constraint, i.e. can be handled by + `_ensure_that_map_is_pointwise()`. If appropriate the function will add the + node to `known_nodes`. I.e. in case of a transient the function will return + `True` but will not add it to `known_nodes`. + """ + + # This case is indicating that the `ConsolidateEdges` has not fully worked. + # Currently the transformation implementation assumes that this is the + # case, so we can not handle this case. + # TODO(phimuell): Implement this case. + if data_node.data in known_nodes: + return False + data_desc: dace_data.Data = data_node.desc(sdfg) + + # The conflict can only occur for global data, because transients + # are only written once. + if data_desc.transient: + return False + + # Currently we do not handle view, as they need to be traced. + # TODO(phimuell): Implement + if gtx_transformations.util.is_view(data_desc, sdfg): + return False + + # TODO(phimuell): Check if there is a access node on the inner side, then we do not have to do it. + + # Now add the node to the list. + assert all(data_node is not known_node for known_node in known_nodes.values()) + known_nodes[data_node.data] = data_node + return True + + +def _get_inner_edges( + outer_edge: dace.sdfg.graph.MultiConnectorEdge, + scope_node: dace_nodes.MapExit | dace_nodes.MapEntry, + state: dace.SDFG, + outgoing_edge: bool, +) -> list[dace.sdfg.graph.MultiConnectorEdge]: + """Gets the edges on the inside of a map.""" + if outgoing_edge: + assert isinstance(scope_node, dace_nodes.MapExit) + conn_name = outer_edge.src_conn[4:] + return list(state.in_edges_by_connector(scope_node, connector="IN_" + conn_name)) + else: + assert isinstance(scope_node, dace_nodes.MapEntry) + conn_name = outer_edge.dst_conn[3:] + return list(state.out_edges_by_connector(scope_node, connector="OUT_" + conn_name)) + + +def _check_if_map_must_be_handled( + map_entry: dace_nodes.MapEntry, + state: dace.SDFGState, + sdfg: dace.SDFG, +) -> None | dict[str, tuple[dace_nodes.AccessNode, dace_nodes.AccessNode]]: + """Check if the map should be processed to uphold rule 3. + + Essentially the function will check if there is a potential read-write + conflict. The function assumes that `ConsolidateEdges` has already run. + + If there is a possible data race the function will return a `dict`, that + maps the name of the data to the access nodes that are used as input and + output to the Map. + + Otherwise, the function returns `None`. It is, however, important that + `None` does not means that there is no possible race condition. It could + also means that the function that implements the buffering, i.e. + `_ensure_that_map_is_pointwise()`, is unable to handle this case. + + Todo: + Improve the function + """ + map_exit: dace_nodes.MapExit = state.exit_node(map_entry) + + # Find all the data that is accessed. Views are resolved. + input_datas: dict[str, dace_nodes.AccessNode] = {} + output_datas: dict[str, dace_nodes.AccessNode] = {} + + # Determine which nodes are possible conflicting. + for in_edge in state.in_edges(map_entry): + if in_edge.data.is_empty(): + continue + if not isinstance(in_edge.src, dace_nodes.AccessNode): + # TODO(phiumuell): Figuring out what this case means + continue + if in_edge.dst_conn and not in_edge.dst_conn.startswith("IN_"): + # TODO(phimuell): It is very unlikely that a Dynamic Map Range causes + # this particular data race, so we ignore it for the time being. + continue + if not _check_if_map_must_be_handled_classify_adjacent_access_node( + data_node=in_edge.src, + sdfg=sdfg, + known_nodes=input_datas, + ): + continue + for out_edge in state.out_edges(map_exit): + if out_edge.data.is_empty(): + continue + if not isinstance(out_edge.dst, dace_nodes.AccessNode): + # TODO(phiumuell): Figuring out what this case means + continue + if not _check_if_map_must_be_handled_classify_adjacent_access_node( + data_node=out_edge.dst, + sdfg=sdfg, + known_nodes=output_datas, + ): + continue + + # Double buffering is only needed if there inout arguments. + inout_datas: dict[str, tuple[dace_nodes.AccessNode, dace_nodes.AccessNode]] = { + dname: (input_datas[dname], output_datas[dname]) + for dname in input_datas + if dname in output_datas + } + if len(inout_datas) == 0: + return None + + # TODO(phimuell): What about the case that some data descriptor needs double + # buffering, but some do not? + for inout_data_name in list(inout_datas.keys()): + input_node, output_node = inout_datas[inout_data_name] + input_edges = state.edges_between(input_node, map_entry) + output_edges = state.edges_between(map_exit, output_node) + assert ( + len(input_edges) == 1 + ), f"Expected a single connection between input node and map entry, but found {len(input_edges)}." + assert ( + len(output_edges) == 1 + ), f"Expected a single connection between map exit and write back node, but found {len(output_edges)}." + + # If there is only one edge on the inside of the map, that goes into an + # AccessNode, then we assume it is double buffered. + inner_read_edges = _get_inner_edges(input_edges[0], map_entry, state, False) + if ( + len(inner_read_edges) == 1 + and isinstance(inner_read_edges[0].dst, dace_nodes.AccessNode) + and not gtx_transformations.util.is_view(inner_read_edges[0].dst, sdfg) + ): + inout_datas.pop(inout_data_name) + continue + + inner_read_subsets = [ + inner_read_edge.data.get_src_subset(inner_read_edge, state) + for inner_read_edge in inner_read_edges + ] + assert all(inner_read_subset is not None for inner_read_subset in inner_read_subsets) + inner_write_subsets = [ + inner_write_edge.data.get_dst_subset(inner_write_edge, state) + for inner_write_edge in _get_inner_edges(output_edges[0], map_exit, state, True) + ] + # TODO(phimuell): Also implement a check that the volume equals the size of the subset. + assert all(inner_write_subset is not None for inner_write_subset in inner_write_subsets) + + # For being point wise the subsets must be compatible. The correct check would be: + # - The write sets are unique. + # - For every read subset there exists one matching write subset. It could + # be that there are many equivalent read subsets. + # - For every write subset there exists at least one matching read subset. + # The current implementation only checks if all are the same. + # TODO(phimuell): Implement the real check. + all_inner_subsets = inner_read_subsets + inner_write_subsets + if not all( + all_inner_subsets[0] == all_inner_subsets[i] for i in range(1, len(all_inner_subsets)) + ): + return None + + if len(inout_datas) == 0: + return None + + return inout_datas diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py index d7326e1131..d401c06f15 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py @@ -59,20 +59,11 @@ class LoopBlocking(dace_transformation.SingleStateTransformation): desc="Name of the iteration variable on which to block (must be an exact match);" " 'I' in the above description.", ) - independent_nodes = dace_properties.Property( - dtype=set, - allow_none=True, - default=None, - desc="Set of nodes that are independent of the blocking parameter.", - ) - dependent_nodes = dace_properties.Property( - dtype=set, - allow_none=True, - default=None, - desc="Set of nodes that are dependent on the blocking parameter.", - ) + # Set of nodes that are independent of the blocking parameter. + _independent_nodes: Optional[set[dace_nodes.AccessNode]] + _dependent_nodes: Optional[set[dace_nodes.AccessNode]] - outer_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + outer_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) def __init__( self, @@ -86,6 +77,8 @@ def __init__( self.blocking_parameter = blocking_parameter if blocking_size is not None: self.blocking_size = blocking_size + self._independent_nodes = None + self._dependent_nodes = None @classmethod def expressions(cls) -> Any: @@ -125,6 +118,8 @@ def can_be_applied( return False if not self.partition_map_output(graph, sdfg): return False + self._independent_nodes = None + self._dependent_nodes = None return True @@ -137,7 +132,6 @@ def apply( Performs the operation described in the doc string. """ - # Now compute the partitions of the nodes. self.partition_map_output(graph, sdfg) @@ -153,10 +147,8 @@ def apply( state=graph, sdfg=sdfg, ) - - # Clear the old partitions - self.independent_nodes = None - self.dependent_nodes = None + self._independent_nodes = None + self._dependent_nodes = None def _prepare_inner_outer_maps( self, @@ -269,8 +261,8 @@ def partition_map_output( """ # Clear the previous partition. - self.independent_nodes = set() - self.dependent_nodes = None + self._independent_nodes = set() + self._dependent_nodes = None while True: # Find all the nodes that we have to classify in this iteration. @@ -279,9 +271,9 @@ def partition_map_output( nodes_to_classify: set[dace_nodes.Node] = { edge.dst for edge in state.out_edges(self.outer_entry) } - for independent_node in self.independent_nodes: + for independent_node in self._independent_nodes: nodes_to_classify.update({edge.dst for edge in state.out_edges(independent_node)}) - nodes_to_classify.difference_update(self.independent_nodes) + nodes_to_classify.difference_update(self._independent_nodes) # Now classify each node found_new_independent_node = False @@ -294,7 +286,7 @@ def partition_map_output( # Check if the partition exists. if class_res is None: - self.independent_nodes = None + self._independent_nodes = None return False if class_res is True: found_new_independent_node = True @@ -305,10 +297,10 @@ def partition_map_output( # After the independent set is computed compute the set of dependent nodes # as the set of all nodes adjacent to `outer_entry` that are not dependent. - self.dependent_nodes = { + self._dependent_nodes = { edge.dst for edge in state.out_edges(self.outer_entry) - if edge.dst not in self.independent_nodes + if edge.dst not in self._independent_nodes } return True @@ -333,7 +325,7 @@ def _classify_node( Returns: The function returns `True` if `node_to_classify` is considered independent. - In this case the function will add the node to `self.independent_nodes`. + In this case the function will add the node to `self._independent_nodes`. If the function returns `False` the node was classified as a dependent node. The function will return `None` if the node can not be classified, in this case the partition does not exist. @@ -343,23 +335,50 @@ def _classify_node( state: The state containing the map. sdfg: The SDFG that is processed. """ + assert self._independent_nodes is not None # silence MyPy outer_entry: dace_nodes.MapEntry = self.outer_entry # for caching. + outer_exit: dace_nodes.MapExit = state.exit_node(outer_entry) + + # The node needs to have an input and output. + if state.in_degree(node_to_classify) == 0 or state.out_degree(node_to_classify) == 0: + return None # We are only able to handle certain kind of nodes, so screening them. if isinstance(node_to_classify, dace_nodes.Tasklet): if node_to_classify.side_effects: - # TODO(phimuell): Think of handling it. return None + + # A Tasklet must write to an AccessNode, because otherwise there would + # be nothing that could be used to cache anything. Furthermore, this + # AccessNode must be outside of the inner loop, i.e. be independent. + # TODO: Make this check stronger to ensure that there is always an + # AccessNode that is independent. + if not all( + isinstance(out_edge.dst, dace_nodes.AccessNode) + for out_edge in state.out_edges(node_to_classify) + if not out_edge.data.is_empty() + ): + return False + elif isinstance(node_to_classify, dace_nodes.AccessNode): # AccessNodes need to have some special properties. node_desc: dace.data.Data = node_to_classify.desc(sdfg) - if isinstance(node_desc, dace.data.View): # Views are forbidden. return None - if node_desc.lifetime != dace.dtypes.AllocationLifetime.Scope: - # The access node has to life fully within the scope. + + # The access node inside either has scope lifetime or is a scalar. + if isinstance(node_desc, dace.data.Scalar): + pass + elif node_desc.lifetime != dace.dtypes.AllocationLifetime.Scope: return None + + elif isinstance(node_to_classify, dace_nodes.MapEntry): + # We classify `MapEntries` as dependent nodes, we could now start + # looking if the whole map is independent, but it is currently an + # overkill. + return False + else: # Any other node type we can not handle, so the partition can not exist. # TODO(phimuell): Try to handle certain kind of library nodes. @@ -376,29 +395,12 @@ def _classify_node( # for these classification to make sense the partition has to exist in the # first place. - # Either all incoming edges of a node are empty or none of them. If it has - # empty edges, they are only allowed to come from the map entry. - found_empty_edges, found_nonempty_edges = False, False - for in_edge in in_edges: - if in_edge.data.is_empty(): - found_empty_edges = True - if in_edge.src is not outer_entry: - # TODO(phimuell): Lift this restriction. - return None - else: - found_nonempty_edges = True - - # Test if we found a mixture of empty and nonempty edges. - if found_empty_edges and found_nonempty_edges: - return None - assert ( - found_empty_edges or found_nonempty_edges - ), f"Node '{node_to_classify}' inside '{outer_entry}' without an input connection." - - # Requiring that all output Memlets are non empty implies, because we are - # inside a scope, that there exists an output. - if any(out_edge.data.is_empty() for out_edge in state.out_edges(node_to_classify)): - return None + # There are some very small requirements that we impose on the output edges. + for out_edge in state.out_edges(node_to_classify): + # We consider nodes that are directly connected to the outer map exit as + # dependent. This is an implementation detail to avoid some hard cases. + if out_edge.dst is outer_exit: + return False # Now we have ensured that the partition exists, thus we will now evaluate # if the node is independent or dependent. @@ -413,7 +415,7 @@ def _classify_node( # Now we have to look at incoming edges individually. # We will inspect the subset of the Memlet to see if they depend on the # block variable. If this loop ends normally, then we classify the node - # as independent and the node is added to `independent_nodes`. + # as independent and the node is added to `_independent_nodes`. for in_edge in in_edges: memlet: dace.Memlet = in_edge.data src_subset: dace_subsets.Subset | None = memlet.src_subset @@ -436,11 +438,11 @@ def _classify_node( # The edge must either originate from `outer_entry` or from an independent # node if not it is dependent. - if not (in_edge.src is outer_entry or in_edge.src in self.independent_nodes): + if not (in_edge.src is outer_entry or in_edge.src in self._independent_nodes): return False # Loop ended normally, thus we classify the node as independent. - self.independent_nodes.add(node_to_classify) + self._independent_nodes.add(node_to_classify) return True def _rewire_map_scope( @@ -467,116 +469,138 @@ def _rewire_map_scope( state: The state of the map. sdfg: The SDFG we operate on. """ + assert self._independent_nodes is not None and self._dependent_nodes is not None # Contains the nodes that are already have been handled. relocated_nodes: set[dace_nodes.Node] = set() # We now handle all independent nodes, this means that all of their - # _output_ edges have to go through the new inner map and the Memlets need - # modifications, because of the block parameter. - for independent_node in self.independent_nodes: - for out_edge in state.out_edges(independent_node): + # _output_ edges have to go through the new inner map and the Memlets + # need modifications, because of the block parameter. + for independent_node in self._independent_nodes: + for out_edge in list(state.out_edges(independent_node)): edge_dst: dace_nodes.Node = out_edge.dst relocated_nodes.add(edge_dst) # If destination of this edge is also independent we do not need # to handle it, because that node will also be before the new # inner serial map. - if edge_dst in self.independent_nodes: + if edge_dst in self._independent_nodes: continue # Now split `out_edge` such that it passes through the new inner entry. # We do not need to modify the subsets, i.e. replacing the variable # on which we block, because the node is independent and the outgoing # new inner map entry iterate over the blocked variable. - new_map_conn = inner_entry.next_connector() - dace_helpers.redirect_edge( - state=state, - edge=out_edge, - new_dst=inner_entry, - new_dst_conn="IN_" + new_map_conn, + if out_edge.data.is_empty(): + # `out_edge` is an empty Memlet that ensures its source, which is + # independent, is sequenced before its destination, which is + # dependent. We now have to split it into two. + # TODO(phimuell): Can we remove this edge? Is the map enough to + # ensure proper sequencing? + new_in_conn = None + new_out_conn = None + new_memlet_outside = dace.Memlet() + + elif not isinstance(independent_node, dace_nodes.AccessNode): + # For syntactical reasons there must be an access node on the + # outside of the (inner) scope, that acts as cache. The + # classification and this preconditions on SDFG should ensure + # that, but there are a few super hard edge cases. + # TODO(phimuell): Add an intermediate here in this case + raise NotImplementedError() + + else: + # NOTE: This creates more connections that are ultimately + # necessary. However, figuring out which one to use and if + # it would be valid, is very complicated, so we don't do it. + new_map_conn = inner_entry.next_connector(try_name=out_edge.data.data) + new_in_conn = "IN_" + new_map_conn + new_out_conn = "OUT_" + new_map_conn + new_memlet_outside = dace.Memlet.from_array( + out_edge.data.data, sdfg.arrays[out_edge.data.data] + ) + inner_entry.add_in_connector(new_in_conn) + inner_entry.add_out_connector(new_out_conn) + + state.add_edge( + out_edge.src, + out_edge.src_conn, + inner_entry, + new_in_conn, + new_memlet_outside, ) - # TODO(phimuell): Check if there might be a subset error. state.add_edge( inner_entry, - "OUT_" + new_map_conn, + new_out_conn, out_edge.dst, out_edge.dst_conn, copy.deepcopy(out_edge.data), ) - inner_entry.add_in_connector("IN_" + new_map_conn) - inner_entry.add_out_connector("OUT_" + new_map_conn) + state.remove_edge(out_edge) # Now we handle the dependent nodes, they differ from the independent nodes - # in that they _after_ the new inner map entry. Thus, we will modify incoming edges. - for dependent_node in self.dependent_nodes: + # in that they _after_ the new inner map entry. Thus, we have to modify + # their incoming edges. + for dependent_node in self._dependent_nodes: for in_edge in state.in_edges(dependent_node): edge_src: dace_nodes.Node = in_edge.src - # Since the independent nodes were already processed, and they process - # their output we have to check for this. We do this by checking if - # the source of the edge is the new inner map entry. + # The incoming edge of a dependent node (before any processing) either + # starts at: + # - The outer map. + # - An other dependent node. + # - An independent node. + # The last case was already handled by the loop above. if edge_src is inner_entry: + # Edge originated originally at an independent node, but was + # already handled by the loop above. assert dependent_node in relocated_nodes - continue - # A dependent node has at least one connection to the outer map entry. - # And these are the only connections that we must handle, since other - # connections come from independent nodes, and were already handled - # or are inner nodes. - if edge_src is not outer_entry: - continue - - # If we encounter an empty Memlet we just just attach it to the - # new inner map entry. Note the partition function ensures that - # either all edges are empty or non. - if in_edge.data.is_empty(): - assert ( - edge_src is outer_entry - ), f"Found an empty edge that does not go to the outer map entry, but to '{edge_src}'." + elif edge_src is not outer_entry: + # Edge originated at an other dependent node. There is nothing + # that we have to do. + # NOTE: We can not test if `edge_src` is in `self._dependent_nodes` + # because it only contains the dependent nodes that are directly + # connected to the map entry. + assert edge_src not in self._independent_nodes + + elif in_edge.data.is_empty(): + # The dependent node has an empty Memlet to the other map. + # Since the inner map is sequenced after the outer map, + # we will simply reconnect the edge to the inner map. + # TODO(phimuell): Are there situations where this makes problems. dace_helpers.redirect_edge(state=state, edge=in_edge, new_src=inner_entry) - continue - # Because of the definition of a dependent node and the processing - # order, their incoming edges either point to the outer map or - # are already handled. - assert ( - edge_src is outer_entry - ), f"Expected to find source '{outer_entry}' but found '{edge_src}'." - edge_conn: str = in_edge.src_conn[4:] - - # Must be before the handling of the modification below - # Note that this will remove the original edge from the SDFG. - dace_helpers.redirect_edge( - state=state, - edge=in_edge, - new_src=inner_entry, - new_src_conn="OUT_" + edge_conn, - ) - - # In a valid SDFG only one edge can go into an input connector of a Map. - if "IN_" + edge_conn in inner_entry.in_connectors: - # We have found this edge multiple times already. - # To ensure that there is no error, we will create a new - # Memlet that reads the whole array. - piping_edge = next(state.in_edges_by_connector(inner_entry, "IN_" + edge_conn)) - data_name = piping_edge.data.data - piping_edge.data = dace.Memlet.from_array( - data_name, sdfg.arrays[data_name], piping_edge.data.wcr + elif edge_src is outer_entry: + # This dependent node originated at the outer map. Thus we have to + # split the edge, such that it now passes through the inner map. + new_map_conn = inner_entry.next_connector(try_name=in_edge.src_conn[4:]) + new_in_conn = "IN_" + new_map_conn + new_out_conn = "OUT_" + new_map_conn + new_memlet_inner = dace.Memlet.from_array( + in_edge.data.data, sdfg.arrays[in_edge.data.data] ) - - else: - # This is the first time we found this connection. - # so we just create the edge. state.add_edge( - outer_entry, - "OUT_" + edge_conn, + in_edge.src, + in_edge.src_conn, inner_entry, - "IN_" + edge_conn, + new_in_conn, + new_memlet_inner, + ) + state.add_edge( + inner_entry, + new_out_conn, + in_edge.dst, + in_edge.dst_conn, copy.deepcopy(in_edge.data), ) - inner_entry.add_in_connector("IN_" + edge_conn) - inner_entry.add_out_connector("OUT_" + edge_conn) + inner_entry.add_in_connector(new_in_conn) + inner_entry.add_out_connector(new_out_conn) + state.remove_edge(in_edge) + + else: + raise NotImplementedError("Unknown node configuration.") # In certain cases it might happen that we need to create an empty # Memlet between the outer map entry and the inner one. @@ -593,7 +617,7 @@ def _rewire_map_scope( # This is simple reconnecting, there would be possibilities for improvements # but we do not use them for now. for in_edge in state.in_edges(outer_exit): - edge_conn = in_edge.dst_conn[3:] + edge_conn = inner_exit.next_connector(in_edge.dst_conn[3:]) dace_helpers.redirect_edge( state=state, edge=in_edge, @@ -610,5 +634,9 @@ def _rewire_map_scope( inner_exit.add_in_connector("IN_" + edge_conn) inner_exit.add_out_connector("OUT_" + edge_conn) + # There is an invalid cache state in the SDFG, that makes the memlet + # propagation fail, to clear the cache we call the hash function. + # See: https://github.com/spcl/dace/issues/1703 + _ = sdfg.hash_sdfg() # TODO(phimuell): Use a less expensive method. dace.sdfg.propagation.propagate_memlets_state(sdfg, state) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index ec33e7ea63..eceb07ed82 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -6,89 +6,106 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -"""Implements helper functions for the map fusion transformations. +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +"""Implements Helper functionaliyies for map fusion -Note: - After DaCe [PR#1629](https://github.com/spcl/dace/pull/1629), that implements - a better map fusion transformation is merged, this file will be deleted. +THIS FILE WAS COPIED FROM DACE TO FACILITATE DEVELOPMENT UNTIL THE PR#1625 IN +DACE IS MERGED AND THE VERSION WAS UPGRADED. """ -import functools -import itertools -from typing import Any, Optional, Sequence, Union + +# ruff: noqa + +import copy +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union, Callable, TypeAlias import dace -from dace import ( - data as dace_data, - properties as dace_properties, - subsets as dace_subsets, - transformation as dace_transformation, -) -from dace.sdfg import graph as dace_graph, nodes as dace_nodes, validation as dace_validation -from dace.transformation import helpers as dace_helpers - -from gt4py.next.program_processors.runners.dace_fieldview.transformations import util - - -@dace_properties.make_properties -class MapFusionHelper(dace_transformation.SingleStateTransformation): - """Contains common part of the fusion for parallel and serial Map fusion. - - The transformation assumes that the SDFG obeys the principals outlined in - [ADR0018](https://github.com/GridTools/gt4py/tree/main/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md). - The main advantage of this structure is, that it is rather easy to determine - if a transient is used anywhere else. This check, performed by - `is_interstate_transient()`. It is further speeded up by cashing some computation, - thus such an object should not be used after interstate optimizations were applied - to the SDFG. +from dace import data, properties, subsets, symbolic, transformation +from dace.sdfg import SDFG, SDFGState, nodes, validation +from dace.transformation import helpers + +FusionCallback: TypeAlias = Callable[ + ["MapFusionHelper", nodes.MapEntry, nodes.MapEntry, dace.SDFGState, dace.SDFG, bool], bool +] +"""Callback for the map fusion transformation to check if a fusion should be performed. +""" + + +@properties.make_properties +class MapFusionHelper(transformation.SingleStateTransformation): + """Common parts of the parallel and serial map fusion transformation. Args: only_inner_maps: Only match Maps that are internal, i.e. inside another Map. only_toplevel_maps: Only consider Maps that are at the top. + strict_dataflow: If `True`, the transformation ensures a more + stricter version of the data flow. + apply_fusion_callback: A user supplied function, same signature as `can_be_fused()`, + to check if a fusion should be performed. + + Note: + If `strict_dataflow` mode is enabled then the transformation will not remove + _direct_ data flow dependency from the graph. Furthermore, the transformation + will not remove size 1 dimensions of intermediate it creates. + This is a compatibility mode, that will limit the applicability of the + transformation, but might help transformations that do not fully analyse + the graph. """ - only_toplevel_maps = dace_properties.Property( + only_toplevel_maps = properties.Property( dtype=bool, default=False, - allow_none=False, desc="Only perform fusing if the Maps are in the top level.", ) - only_inner_maps = dace_properties.Property( + only_inner_maps = properties.Property( dtype=bool, default=False, - allow_none=False, desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", ) - shared_transients = dace_properties.DictProperty( - key_type=dace.SDFG, - value_type=set[str], - default=None, - allow_none=True, - desc="Maps SDFGs to the set of array transients that can not be removed. " - "The variable acts as a cache, and is managed by 'is_interstate_transient()'.", + strict_dataflow = properties.Property( + dtype=bool, + default=False, + desc="If `True` then the transformation will ensure a more stricter data flow.", ) + # Callable that can be specified by the user, if it is specified, it should be + # a callable with the same signature as `can_be_fused()`. If the function returns + # `False` then the fusion will be rejected. + _apply_fusion_callback: Optional[FusionCallback] + + # Maps SDFGs to the set of data that can not be removed, + # because they transmit data _between states_, such data will be made 'shared'. + # This variable acts as a cache, and is managed by 'is_shared_data()'. + _shared_data: Dict[SDFG, Set[str]] + def __init__( self, only_inner_maps: Optional[bool] = None, only_toplevel_maps: Optional[bool] = None, + strict_dataflow: Optional[bool] = None, + apply_fusion_callback: Optional[FusionCallback] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) + self._shared_data = {} + self._apply_fusion_callback = None if only_toplevel_maps is not None: self.only_toplevel_maps = bool(only_toplevel_maps) if only_inner_maps is not None: self.only_inner_maps = bool(only_inner_maps) - self.shared_transients = {} + if strict_dataflow is not None: + self.strict_dataflow = bool(strict_dataflow) + if apply_fusion_callback is not None: + self._apply_fusion_callback = apply_fusion_callback @classmethod def expressions(cls) -> bool: - raise RuntimeError("The `_MapFusionHelper` is not a transformation on its own.") + raise RuntimeError("The `MapFusionHelper` is not a transformation on its own.") def can_be_fused( self, - map_entry_1: dace_nodes.MapEntry, - map_entry_2: dace_nodes.MapEntry, + map_entry_1: nodes.MapEntry, + map_entry_2: nodes.MapEntry, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG, permissive: bool = False, @@ -97,13 +114,11 @@ def can_be_fused( This function only checks constrains that are common between serial and parallel map fusion process, which includes: + - The registered callback, if specified. - The scope of the maps. - The scheduling of the maps. - The map parameters. - However, for performance reasons, the function does not check if the node - decomposition exists. - Args: map_entry_1: The entry of the first (in serial case the top) map. map_exit_2: The entry of the second (in serial case the bottom) map. @@ -111,6 +126,13 @@ def can_be_fused( sdfg: The SDFG itself. permissive: Currently unused. """ + # Consult the callback if defined. + if self._apply_fusion_callback is not None: + if not self._apply_fusion_callback( + self, map_entry_1, map_entry_2, graph, sdfg, permissive + ): + return False + if self.only_inner_maps and self.only_toplevel_maps: raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") @@ -128,26 +150,22 @@ def can_be_fused( elif self.only_toplevel_maps: if scope[map_entry_1] is not None: return False - # TODO(phimuell): Figuring out why this is here. - elif util.is_nested_sdfg(sdfg): - return False - # We will now check if there exists a "remapping" that we can use. - # NOTE: The serial map promoter depends on the fact that this is the - # last check. - if not self.map_parameter_compatible( - map_1=map_entry_1.map, map_2=map_entry_2.map, state=graph, sdfg=sdfg + # We will now check if there exists a remapping of the map parameter + if ( + self.find_parameter_remapping(first_map=map_entry_1.map, second_map=map_entry_2.map) + is None ): return False return True - @staticmethod def relocate_nodes( - from_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], - to_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], - state: dace.SDFGState, - sdfg: dace.SDFG, + self, + from_node: Union[nodes.MapExit, nodes.MapEntry], + to_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, ) -> None: """Move the connectors and edges from `from_node` to `to_nodes` node. @@ -156,6 +174,7 @@ def relocate_nodes( once for the entry and then for the exit. While it does not remove the node themselves if guarantees that the `from_node` has degree zero. + The function assumes that the parameter renaming was already done. Args: from_node: Node from which the edges should be removed. @@ -165,22 +184,23 @@ def relocate_nodes( """ # Now we relocate empty Memlets, from the `from_node` to the `to_node` - for empty_edge in filter(lambda e: e.data.is_empty(), state.out_edges(from_node)): - dace_helpers.redirect_edge(state, empty_edge, new_src=to_node) - for empty_edge in filter(lambda e: e.data.is_empty(), state.in_edges(from_node)): - dace_helpers.redirect_edge(state, empty_edge, new_dst=to_node) + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.out_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_src=to_node) + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.in_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_dst=to_node) # We now ensure that there is only one empty Memlet from the `to_node` to any other node. # Although it is allowed, we try to prevent it. - empty_targets: set[dace_nodes.Node] = set() - for empty_edge in filter(lambda e: e.data.is_empty(), state.all_edges(to_node)): + empty_targets: Set[nodes.Node] = set() + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.all_edges(to_node))): if empty_edge.dst in empty_targets: state.remove_edge(empty_edge) empty_targets.add(empty_edge.dst) # We now determine which edges we have to migrate, for this we are looking at # the incoming edges, because this allows us also to detect dynamic map ranges. - for edge_to_move in state.in_edges(from_node): + # TODO(phimuell): If there is already a connection to the node, reuse this. + for edge_to_move in list(state.in_edges(from_node)): assert isinstance(edge_to_move.dst_conn, str) if not edge_to_move.dst_conn.startswith("IN_"): @@ -200,36 +220,32 @@ def relocate_nodes( raise RuntimeError( # Might fail because of out connectors. f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'." ) - dace_helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) + helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) from_node.remove_in_connector(dmr_symbol) - # There is no other edge that we have to consider, so we just end here - continue - - # We have a Passthrough connection, i.e. there exists a matching `OUT_`. - old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix - new_conn = to_node.next_connector(old_conn) - - to_node.add_in_connector("IN_" + new_conn) - for e in state.in_edges_by_connector(from_node, "IN_" + old_conn): - dace_helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) - to_node.add_out_connector("OUT_" + new_conn) - for e in state.out_edges_by_connector(from_node, "OUT_" + old_conn): - dace_helpers.redirect_edge( - state, e, new_src=to_node, new_src_conn="OUT_" + new_conn - ) - from_node.remove_in_connector("IN_" + old_conn) - from_node.remove_out_connector("OUT_" + old_conn) + else: + # We have a Passthrough connection, i.e. there exists a matching `OUT_`. + old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix + new_conn = to_node.next_connector(old_conn) + + to_node.add_in_connector("IN_" + new_conn) + for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): + helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) + to_node.add_out_connector("OUT_" + new_conn) + for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): + helpers.redirect_edge(state, e, new_src=to_node, new_src_conn="OUT_" + new_conn) + from_node.remove_in_connector("IN_" + old_conn) + from_node.remove_out_connector("OUT_" + old_conn) # Check if we succeeded. if state.out_degree(from_node) != 0: - raise dace_validation.InvalidSDFGError( + raise validation.InvalidSDFGError( f"Failed to relocate the outgoing edges from `{from_node}`, there are still `{state.out_edges(from_node)}`", sdfg, sdfg.node_id(state), ) if state.in_degree(from_node) != 0: - raise dace_validation.InvalidSDFGError( + raise validation.InvalidSDFGError( f"Failed to relocate the incoming edges from `{from_node}`, there are still `{state.in_edges(from_node)}`", sdfg, sdfg.node_id(state), @@ -237,330 +253,442 @@ def relocate_nodes( assert len(from_node.in_connectors) == 0 assert len(from_node.out_connectors) == 0 - @staticmethod - def map_parameter_compatible( - map_1: dace_nodes.Map, - map_2: dace_nodes.Map, - state: Union[dace.SDFGState, dace.SDFG], - sdfg: dace.SDFG, - ) -> bool: - """Checks if the parameters of `map_1` are compatible with `map_2`. + def find_parameter_remapping( + self, first_map: nodes.Map, second_map: nodes.Map + ) -> Union[Dict[str, str], None]: + """Computes the parameter remapping for the parameters of the _second_ map. + + The returned `dict` maps the parameters of the second map (keys) to parameter + names of the first map (values). Because of how the replace function works + the `dict` describes how to replace the parameters of the second map + with parameters of the first map. + Parameters that already have the correct name and compatible range, are not + included in the return value, thus the keys and values are always different. + If no renaming at all is _needed_, i.e. all parameter have the same name and + range, then the function returns an empty `dict`. + If no remapping exists, then the function will return `None`. - The check follows the following rules: - - The names of the map variables must be the same, i.e. no renaming - is performed. - - The ranges must be the same. + Args: + first_map: The first map (these parameters will be replaced). + second_map: The second map, these parameters acts as source. """ - range_1: dace_subsets.Range = map_1.range - params_1: Sequence[str] = map_1.params - range_2: dace_subsets.Range = map_2.range - params_2: Sequence[str] = map_2.params - - # The maps are only fuseable if we have an exact match in the parameter names - # this is because we do not do any renaming. This is in accordance with the - # rules. - if set(params_1) != set(params_2): - return False - # Maps the name of a parameter to the dimension index - param_dim_map_1: dict[str, int] = {pname: i for i, pname in enumerate(params_1)} - param_dim_map_2: dict[str, int] = {pname: i for i, pname in enumerate(params_2)} + # The parameter names + first_params: List[str] = first_map.params + second_params: List[str] = second_map.params + + if len(first_params) != len(second_params): + return None + + # The ranges, however, we apply some post processing to them. + simp = lambda e: symbolic.simplify_ext(symbolic.simplify(e)) # noqa: E731 + first_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) for param, rng in zip(first_params, first_map.range) + } + second_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) + for param, rng in zip(second_params, second_map.range) + } + + # Parameters of the second map that have not yet been matched to a parameter + # of the first map and vice versa. + unmapped_second_params: Set[str] = set(second_params) + unused_first_params: Set[str] = set(first_params) + + # This is the result (`second_param -> first_param`), note that if no renaming + # is needed then the parameter is not present in the mapping. + final_mapping: Dict[str, str] = {} + + # First we identify the parameters that already have the correct name. + for param in set(first_params).intersection(second_params): + first_rng = first_rngs[param] + second_rng = second_rngs[param] + + if first_rng == second_rng: + # They have the same name and the same range, this is already a match. + # Because the names are already the same, we do not have to enter them + # in the `final_mapping` + unmapped_second_params.discard(param) + unused_first_params.discard(param) + + # Check if no remapping is needed. + if len(unmapped_second_params) == 0: + return {} + + # Now we go through all the parameters that we have not mapped yet. + # All of them will result in a remapping. + for unmapped_second_param in unmapped_second_params: + second_rng = second_rngs[unmapped_second_param] + assert unmapped_second_param not in final_mapping + + # Now look in all not yet used parameters of the first map which to use. + for candidate_param in unused_first_params: + candidate_rng = first_rngs[candidate_param] + if candidate_rng == second_rng: + final_mapping[unmapped_second_param] = candidate_param + unused_first_params.discard(candidate_param) + break + else: + # We did not find a candidate, so the remapping does not exist + return None - # To fuse the two maps the ranges must have the same ranges - for pname in params_1: - idx_1 = param_dim_map_1[pname] - idx_2 = param_dim_map_2[pname] - # TODO(phimuell): do we need to call simplify? - if range_1[idx_1] != range_2[idx_2]: - return False + assert len(unused_first_params) == 0 + assert len(final_mapping) == len(unmapped_second_params) + return final_mapping - return True + def rename_map_parameters( + self, + first_map: nodes.Map, + second_map: nodes.Map, + second_map_entry: nodes.MapEntry, + state: SDFGState, + ) -> None: + """Replaces the map parameters of the second map with names from the first. + + The replacement is done in a safe way, thus `{'i': 'j', 'j': 'i'}` is + handled correct. The function assumes that a proper replacement exists. + The replacement is computed by calling `self.find_parameter_remapping()`. + + Args: + first_map: The first map (these are the final parameter). + second_map: The second map, this map will be replaced. + second_map_entry: The entry node of the second map. + state: The SDFGState on which we operate. + """ + # Compute the replacement dict. + repl_dict: Dict[str, str] = self.find_parameter_remapping( # type: ignore[assignment] + first_map=first_map, second_map=second_map + ) + + if repl_dict is None: + raise RuntimeError("The replacement does not exist") + if len(repl_dict) == 0: + return + + second_map_scope = state.scope_subgraph(entry_node=second_map_entry) + # Why is this thing is symbolic and not in replace? + symbolic.safe_replace( + mapping=repl_dict, + replace_callback=second_map_scope.replace_dict, + ) - def is_interstate_transient( + # For some odd reason the replace function does not modify the range and + # parameter of the map, so we will do it the hard way. + second_map.params = copy.deepcopy(first_map.params) + second_map.range = copy.deepcopy(first_map.range) + + def is_shared_data( self, - transient: Union[str, dace_nodes.AccessNode], + data: nodes.AccessNode, sdfg: dace.SDFG, - state: dace.SDFGState, ) -> bool: - """Tests if `transient` is an interstate transient, an can not be removed. - - Essentially this function checks if a transient might be needed in a - different state in the SDFG, because it transmit information from - one state to the other. - If only the name of the data container is passed the function will - first look for an corresponding access node. + """Tests if `data` is interstate data, an can not be removed. - The set of these "interstate transients" is computed once per SDFG. - The result is then cached internally for later reuse. + Interstate data is used to transmit data between multiple state or by + extension within the state. Thus it must be classified as a shared output. + This function will go through the SDFG to and collect the names of all data + container that should be classified as shared. Note that this is an over + approximation as it does not take the location into account, i.e. "is no longer + used". Args: transient: The transient that should be checked. sdfg: The SDFG containing the array. - state: If given the state the node is located in. + + Note: + The function computes the this set once for every SDFG and then caches it. + There is no mechanism to detect if the cache must be evicted. However, + as long as no additional data is added, there is no problem. """ + if sdfg not in self._shared_data: + self._compute_shared_data(sdfg) + return data.data in self._shared_data[sdfg] - # The following builds upon the HACK MD document and not on ADR0018. - # Therefore the numbers are slightly different, but both documents - # essentially describes the same SDFG. - # According to [rule 6](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) - # the set of such transients is partially given by all source access dace_nodes. - # Because of rule 3 we also include all scalars in this set, as an over - # approximation. Furthermore, because simplify might violate rule 3, - # we also include the sink dace_nodes. - - # See if we have already computed the set - if sdfg in self.shared_transients: - shared_sdfg_transients: set[str] = self.shared_transients[sdfg] - else: - # SDFG is not known so we have to compute the set. - shared_sdfg_transients = set() - for state_to_scan in sdfg.all_states(): - # TODO(phimuell): Use `all_nodes_recursive()` once it is available. - shared_sdfg_transients.update( - [ - node.data - for node in itertools.chain( - state_to_scan.source_nodes(), state_to_scan.sink_nodes() - ) - if isinstance(node, dace_nodes.AccessNode) - and sdfg.arrays[node.data].transient - ] + def _compute_shared_data( + self, + sdfg: dace.SDFG, + ) -> None: + """Updates the internal set of shared data/interstate data of `self` for `sdfg`. + + See the documentation for `self.is_shared_data()` for a description. + + Args: + sdfg: The SDFG for which the set of shared data should be computed. + """ + # Shared data of this SDFG. + shared_data: Set[str] = set() + + # All global data can not be removed, so it must always be shared. + for data_name, data_desc in sdfg.arrays.items(): + if not data_desc.transient: + shared_data.add(data_name) + elif isinstance(data_desc, dace.data.Scalar): + shared_data.add(data_name) + + # We go through all states and classify the nodes/data: + # - Data is referred to in different states. + # - The access node is a view (both have to survive). + # - Transient sink or source node. + # - The access node has output degree larger than 1 (input degrees larger + # than one, will always be partitioned as shared anyway). + prevously_seen_data: Set[str] = set() + interstate_read_symbols: Set[str] = set() + for state in sdfg.nodes(): + for access_node in state.data_nodes(): + if access_node.data in shared_data: + # The data was already classified to be shared data + pass + + elif access_node.data in prevously_seen_data: + # We have seen this data before, either in this state or in + # a previous one, but we did not classifies it as shared back then + shared_data.add(access_node.data) + + if state.in_degree(access_node) == 0: + # (Transient) sink nodes are used in other states, or simplify + # will get rid of them. + shared_data.add(access_node.data) + + elif ( + state.out_degree(access_node) != 1 + ): # state.out_degree() == 0 or state.out_degree() > 1 + # The access node is either a source node (it is shared in another + # state) or the node has a degree larger than one, so it is used + # in this state somewhere else. + shared_data.add(access_node.data) + + elif self.is_view(node=access_node, sdfg=sdfg): + # To ensure that the write to the view happens, both have to be shared. + viewed_data: str = self.track_view( + view=access_node, state=state, sdfg=sdfg + ).data + shared_data.update([access_node.data, viewed_data]) + prevously_seen_data.update([access_node.data, viewed_data]) + + else: + # The node was not classified as shared data, so we record that + # we saw it. Note that a node that was immediately classified + # as shared node will never be added to this set, but a data + # that was found twice will be inside this list. + prevously_seen_data.add(access_node.data) + + # Now we are collecting all symbols that interstate edges read from. + for edge in sdfg.edges(): + interstate_read_symbols.update(edge.data.read_symbols()) + + # We also have to keep everything the edges referrers to and is an array. + shared_data.update(interstate_read_symbols.intersection(prevously_seen_data)) + + # Update the internal cache + self._shared_data[sdfg] = shared_data + + def _compute_multi_write_data( + self, + state: SDFGState, + sdfg: SDFG, + ) -> Set[str]: + """Computes data inside a _single_ state, that is written multiple times. + + Essentially this function computes the set of data that does not follow + the single static assignment idiom. The function also resolves views. + If an access node, refers to a view, not only the view itself, but also + the data it refers to is added to the set. + + Args: + state: The state that should be examined. + sdfg: The SDFG object. + + Note: + This information is used by the partition function (in case strict data + flow mode is enabled), in strict data flow mode only. The current + implementation is rather simple as it only checks if a data is written + to multiple times in the same state. + """ + data_written_to: Set[str] = set() + multi_write_data: Set[str] = set() + + for access_node in state.data_nodes(): + if state.in_degree(access_node) == 0: + continue + if access_node.data in data_written_to: + multi_write_data.add(access_node.data) + elif self.is_view(access_node, sdfg): + # This is an over approximation. + multi_write_data.update( + [access_node.data, self.track_view(access_node, state, sdfg).data] ) - self.shared_transients[sdfg] = shared_sdfg_transients - - if isinstance(transient, str): - name = transient - matching_access_nodes = [node for node in state.data_nodes() if node.data == name] - # Rule 8: There is only one access node per state for data. - assert len(matching_access_nodes) == 1 - transient = matching_access_nodes[0] - else: - assert isinstance(transient, dace_nodes.AccessNode) - name = transient.data + data_written_to.add(access_node.data) + return multi_write_data - desc: dace_data.Data = sdfg.arrays[name] - if not desc.transient: - return True - if isinstance(desc, dace_data.Scalar): - return True # Scalars can not be removed by fusion anyway. + def is_node_reachable_from( + self, + graph: Union[dace.SDFG, dace.SDFGState], + begin: nodes.Node, + end: nodes.Node, + ) -> bool: + """Test if the node `end` can be reached from `begin`. + + Essentially the function starts a DFS at `begin`. If an edge is found that lead + to `end` the function returns `True`. If the node is never found `False` is + returned. + + Args: + graph: The graph to operate on. + begin: The start of the DFS. + end: The node that should be located. + """ - # Rule 8: If degree larger than one then it is used within the state. - if state.out_degree(transient) > 1: - return True + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.dst for edge in graph.out_edges(node)) - # Now we check if it is used in a different state. - return name in shared_sdfg_transients + to_visit: List[nodes.Node] = [begin] + seen: Set[nodes.Node] = set() - def partition_first_outputs( + while len(to_visit) > 0: + node: nodes.Node = to_visit.pop() + if node == end: + return True + elif node not in seen: + to_visit.extend(next_nodes(node)) + seen.add(node) + + # We never found `end` + return False + + def get_access_set( self, - state: dace.SDFGState, - sdfg: dace.SDFG, - map_exit_1: dace_nodes.MapExit, - map_entry_2: dace_nodes.MapEntry, - ) -> Union[ - tuple[ - set[dace_graph.MultiConnectorEdge[dace.Memlet]], - set[dace_graph.MultiConnectorEdge[dace.Memlet]], - set[dace_graph.MultiConnectorEdge[dace.Memlet]], - ], - None, - ]: - """Partition the output edges of `map_exit_1` for serial map fusion. - - The output edges of the first map are partitioned into three distinct sets, - defined as follows: - - - Pure Output Set `\mathbb{P}`: - These edges exits the first map and does not enter the second map. These - outputs will be simply be moved to the output of the second map. - - Exclusive Intermediate Set `\mathbb{E}`: - Edges in this set leaves the first map exit, enters an access node, from - where a Memlet then leads immediately to the second map. The memory - referenced by this access node is not used anywhere else, thus it can - be removed. - - Shared Intermediate Set `\mathbb{S}`: - These edges are very similar to the one in `\mathbb{E}` except that they - are used somewhere else, thus they can not be removed and have to be - recreated as output of the second map. - - Returns: - If such a decomposition exists the function will return the three sets - mentioned above in the same order. - In case the decomposition does not exist, i.e. the maps can not be fused - the function returns `None`. + scope_node: Union[nodes.MapEntry, nodes.MapExit], + state: SDFGState, + ) -> Set[nodes.AccessNode]: + """Computes the access set of a "scope node". + + If `scope_node` is a `MapEntry` it will operate on the set of incoming edges + and if it is an `MapExit` on the set of outgoing edges. The function will + then determine all access nodes that have a connection through these edges + to the scope nodes (edges that does not lead to access nodes are ignored). + The function returns a set that contains all access nodes that were found. + It is important that this set will also contain views. Args: - state: The in which the two maps are located. - sdfg: The full SDFG in whcih we operate. - map_exit_1: The exit node of the first map. - map_entry_2: The entry node of the second map. + scope_node: The scope node that should be evaluated. + state: The state in which we operate. """ - # The three outputs set. - pure_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - exclusive_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - shared_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() + if isinstance(scope_node, nodes.MapEntry): + get_edges = lambda node: state.in_edges(node) # noqa: E731 + other_node = lambda e: e.src # noqa: E731 + else: + get_edges = lambda node: state.out_edges(node) # noqa: E731 + other_node = lambda e: e.dst # noqa: E731 + access_set: Set[nodes.AccessNode] = { + node + for node in map(other_node, get_edges(scope_node)) + if isinstance(node, nodes.AccessNode) + } - # Set of intermediate nodes that we have already processed. - processed_inter_nodes: set[dace_nodes.Node] = set() + return access_set - # Now scan all output edges of the first exit and classify them - for out_edge in state.out_edges(map_exit_1): - intermediate_node: dace_nodes.Node = out_edge.dst + def find_subsets( + self, + node: nodes.AccessNode, + scope_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + repl_dict: Optional[Dict[str, str]], + ) -> List[subsets.Subset]: + """Finds all subsets that access `node` within `scope_node`. + + The function will not start a search for all consumer/producers. + Instead it will locate the edges which is immediately inside the + map scope. - # We already processed the node, this should indicate that we should - # run simplify again, or we should start implementing this case. - if intermediate_node in processed_inter_nodes: - return None - processed_inter_nodes.add(intermediate_node) - - # Now let's look at all nodes that are downstream of the intermediate node. - # This, among other things, will tell us, how we have to handle this node. - downstream_nodes = util.all_nodes_between( - graph=state, - begin=intermediate_node, - end=map_entry_2, + Args: + node: The access node that should be examined. + scope_node: We are only interested in data that flows through this node. + state: The state in which we operate. + sdfg: The SDFG object. + """ + + # Is the node used for reading or for writing. + # This influences how we have to proceed. + if isinstance(scope_node, nodes.MapEntry): + outer_edges_to_inspect = [e for e in state.in_edges(scope_node) if e.src == node] + get_subset = lambda e: e.data.src_subset # noqa: E731 + get_inner_edges = lambda e: state.out_edges_by_connector( + scope_node, "OUT_" + e.dst_conn[3:] + ) + else: + outer_edges_to_inspect = [e for e in state.out_edges(scope_node) if e.dst == node] + get_subset = lambda e: e.data.dst_subset # noqa: E731 + get_inner_edges = lambda e: state.in_edges_by_connector( + scope_node, "IN_" + e.src_conn[4:] ) - # If `downstream_nodes` is `None` this means that `map_entry_2` was never - # reached, thus `intermediate_node` does not enter the second map and - # the node is a pure output node. - if downstream_nodes is None: - pure_outputs.add(out_edge) - continue + found_subsets: List[subsets.Subset] = [] + for edge in outer_edges_to_inspect: + found_subsets.extend(get_subset(e) for e in get_inner_edges(edge)) + assert len(found_subsets) > 0, "Could not find any subsets." + assert not any(subset is None for subset in found_subsets) - # The following tests are _after_ we have determined if we have a pure - # output node, because this allows us to handle more exotic pure node - # cases, as handling them is essentially rerouting an edge, whereas - # handling intermediate nodes is much more complicated. + found_subsets = copy.deepcopy(found_subsets) + if repl_dict: + for subset in found_subsets: + # Replace happens in place + symbolic.safe_replace(repl_dict, subset.replace) - # Empty Memlets are only allowed if they are in `\mathbb{P}`, which - # is also the only place they really make sense (for a map exit). - # Thus if we now found an empty Memlet we reject it. - if out_edge.data.is_empty(): - return None + return found_subsets - # In case the intermediate has more than one entry, all must come from the - # first map, otherwise we can not fuse them. Currently we restrict this - # even further by saying that it has only one incoming Memlet. - if state.in_degree(intermediate_node) != 1: - return None + def is_view( + self, + node: nodes.AccessNode, + sdfg: SDFG, + ) -> bool: + """Tests if `node` points to a view or not.""" + node_desc: data.Data = node.desc(sdfg) + return isinstance(node_desc, data.View) - # It can happen that multiple edges converges at the `IN_` connector - # of the first map exit, but there is only one edge leaving the exit. - # It is complicate to handle this, so for now we ignore it. - # TODO(phimuell): Handle this case properly. - inner_collector_edges = list( - state.in_edges_by_connector(intermediate_node, "IN_" + out_edge.src_conn[3:]) - ) - if len(inner_collector_edges) > 1: - return None + def track_view( + self, + view: nodes.AccessNode, + state: SDFGState, + sdfg: SDFG, + ) -> nodes.AccessNode: + """Find the original data of a View. - # For us an intermediate node must always be an access node, because - # everything else we do not know how to handle. It is important that - # we do not test for non transient data here, because they can be - # handled has shared intermediates. - if not isinstance(intermediate_node, dace_nodes.AccessNode): - return None - intermediate_desc: dace_data.Data = intermediate_node.desc(sdfg) - if isinstance(intermediate_desc, dace_data.View): - return None + Given the View `view`, the function will trace the view back to the original + access node. For convenience, if `view` is not a `View` the argument will be + returned. - # There are some restrictions we have on intermediate dace_nodes. The first one - # is that we do not allow WCR, this is because they need special handling - # which is currently not implement (the DaCe transformation has this - # restriction as well). The second one is that we can reduce the - # intermediate node and only feed a part into the second map, consider - # the case `b = a + 1; return b + 2`, where we have arrays. In this - # example only a single element must be available to the second map. - # However, this is hard to check so we will make a simplification. - # First, we will not check it at the producer, but at the consumer point. - # There we assume if the consumer does _not consume the whole_ - # intermediate array, then we can decompose the intermediate, by setting - # the map iteration index to zero and recover the shape, see - # implementation in the actual fusion routine. - # This is an assumption that is in most cases correct, but not always. - # However, doing it correctly is extremely complex. - for _, produce_edge in util.find_upstream_producers(state, out_edge): - if produce_edge.data.wcr is not None: - return None - - if len(downstream_nodes) == 0: - # There is nothing between intermediate node and the entry of the - # second map, thus the edge belongs either in `\mathbb{S}` or - # `\mathbb{E}`. - - # This is a very special situation, i.e. the access node has many - # different connections to the second map entry, this is a special - # case that we do not handle. - # TODO(phimuell): Handle this case. - if state.out_degree(intermediate_node) != 1: - return None - - # Certain nodes need more than one element as input. As explained - # above, in this situation we assume that we can naturally decompose - # them iff the node does not consume that whole intermediate. - # Furthermore, it can not be a dynamic map range or a library node. - intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) - consumers = util.find_downstream_consumers(state=state, begin=intermediate_node) - for consumer_node, feed_edge in consumers: - # TODO(phimuell): Improve this approximation. - if ( - intermediate_size != 1 - ) and feed_edge.data.num_elements() == intermediate_size: - return None - if consumer_node is map_entry_2: # Dynamic map range. - return None - if isinstance(consumer_node, dace_nodes.LibraryNode): - # TODO(phimuell): Allow some library dace_nodes. - return None - - # Note that "remove" has a special meaning here, regardless of the - # output of the check function, from within the second map we remove - # the intermediate, it has more the meaning of "do we need to - # reconstruct it after the second map again?" - if self.is_interstate_transient(intermediate_node, sdfg, state): - shared_outputs.add(out_edge) - else: - exclusive_outputs.add(out_edge) - continue + Args: + view: The view that should be traced. + state: The state in which we operate. + sdfg: The SDFG on which we operate. + """ - else: - # There is not only a single connection from the intermediate node to - # the second map, but the intermediate has more connections, thus - # the node might belong to the shared output. Of the many different - # possibilities, we only consider a single case: - # - The intermediate has a single connection to the second map, that - # fulfills the restriction outlined above. - # - All other connections have no connection to the second map. - found_second_entry = False - intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) - for edge in state.out_edges(intermediate_node): - if edge.dst is map_entry_2: - if found_second_entry: # The second map was found again. - return None - found_second_entry = True - consumers = util.find_downstream_consumers(state=state, begin=edge) - for consumer_node, feed_edge in consumers: - if feed_edge.data.num_elements() == intermediate_size: - return None - if consumer_node is map_entry_2: # Dynamic map range - return None - if isinstance(consumer_node, dace_nodes.LibraryNode): - # TODO(phimuell): Allow some library dace_nodes. - return None - else: - # Ensure that there is no path that leads to the second map. - after_intermdiate_node = util.all_nodes_between( - graph=state, begin=edge.dst, end=map_entry_2 - ) - if after_intermdiate_node is not None: - return None - # If we are here, then we know that the node is a shared output - shared_outputs.add(out_edge) - continue + # Test if it is a view at all, if not return the passed node as source. + if not self.is_view(view, sdfg): + return view + + # First determine if the view is used for reading or writing. + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") + if curr_edge.dst_conn == "views": + # The view is used for reading. + next_node = lambda curr_edge: curr_edge.src # noqa: E731 + elif curr_edge.src_conn == "views": + # The view is used for writing. + next_node = lambda curr_edge: curr_edge.dst # noqa: E731 + else: + raise RuntimeError( + f"Failed to determine the direction of the view '{view}' | {curr_edge}." + ) - assert exclusive_outputs or shared_outputs or pure_outputs - assert len(processed_inter_nodes) == sum( - len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs] - ) - return (pure_outputs, exclusive_outputs, shared_outputs) + # Now trace the view back. + org_view = view + view = next_node(curr_edge) + while self.is_view(view, sdfg): + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"View tracing of '{org_view}' failed at note '{view}'.") + view = next_node(curr_edge) + return view diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_parallel.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_parallel.py new file mode 100644 index 0000000000..19412b9dfa --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_parallel.py @@ -0,0 +1,170 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +"""Implements the parallel map fusing transformation. + +THIS FILE WAS COPIED FROM DACE TO FACILITATE DEVELOPMENT UNTIL THE PR#1625 IN +DACE IS MERGED AND THE VERSION WAS UPGRADED. +""" + +from typing import Any, Optional, Set, Union + +import dace +from dace import properties, transformation +from dace.sdfg import SDFG, SDFGState, graph, nodes + +from . import map_fusion_helper as mfh + + +@properties.make_properties +class MapFusionParallel(mfh.MapFusionHelper): + """The `MapFusionParallel` transformation allows to merge two parallel maps. + + While the `MapFusionSerial` transformation fuses maps that are sequentially + connected through an intermediate node, this transformation is able to fuse any + two maps that are not sequential and in the same scope. + + Args: + only_if_common_ancestor: Only perform fusion if both Maps share at least one + node as direct ancestor. This will increase the locality of the merge. + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + apply_fusion_callback: A user supplied function, same signature as `can_be_fused()`, + to check if a fusion should be performed. + + Note: + This transformation only matches the entry nodes of the Map, but will also + modify the exit nodes of the Maps. + """ + + map_entry_1 = transformation.transformation.PatternNode(nodes.MapEntry) + map_entry_2 = transformation.transformation.PatternNode(nodes.MapEntry) + + only_if_common_ancestor = properties.Property( + dtype=bool, + default=False, + allow_none=False, + desc="Only perform fusing if the Maps share a node as parent.", + ) + + def __init__( + self, + only_if_common_ancestor: Optional[bool] = None, + **kwargs: Any, + ) -> None: + if only_if_common_ancestor is not None: + self.only_if_common_ancestor = only_if_common_ancestor + super().__init__(**kwargs) + + @classmethod + def expressions(cls) -> Any: + # This just matches _any_ two Maps inside a state. + state = graph.OrderedMultiDiConnectorGraph() + state.add_nodes_from([cls.map_entry_1, cls.map_entry_2]) + return [state] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Checks if the fusion can be done. + + The function checks the general fusing conditions and if the maps are parallel. + """ + map_entry_1: nodes.MapEntry = self.map_entry_1 + map_entry_2: nodes.MapEntry = self.map_entry_2 + + # Check the structural properties of the maps, this will also ensure that + # the two maps are in the same scope and the parameters can be renamed + if not self.can_be_fused( + map_entry_1=map_entry_1, + map_entry_2=map_entry_2, + graph=graph, + sdfg=sdfg, + permissive=permissive, + ): + return False + + # Since the match expression matches any two Maps, we have to ensure that + # the maps are parallel. The `can_be_fused()` function already verified + # if they are in the same scope. + if not self.is_parallel(graph=graph, node1=map_entry_1, node2=map_entry_2): + return False + + # Test if they have they share a node as direct ancestor. + if self.only_if_common_ancestor: + # This assumes that there is only one access node per data container in the state. + ancestors_1: Set[nodes.Node] = {e1.src for e1 in graph.in_edges(map_entry_1)} + if not any(e2.src in ancestors_1 for e2 in graph.in_edges(map_entry_2)): + return False + + return True + + def is_parallel( + self, + graph: SDFGState, + node1: nodes.Node, + node2: nodes.Node, + ) -> bool: + """Tests if `node1` and `node2` are parallel. + + The nodes are parallel if `node2` can not be reached from `node1` and vice versa. + + Args: + graph: The graph to traverse. + node1: The first node to check. + node2: The second node to check. + """ + + # In order to be parallel they must be in the same scope. + scope = graph.scope_dict() + if scope[node1] != scope[node2]: + return False + + # The `all_nodes_between()` function traverse the graph and returns `None` if + # `end` was not found. We have to call it twice, because we do not know + # which node is upstream if they are not parallel. + if self.is_node_reachable_from(graph=graph, begin=node1, end=node2): + return False + elif self.is_node_reachable_from(graph=graph, begin=node2, end=node1): + return False + return True + + def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: + """Performs the Map fusing. + + Essentially, the function relocate all edges from the scope nodes (`MapEntry` + and `MapExit`) of the second map to the scope nodes of the first map. + """ + + map_entry_1: nodes.MapEntry = self.map_entry_1 + map_exit_1: nodes.MapExit = graph.exit_node(map_entry_1) + map_entry_2: nodes.MapEntry = self.map_entry_2 + map_exit_2: nodes.MapExit = graph.exit_node(map_entry_2) + + # Before we do anything we perform the renaming. + self.rename_map_parameters( + first_map=map_entry_1.map, + second_map=map_entry_2.map, + second_map_entry=map_entry_2, + state=graph, + ) + + for to_node, from_node in zip((map_entry_1, map_exit_1), (map_entry_2, map_exit_2)): + self.relocate_nodes( + from_node=from_node, + to_node=to_node, + state=graph, + sdfg=sdfg, + ) + # The relocate function does not remove the node, so we must do it. + graph.remove_node(from_node) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_serial.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_serial.py new file mode 100644 index 0000000000..2cdcc455d4 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_serial.py @@ -0,0 +1,1007 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +"""Implements the serial map fusing transformation. + +THIS FILE WAS COPIED FROM DACE TO FACILITATE DEVELOPMENT UNTIL THE PR#1625 IN +DACE IS MERGED AND THE VERSION WAS UPGRADED. +""" + +import copy +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import dace +from dace import data, dtypes, properties, subsets, symbolic, transformation +from dace.sdfg import SDFG, SDFGState, graph, nodes + +from . import map_fusion_helper as mfh + + +@properties.make_properties +class MapFusionSerial(mfh.MapFusionHelper): + """Fuse two serial maps together. + + The transformation combines two maps into one that are connected through some + access nodes. Conceptually this transformation removes the exit of the first + or upper map and the entry of the lower or second map and then rewrites the + connections appropriately. Depending on the situation the transformation will + either fully remove or make the intermediate a new output of the second map. + + By default, the transformation does not use the strict data flow mode, see + `MapFusionHelper` for more, however, it might be useful in come cases to enable + it, especially in the context of DaCe's auto optimizer. + + Args: + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + strict_dataflow: If `True`, the transformation ensures a more + stricter version of the data flow. + apply_fusion_callback: A user supplied function, same signature as `can_be_fused()`, + to check if a fusion should be performed. + + Notes: + - This transformation modifies more nodes than it matches. + - After the transformation has been applied simplify should be run to remove + some dead data flow, that was introduced to ensure validity. + - A `MapFusionSerial` object can be initialized and be reused. However, + after new access nodes are added to any state, it is no longer valid + to use the object. + + Todo: + - Consider the case that only shared nodes are created (thus no inspection of + the graph is needed) and make all shared. Then use the dead dataflow + elimination transformation to get rid of the ones we no longer need. + - Increase the applicability. + """ + + map_exit_1 = transformation.transformation.PatternNode(nodes.MapExit) + intermediate_access_node = transformation.transformation.PatternNode(nodes.AccessNode) + map_entry_2 = transformation.transformation.PatternNode(nodes.MapEntry) + + def __init__( + self, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + @classmethod + def expressions(cls) -> Any: + """Get the match expression. + + The transformation matches the exit node of the top Map that is connected to + an access node that again is connected to the entry node of the second Map. + An important note is, that the transformation operates not just on the + matched nodes, but more or less on anything that has an incoming connection + from the first Map or an outgoing connection to the second Map entry. + """ + return [ + dace.sdfg.utils.node_path_graph( + cls.map_exit_1, cls.intermediate_access_node, cls.map_entry_2 + ) + ] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Tests if the matched Maps can be merged. + + The two Maps are mergeable iff: + - Satisfies general requirements, see `MapFusionHelper.can_be_fused()`. + - Tests if the decomposition exists. + - Tests if there are read write dependencies. + """ + map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit_1) + map_exit_1: nodes.MapExit = self.map_exit_1 + map_entry_2: nodes.MapEntry = self.map_entry_2 + + # This essentially test the structural properties of the two Maps. + if not self.can_be_fused( + map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg + ): + return False + + # Test for read-write conflicts + if self.has_read_write_dependency( + map_entry_1=map_entry_1, + map_entry_2=map_entry_2, + state=graph, + sdfg=sdfg, + ): + return False + + # Two maps can be serially fused if the node decomposition exists and + # at least one of the intermediate output sets is not empty. The state + # of the pure outputs is irrelevant for serial map fusion. + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + ) + if output_partition is None: + return False + _, exclusive_outputs, shared_outputs = output_partition + if not (exclusive_outputs or shared_outputs): + return False + return True + + def has_read_write_dependency( + self, + map_entry_1: nodes.MapEntry, + map_entry_2: nodes.MapEntry, + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """Test if there is a read write dependency between the two maps to be fused. + + The function checks two different things. + - The function will make sure that there is no read write dependency between + the input and output of the fused maps. For that it will inspect the + respective subsets. + - The second part partially checks the intermediate nodes, it mostly ensures + that there are not views and that they are not used as inputs or outputs + at the same time. However, the function will not check for read write + conflicts in this set, this is done in the partition function. + + Returns: + `True` if there is a conflict between the maps that can not be handled. + If there is no conflict or if the conflict can be handled `False` + is returned. + + Args: + map_entry_1: The entry node of the first map. + map_entry_2: The entry node of the second map. + state: The state on which we operate. + sdfg: The SDFG on which we operate. + """ + map_exit_1: nodes.MapExit = state.exit_node(map_entry_1) + map_exit_2: nodes.MapExit = state.exit_node(map_entry_2) + + # Get the read and write sets of the different maps, note that Views + # are not resolved yet. + access_sets: List[Dict[str, nodes.AccessNode]] = [] + for scope_node in [map_entry_1, map_exit_1, map_entry_2, map_exit_2]: + access_set: Set[nodes.AccessNode] = self.get_access_set(scope_node, state) + access_sets.append({node.data: node for node in access_set}) + # If two different access nodes of the same scoping node refers to the + # same data, then we consider this as a dependency we can not handle. + # It is only a problem for the intermediate nodes and might be possible + # to handle, but doing so is hard, so we just forbid it. + if len(access_set) != len(access_sets[-1]): + return True + read_map_1, write_map_1, read_map_2, write_map_2 = access_sets + + # It might be possible that there are views, so we have to resolve them. + # We also already get the name of the data container. + # Note that `len(real_read_map_1) <= len(read_map_1)` holds because of Views. + resolved_sets: List[Set[str]] = [] + for unresolved_set in [read_map_1, write_map_1, read_map_2, write_map_2]: + resolved_sets.append( + { + self.track_view(node, state, sdfg).data + if self.is_view(node, sdfg) + else node.data + for node in unresolved_set.values() + } + ) + # If the resolved and unresolved names do not have the same length. + # Then different views point to the same location, which we forbid + if len(unresolved_set) != len(resolved_sets[-1]): + return False + real_read_map_1, real_write_map_1, real_read_map_2, real_write_map_2 = resolved_sets + + # We do not allow that the first and second map each write to the same data. + if not real_write_map_1.isdisjoint(real_write_map_2): + return True + + # If there is no overlap in what is (totally) read and written, there will be no conflict. + # This must come before the check of disjoint write. + if (real_read_map_1 | real_read_map_2).isdisjoint(real_write_map_1 | real_write_map_2): + return False + + # These are the names (unresolved) and the access nodes of the data that is used + # to transmit information between the maps. The partition function ensures that + # these nodes are directly connected to the two maps. + exchange_names: Set[str] = set(write_map_1.keys()).intersection(read_map_2.keys()) + exchange_nodes: Set[nodes.AccessNode] = set(write_map_1.values()).intersection( + read_map_2.values() + ) + + # If the number are different then a data is accessed through multiple nodes. + if len(exchange_names) != len(exchange_nodes): + return True + assert all(exchange_node.data in exchange_names for exchange_node in exchange_nodes) + + # For simplicity we assume that the nodes used for exchange are not views. + if any(self.is_view(exchange_node, sdfg) for exchange_node in exchange_nodes): + return True + + # This is the names of the node that are used as input of the first map and + # as output of the second map. We have to ensure that there is no data + # dependency between these nodes. + fused_inout_data_names: Set[str] = set(read_map_1.keys()).intersection(write_map_2.keys()) + + # If a data container is used as input and output then it can not be a view (simplicity) + if any(self.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names): + return True + + # A data container can be used as input and output. But we do not allow that + # it is also used as intermediate or exchange data. This is an important check. + if not fused_inout_data_names.isdisjoint(exchange_names): + return True + + # Get the replacement dict for changing the map variables from the subsets of + # the second map. + repl_dict = self.find_parameter_remapping(map_entry_1.map, map_exit_2.map) + + # Now we inspect if there is a read write dependency, between data that is + # used as input and output of the fused map. There is no problem is they + # are pointwise, i.e. in each iteration the same locations are accessed. + # Essentially they all boil down to `a += 1`. + for inout_data_name in fused_inout_data_names: + all_subsets: List[subsets.Subset] = [] + # The subsets that define reading are given by the first map's entry node + all_subsets.extend( + self.find_subsets( + node=read_map_1[inout_data_name], + scope_node=map_entry_1, + state=state, + sdfg=sdfg, + repl_dict=None, + ) + ) + # While the subsets defining writing are given by the second map's exit + # node, there we also have to apply renaming. + all_subsets.extend( + self.find_subsets( + node=write_map_2[inout_data_name], + scope_node=map_exit_2, + state=state, + sdfg=sdfg, + repl_dict=repl_dict, + ) + ) + # Now we can test if these subsets are point wise + if not self.test_if_subsets_are_point_wise(all_subsets): + return True + + # No read write dependency was found. + return False + + def test_if_subsets_are_point_wise(self, subsets_to_check: List[subsets.Subset]) -> bool: + """Point wise means that they are all the same. + + If a series of subsets are point wise it means that all Memlets, access + the same data. This is an important property because the whole map fusion + is build upon this. + If the subsets originates from different maps, then they must have been + renamed. + + Args: + subsets_to_check: The list of subsets that should be checked. + """ + assert len(subsets_to_check) > 1 + + # We will check everything against the master subset. + master_subset = subsets_to_check[0] + for ssidx in range(1, len(subsets_to_check)): + subset = subsets_to_check[ssidx] + if isinstance(subset, subsets.Indices): + subset = subsets.Range.from_indices(subset) + # Do we also need the reverse? See below why. + if any(r != (0, 0, 1) for r in subset.offset_new(master_subset, negative=True)): + return False + else: + # The original code used `Range.offset` here, but that one had trouble + # for `r1 = 'j, 0:10'` and `r2 = 'j, 0`. The solution would be to test + # symmetrically, i.e. `r1 - r2` and `r2 - r1`. However, if we would + # have `r2_1 = 'j, 0:10'` it consider it as failing, which is not + # what we want. Thus we will use symmetric cover. + if not master_subset.covers(subset): + return False + if not subset.covers(master_subset): + return False + + # All subsets are equal to the master subset, thus they are equal to each other. + # This means that the data accesses, described by this transformation is + # point wise + return True + + def compute_offset_subset( + self, + original_subset: subsets.Range, + intermediate_desc: data.Data, + map_params: List[str], + producer_offset: Optional[subsets.Range] = None, + ) -> subsets.Range: + """Computes the memlet to correct read and writes of the intermediate. + + Args: + original_subset: The original subset that was used to write into the + intermediate, must be renamed to the final map parameter. + intermediate_desc: The original intermediate data descriptor. + map_params: The parameter of the final map. + """ + assert not isinstance(intermediate_desc, data.View) + final_offset: subsets.Range = None + if isinstance(intermediate_desc, data.Scalar): + final_offset = subsets.Range.from_string("0") + + elif isinstance(intermediate_desc, data.Array): + basic_offsets = original_subset.min_element() + offset_list = [] + for d in range(original_subset.dims()): + d_range = subsets.Range([original_subset[d]]) + if d_range.free_symbols.intersection(map_params): + offset_list.append(d_range[0]) + else: + offset_list.append((basic_offsets[d], basic_offsets[d], 1)) + final_offset = subsets.Range(offset_list) + + else: + raise TypeError( + f"Does not know how to compute the subset offset for '{type(intermediate_desc).__name__}'." + ) + + if producer_offset is not None: + # Here we are correcting some parts that over approximate (which partially + # does under approximate) might screw up. Consider two maps, the first + # map only writes the subset `[:, 2:6]`, thus the new intermediate will + # have shape `(1, 4)`. Now also imagine that the second map only reads + # the elements `[:, 3]`. From this we see that we can only correct the + # consumer side if we also take the producer side into consideration! + # See also the `transformations/mapfusion_test.py::test_offset_correction_*` + # tests for more. + final_offset.offset( + final_offset.offset_new( + producer_offset, + negative=True, + ), + negative=True, + ) + return final_offset + + def partition_first_outputs( + self, + state: SDFGState, + sdfg: SDFG, + map_exit_1: nodes.MapExit, + map_entry_2: nodes.MapEntry, + ) -> Union[ + Tuple[ + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + ], + None, + ]: + """Partition the output edges of `map_exit_1` for serial map fusion. + + The output edges of the first map are partitioned into three distinct sets, + defined as follows: + - Pure Output Set `\mathbb{P}`: + These edges exits the first map and does not enter the second map. These + outputs will be simply be moved to the output of the second map. + - Exclusive Intermediate Set `\mathbb{E}`: + Edges in this set leaves the first map exit, enters an access node, from + where a Memlet then leads immediately to the second map. The memory + referenced by this access node is not used anywhere else, thus it can + be removed. + - Shared Intermediate Set `\mathbb{S}`: + These edges are very similar to the one in `\mathbb{E}` except that they + are used somewhere else, thus they can not be removed and have to be + recreated as output of the second map. + + If strict data flow mode is enabled the function is rather strict if an + output can be added to either intermediate set and might fail to compute + the partition, even if it would exist. + + Returns: + If such a decomposition exists the function will return the three sets + mentioned above in the same order. + In case the decomposition does not exist, i.e. the maps can not be fused + the function returns `None`. + + Args: + state: The in which the two maps are located. + sdfg: The full SDFG in whcih we operate. + map_exit_1: The exit node of the first map. + map_entry_2: The entry node of the second map. + """ + # The three outputs set. + pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + + # Compute the renaming that for translating the parameter of the _second_ + # map to the ones used by the first map. + repl_dict: Dict[str, str] = self.find_parameter_remapping( # type: ignore[assignment] + first_map=map_exit_1.map, + second_map=map_entry_2.map, + ) + assert repl_dict is not None + + # Set of intermediate nodes that we have already processed. + processed_inter_nodes: Set[nodes.Node] = set() + + # These are the data that is written to multiple times in _this_ state. + # If a data is written to multiple time in a state, it could be + # classified as shared. However, it might happen that the node has zero + # degree. This is not a problem as the maps also induced a before-after + # relationship. But some DaCe transformations do not catch this. + # Thus we will never modify such intermediate nodes and fail instead. + if self.strict_dataflow: + multi_write_data: Set[str] = self._compute_multi_write_data(state, sdfg) + else: + multi_write_data = set() + + # Now scan all output edges of the first exit and classify them + for out_edge in state.out_edges(map_exit_1): + intermediate_node: nodes.Node = out_edge.dst + + # We already processed the node, this should indicate that we should + # run simplify again, or we should start implementing this case. + # TODO(phimuell): Handle this case, already partially handled here. + if intermediate_node in processed_inter_nodes: + return None + processed_inter_nodes.add(intermediate_node) + + # The intermediate can only have one incoming degree. It might be possible + # to handle multiple incoming edges, if they all come from the top map. + # However, the resulting SDFG might be invalid. + # NOTE: Allow this to happen (under certain cases) if the only producer + # is the top map. + if state.in_degree(intermediate_node) != 1: + return None + + # If the second map is not reachable from the intermediate node, then + # the output is pure and we can end here. + if not self.is_node_reachable_from( + graph=state, + begin=intermediate_node, + end=map_entry_2, + ): + pure_outputs.add(out_edge) + continue + + # The following tests are _after_ we have determined if we have a pure + # output node, because this allows us to handle more exotic pure node + # cases, as handling them is essentially rerouting an edge, whereas + # handling intermediate nodes is much more complicated. + + # For us an intermediate node must always be an access node, because + # everything else we do not know how to handle. It is important that + # we do not test for non transient data here, because they can be + # handled has shared intermediates. + if not isinstance(intermediate_node, nodes.AccessNode): + return None + if self.is_view(intermediate_node, sdfg): + return None + + # Checks if the intermediate node refers to data that is accessed by + # _other_ access nodes in _this_ state. If this is the case then never + # touch this intermediate node. + # TODO(phimuell): Technically it would be enough to turn the node into + # a shared output node, because this will still fulfil the dependencies. + # However, some DaCe transformation can not handle this properly, so we + # are _forced_ to reject this node. + if intermediate_node.data in multi_write_data: + return None + + # Empty Memlets are only allowed if they are in `\mathbb{P}`, which + # is also the only place they really make sense (for a map exit). + # Thus if we now found an empty Memlet we reject it. + if out_edge.data.is_empty(): + return None + + # It can happen that multiple edges converges at the `IN_` connector + # of the first map exit, but there is only one edge leaving the exit. + # It is complicate to handle this, so for now we ignore it. + # TODO(phimuell): Handle this case properly. + # To handle this we need to associate a consumer edge (the outgoing edges + # of the second map) with exactly one producer. + producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list( + state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) + ) + if len(producer_edges) > 1: + return None + + # Now check the constraints we have on the producers. + # - The source of the producer can not be a view (we do not handle this) + # - The edge shall also not be a reduction edge. + # - Defined location to where they write. + # - No dynamic Memlets. + # Furthermore, we will also extract the subsets, i.e. the location they + # modify inside the intermediate array. + # Since we do not allow for WCR, we do not check if the producer subsets intersects. + producer_subsets: List[subsets.Subset] = [] + for producer_edge in producer_edges: + if isinstance(producer_edge.src, nodes.AccessNode) and self.is_view( + producer_edge.src, sdfg + ): + return None + if producer_edge.data.dynamic: + return None + if producer_edge.data.wcr is not None: + return None + if producer_edge.data.dst_subset is None: + return None + producer_subsets.append(producer_edge.data.dst_subset) + + # Check if the producer do not intersect + if len(producer_subsets) == 1: + pass + elif len(producer_subsets) == 2: + if producer_subsets[0].intersects(producer_subsets[1]): + return None + else: + for i, psbs1 in enumerate(producer_subsets): + for j, psbs2 in enumerate(producer_subsets): + if i == j: + continue + if psbs1.intersects(psbs2): + return None + + # Now we determine the consumer of nodes. For this we are using the edges + # leaves the second map entry. It is not necessary to find the actual + # consumer nodes, as they might depend on symbols of nested Maps. + # For the covering test we only need their subsets, but we will perform + # some scan and filtering on them. + found_second_map = False + consumer_subsets: List[subsets.Subset] = [] + for intermediate_node_out_edge in state.out_edges(intermediate_node): + # If the second map entry is not immediately reachable from the intermediate + # node, then ensure that there is not path that goes to it. + if intermediate_node_out_edge.dst is not map_entry_2: + if self.is_node_reachable_from( + graph=state, begin=intermediate_node_out_edge.dst, end=map_entry_2 + ): + return None + continue + + # Ensure that the second map is found exactly once. + # TODO(phimuell): Lift this restriction. + if found_second_map: + return None + found_second_map = True + + # The output of the top map can not define a dynamic map range in the + # second map. + if not intermediate_node_out_edge.dst_conn.startswith("IN_"): + return None + + # Now we look at all edges that leave the second map entry, i.e. the + # edges that feeds the consumer and define what is read inside the map. + # We do not check them, but collect them and inspect them. + # NOTE: The subset still uses the old iteration variables. + for inner_consumer_edge in state.out_edges_by_connector( + map_entry_2, "OUT_" + intermediate_node_out_edge.dst_conn[3:] + ): + if inner_consumer_edge.data.src_subset is None: + return None + if inner_consumer_edge.data.dynamic: + # TODO(phimuell): Is this restriction necessary, I am not sure. + return None + consumer_subsets.append(inner_consumer_edge.data.src_subset) + assert ( + found_second_map + ), f"Found '{intermediate_node}' which looked like a pure node, but is not one." + assert len(consumer_subsets) != 0 + + # The consumer still uses the original symbols of the second map, so we must rename them. + if repl_dict: + consumer_subsets = copy.deepcopy(consumer_subsets) + for consumer_subset in consumer_subsets: + symbolic.safe_replace( + mapping=repl_dict, replace_callback=consumer_subset.replace + ) + + # Now we are checking if a single iteration of the first (top) map + # can satisfy all data requirements of the second (bottom) map. + # For this we look if the producer covers the consumer. A consumer must + # be covered by exactly one producer. + for consumer_subset in consumer_subsets: + nb_coverings = sum( + producer_subset.covers(consumer_subset) for producer_subset in producer_subsets + ) + if nb_coverings != 1: + return None + + # After we have ensured coverage, we have to decide if the intermediate + # node can be removed (`\mathbb{E}`) or has to be restored (`\mathbb{S}`). + # Note that "removed" here means that it is reconstructed by a new + # output of the second map. + if self.is_shared_data(intermediate_node, sdfg): + # The intermediate data is used somewhere else, either in this or another state. + shared_outputs.add(out_edge) + else: + # The intermediate can be removed, as it is not used anywhere else. + exclusive_outputs.add(out_edge) + + assert len(processed_inter_nodes) == sum( + len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs] + ) + return (pure_outputs, exclusive_outputs, shared_outputs) + + def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: + """Performs the serial Map fusing. + + The function first computes the map decomposition and then handles the + three sets. The pure outputs are handled by `relocate_nodes()` while + the two intermediate sets are handled by `handle_intermediate_set()`. + + By assumption we do not have to rename anything. + + Args: + graph: The SDFG state we are operating on. + sdfg: The SDFG we are operating on. + """ + # NOTE: `self.map_*` actually stores the ID of the node. + # once we start adding and removing nodes it seems that their ID changes. + # Thus we have to save them here, this is a known behaviour in DaCe. + assert isinstance(graph, dace.SDFGState) + assert isinstance(self.map_exit_1, nodes.MapExit) + assert isinstance(self.map_entry_2, nodes.MapEntry) + + map_exit_1: nodes.MapExit = self.map_exit_1 + map_entry_2: nodes.MapEntry = self.map_entry_2 + map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry_2) + map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit_1) + + # Before we do anything we perform the renaming. + self.rename_map_parameters( + first_map=map_exit_1.map, + second_map=map_entry_2.map, + second_map_entry=map_entry_2, + state=graph, + ) + + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + ) + assert output_partition is not None # Make MyPy happy. + pure_outputs, exclusive_outputs, shared_outputs = output_partition + + if len(exclusive_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=exclusive_outputs, + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + map_exit_2=map_exit_2, + is_exclusive_set=True, + ) + if len(shared_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=shared_outputs, + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + map_exit_2=map_exit_2, + is_exclusive_set=False, + ) + assert pure_outputs == set(graph.out_edges(map_exit_1)) + if len(pure_outputs) != 0: + self.relocate_nodes( + from_node=map_exit_1, + to_node=map_exit_2, + state=graph, + sdfg=sdfg, + ) + + # Above we have handled the input of the second map and moved them + # to the first map, now we must move the output of the first map + # to the second one, as this one is used. + self.relocate_nodes( + from_node=map_entry_2, + to_node=map_entry_1, + state=graph, + sdfg=sdfg, + ) + + for node_to_remove in [map_exit_1, map_entry_2]: + assert graph.degree(node_to_remove) == 0 + graph.remove_node(node_to_remove) + + # Now turn the second output node into the output node of the first Map. + map_exit_2.map = map_entry_1.map + + def handle_intermediate_set( + self, + intermediate_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]], + state: SDFGState, + sdfg: SDFG, + map_exit_1: nodes.MapExit, + map_entry_2: nodes.MapEntry, + map_exit_2: nodes.MapExit, + is_exclusive_set: bool, + ) -> None: + """This function handles the intermediate sets. + + The function is able to handle both the shared and exclusive intermediate + output set, see `partition_first_outputs()`. The main difference is that + in exclusive mode the intermediate nodes will be fully removed from + the SDFG. While in shared mode the intermediate node will be preserved. + The function assumes that the parameter renaming was already done. + + Args: + intermediate_outputs: The set of outputs, that should be processed. + state: The state in which the map is processed. + sdfg: The SDFG that should be optimized. + map_exit_1: The exit of the first/top map. + map_entry_2: The entry of the second map. + map_exit_2: The exit of the second map. + is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. + + Notes: + Before the transformation the `state` does not have to be valid and + after this function has run the state is (most likely) invalid. + """ + + map_params = map_exit_1.map.params.copy() + + # Now we will iterate over all intermediate edges and process them. + # If not stated otherwise the comments assume that we run in exclusive mode. + for out_edge in intermediate_outputs: + # This is the intermediate node that, that we want to get rid of. + # In shared mode we want to recreate it after the second map. + inter_node: nodes.AccessNode = out_edge.dst + inter_name = inter_node.data + inter_desc = inter_node.desc(sdfg) + inter_shape = inter_desc.shape + + # Now we will determine the shape of the new intermediate. This size of + # this temporary is given by the Memlet that goes into the first map exit. + pre_exit_edges = list( + state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) + ) + if len(pre_exit_edges) != 1: + raise NotImplementedError() + pre_exit_edge = pre_exit_edges[0] + new_inter_shape_raw = symbolic.overapproximate(pre_exit_edge.data.subset.size()) + + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + if not self.strict_dataflow: + squeezed_dims: List[int] = [] # These are the dimensions we removed. + new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate( + zip(new_inter_shape_raw, inter_shape) + ): + if full_dim_size == 1: # Must be kept! + new_inter_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_inter_shape.append(proposed_dim_size) + else: + squeezed_dims = [] + new_inter_shape = list(new_inter_shape_raw) + + # This is the name of the new "intermediate" node that we will create. + # It will only have the shape `new_inter_shape` which is basically its + # output within one Map iteration. + # NOTE: The insertion process might generate a new name. + new_inter_name: str = f"__s{sdfg.node_id(state)}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" + + # Now generate the intermediate data container. + if len(new_inter_shape) == 0: + assert pre_exit_edge.data.subset.num_elements() == 1 + is_scalar = True + new_inter_name, new_inter_desc = sdfg.add_scalar( + new_inter_name, + dtype=inter_desc.dtype, + transient=True, + storage=dtypes.StorageType.Register, + find_new_name=True, + ) + + else: + assert (pre_exit_edge.data.subset.num_elements() > 1) or all( + x == 1 for x in new_inter_shape + ) + is_scalar = False + new_inter_name, new_inter_desc = sdfg.add_transient( + new_inter_name, + shape=new_inter_shape, + dtype=inter_desc.dtype, + find_new_name=True, + storage=dtypes.StorageType.Register, + ) + new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) + + # Get the subset that defined into which part of the old intermediate + # the old output edge wrote to. We need that to adjust the producer + # Memlets, since they now write into the new (smaller) intermediate. + assert pre_exit_edge.data.data == inter_name + assert pre_exit_edge.data.dst_subset is not None + producer_offset = self.compute_offset_subset( + original_subset=pre_exit_edge.data.dst_subset, + intermediate_desc=inter_desc, + map_params=map_params, + ) + + # Memlets have a lot of additional informations, such as dynamic. + # To ensure that we get all of them, we will now copy them and modify + # the one that was originally there. We also hope that propagate will + # set the rest for us correctly. + new_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + new_pre_exit_memlet.data = new_inter_name + new_pre_exit_memlet.dst_subset = subsets.Range.from_array(new_inter_desc) + + # New we will reroute the output Memlet, thus it will no longer pass + # through the Map exit but through the newly created intermediate. + # NOTE: We will delete the previous edge later. + new_pre_exit_edge = state.add_edge( + pre_exit_edge.src, + pre_exit_edge.src_conn, + new_inter_node, + None, + new_pre_exit_memlet, + ) + + # We now handle the MemletTree defined by this edge. + # The newly created edge, only handled the last collection step. + for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children( + include_self=False + ): + producer_edge = producer_tree.edge + + # Associate the (already existing) Memlet with the new data. + # TODO(phimuell): Improve the code below to remove the check. + assert producer_edge.data.data == inter_name + producer_edge.data.data = new_inter_name + + if is_scalar: + producer_edge.data.dst_subset = "0" + elif producer_edge.data.dst_subset is not None: + # Since we now write into a smaller memory patch, we must + # compensate for that. We do this by substracting where the write + # originally had begun. + producer_edge.data.dst_subset.offset(producer_offset, negative=True) + producer_edge.data.dst_subset.pop(squeezed_dims) + + # Now after we have handled the input of the new intermediate node, + # we must handle its output. For this we have to "inject" the newly + # created intermediate into the second map. We do this by finding + # the input connectors on the map entry, such that we know where we + # have to reroute inside the Map. + # NOTE: Assumes that map (if connected is the direct neighbour). + conn_names: Set[str] = set() + for inter_node_out_edge in state.out_edges(inter_node): + if inter_node_out_edge.dst == map_entry_2: + assert inter_node_out_edge.dst_conn.startswith("IN_") + conn_names.add(inter_node_out_edge.dst_conn) + else: + # If we found another target than the second map entry from the + # intermediate node it means that the node _must_ survive, + # i.e. we are not in exclusive mode. + assert not is_exclusive_set + + # Now we will reroute the connections inside the second map, i.e. + # instead of consuming the old intermediate node, they will now + # consume the new intermediate node. + for in_conn_name in conn_names: + out_conn_name = "OUT_" + in_conn_name[3:] + + for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): + assert inner_edge.data.data == inter_name # DIRECTION!! + + # As for the producer side, we now read from a smaller array, + # So we must offset them, we use the original edge for this. + assert inner_edge.data.src_subset is not None + consumer_offset = self.compute_offset_subset( + original_subset=inner_edge.data.src_subset, + intermediate_desc=inter_desc, + map_params=map_params, + producer_offset=producer_offset, + ) + + # Now we create a new connection that instead reads from the new + # intermediate, instead of the old one. For this we use the + # old Memlet as template. However it is not fully initialized. + new_inner_memlet = copy.deepcopy(inner_edge.data) + new_inner_memlet.data = new_inter_name + + # Now we replace the edge from the SDFG. + state.remove_edge(inner_edge) + new_inner_edge = state.add_edge( + new_inter_node, + None, + inner_edge.dst, + inner_edge.dst_conn, + new_inner_memlet, + ) + + # Now modifying the Memlet, we do it after the insertion to make + # sure that the Memlet was properly initialized. + if is_scalar: + new_inner_memlet.subset = "0" + elif new_inner_memlet.src_subset is not None: + new_inner_memlet.src_subset.offset(consumer_offset, negative=True) + new_inner_memlet.src_subset.pop(squeezed_dims) + + # Now we have to make sure that all consumers are properly updated. + for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children( + include_self=False + ): + assert consumer_tree.edge.data.data == inter_name + + consumer_edge = consumer_tree.edge + consumer_edge.data.data = new_inter_name + if is_scalar: + consumer_edge.data.src_subset = "0" + elif consumer_edge.data.src_subset is not None: + consumer_edge.data.src_subset.offset(consumer_offset, negative=True) + consumer_edge.data.src_subset.pop(squeezed_dims) + + # The edge that leaves the second map entry was already deleted. We now delete + # the edges that connected the intermediate node with the second map entry. + for edge in list(state.in_edges_by_connector(map_entry_2, in_conn_name)): + assert edge.src == inter_node + state.remove_edge(edge) + map_entry_2.remove_in_connector(in_conn_name) + map_entry_2.remove_out_connector(out_conn_name) + + if is_exclusive_set: + # In exclusive mode the old intermediate node is no longer needed. + # This will also remove `out_edge` from the SDFG. + assert state.degree(inter_node) == 1 + state.remove_edge_and_connectors(out_edge) + state.remove_node(inter_node) + + state.remove_edge(pre_exit_edge) + map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + map_exit_1.remove_out_connector(out_edge.src_conn) + del sdfg.arrays[inter_name] + + else: + # This is the shared mode, so we have to recreate the intermediate + # node, but this time it is at the exit of the second map. + state.remove_edge(pre_exit_edge) + map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + + # This is the Memlet that goes from the map internal intermediate + # temporary node to the Map output. This will essentially restore + # or preserve the output for the intermediate node. It is important + # that we use the data that `preExitEdge` was used. + final_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + assert pre_exit_edge.data.data == inter_name + final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc) + + new_pre_exit_conn = map_exit_2.next_connector() + state.add_edge( + new_inter_node, + None, + map_exit_2, + "IN_" + new_pre_exit_conn, + final_pre_exit_memlet, + ) + state.add_edge( + map_exit_2, + "OUT_" + new_pre_exit_conn, + inter_node, + out_edge.dst_conn, + copy.deepcopy(out_edge.data), + ) + map_exit_2.add_in_connector("IN_" + new_pre_exit_conn) + map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn) + + map_exit_1.remove_out_connector(out_edge.src_conn) + state.remove_edge(out_edge) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py index 4b34dd6adc..8fb41c7d0a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py @@ -16,12 +16,42 @@ from gt4py.next.program_processors.runners.dace_fieldview import utility as gtx_dace_fieldview_util +def gt_set_iteration_order( + sdfg: dace.SDFG, + leading_dim: Optional[ + Union[str, gtx_common.Dimension, list[Union[str, gtx_common.Dimension]]] + ] = None, + validate: bool = True, + validate_all: bool = False, +) -> Any: + """Set the iteration order of the Maps correctly. + + Modifies the order of the Map parameters such that `leading_dim` + is the fastest varying one, the order of the other dimensions in + a Map is unspecific. `leading_dim` should be the dimensions were + the stride is one. + + Args: + sdfg: The SDFG to process. + leading_dim: The leading dimensions. + validate: Perform validation at the end of the function. + validate_all: Perform validation also on intermediate steps. + """ + return sdfg.apply_transformations_once_everywhere( + MapIterationOrder( + leading_dims=leading_dim, + ), + validate=validate, + validate_all=validate_all, + ) + + @dace_properties.make_properties class MapIterationOrder(dace_transformation.SingleStateTransformation): """Modify the order of the iteration variables. The iteration order, while irrelevant from an SDFG point of view, is highly - relevant in code, and the fastest varying index ("inner most loop" in CPU or + relevant in code and the fastest varying index ("inner most loop" in CPU or "x block dimension" in GPU) should be associated with the stride 1 dimension of the array. This transformation will reorder the map indexes such that this is the case. @@ -29,9 +59,18 @@ class MapIterationOrder(dace_transformation.SingleStateTransformation): While the place of the leading dimension is clearly defined, the order of the other loop indexes, after this transformation is unspecified. + The transformation accepts either a single dimension or a list of dimensions. + In case a list is passed this is interpreted as priorities. + Assuming we have the `leading_dim=[EdgeDim, VertexDim]`, then we have the + following: + - `Map[EdgeDim, KDim, VertexDim]` -> `Map[KDim, VertexDim, EdgeDim]`. + - `Map[VertexDim, KDim]` -> `Map[KDim, VertexDim]`. + - `Map[EdgeDim, KDim]` -> `Map[KDim, EdgeDim]`. + - `Map[CellDim, KDim]` -> `Map[CellDim, KDim]` (no modification). + Args: - leading_dim: A GT4Py dimension object that identifies the dimension that - is supposed to have stride 1. + leading_dim: GT4Py dimensions that are associated with the dimension that is + supposed to have stride 1. If it is a list it is used as a ranking. Note: The transformation does follow the rules outlines in @@ -44,25 +83,33 @@ class MapIterationOrder(dace_transformation.SingleStateTransformation): - Maybe also process the parameters to bring them in a canonical order. """ - leading_dim = dace_properties.Property( - dtype=str, + leading_dims = dace_properties.ListProperty( + element_type=str, allow_none=True, - desc="Dimension that should become the leading dimension.", + default=None, + desc="Dimensions that should become the leading dimension.", ) - map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) def __init__( self, - leading_dim: Optional[Union[gtx_common.Dimension, str]] = None, + leading_dims: Optional[ + Union[str, gtx_common.Dimension, list[Union[str, gtx_common.Dimension]]] + ] = None, *args: Any, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) - if isinstance(leading_dim, gtx_common.Dimension): - self.leading_dim = gtx_dace_fieldview_util.get_map_variable(leading_dim) - elif leading_dim is not None: - self.leading_dim = leading_dim + if isinstance(leading_dims, (gtx_common.Dimension, str)): + leading_dims = [leading_dims] + if isinstance(leading_dims, list): + self.leading_dims = [ + leading_dim + if isinstance(leading_dim, str) + else gtx_dace_fieldview_util.get_map_variable(leading_dim) + for leading_dim in leading_dims + ] @classmethod def expressions(cls) -> Any: @@ -80,16 +127,15 @@ def can_be_applied( Essentially the function checks if the selected dimension is inside the map, and if so, if it is on the right place. """ - - if self.leading_dim is None: + if self.leading_dims is None: return False map_entry: dace_nodes.MapEntry = self.map_entry map_params: Sequence[str] = map_entry.map.params - map_var: str = self.leading_dim + processed_dims: set[str] = set(self.leading_dims) - if map_var not in map_params: + if not any(map_param in processed_dims for map_param in map_params): return False - if map_params[-1] == map_var: # Already at the correct location + if self.compute_map_param_order() is None: return False return True @@ -104,22 +150,52 @@ def apply( `self.leading_dim` the last map variable (this is given by the structure of DaCe's code generator). """ + map_object: dace_nodes.Map = self.map_entry.map + new_map_params_order: list[int] = self.compute_map_param_order() # type: ignore[assignment] # Guaranteed to be not `None`. + + def reorder(what: list[Any]) -> list[Any]: + assert isinstance(what, list) + return [what[new_pos] for new_pos in new_map_params_order] + + map_object.params = reorder(map_object.params) + map_object.range.ranges = reorder(map_object.range.ranges) + map_object.range.tile_sizes = reorder(map_object.range.tile_sizes) + + def compute_map_param_order(self) -> Optional[list[int]]: + """Computes the new iteration order of the matched map. + + The function returns a list, the value at index `i` indicates the old dimension + that should be put at the new location. If the order is already correct then + `None` is returned. + """ map_entry: dace_nodes.MapEntry = self.map_entry map_params: list[str] = map_entry.map.params - map_var: str = self.leading_dim - - # This implementation will just swap the variable that is currently the last - # with the one that should be the last. - dst_idx = -1 - src_idx = map_params.index(map_var) - - for to_process in [ - map_entry.map.params, - map_entry.map.range.ranges, - map_entry.map.range.tile_sizes, - ]: - assert isinstance(to_process, list) - src_val = to_process[src_idx] - dst_val = to_process[dst_idx] - to_process[dst_idx] = src_val - to_process[src_idx] = dst_val + org_mapping: dict[str, int] = {map_param: i for i, map_param in enumerate(map_params)} + leading_dims: list[str] = self.leading_dims + + # We divide the map parameters into two groups, the one we care and the others. + map_params_to_order: set[str] = { + map_param for map_param in map_params if map_param in leading_dims + } + + # If there is nothing to order, then we are done. + if not map_params_to_order: + return None + + # We start with all parameters that we ignore/do not care about. + new_map_params: list[str] = [ + map_param for map_param in map_params if map_param not in leading_dims + ] + + # Because how code generation works the leading dimension must be the most + # left one. Because this is also `self.leading_dims[0]` we have to process + # then in reverse order. + for map_param_to_check in reversed(leading_dims): + if map_param_to_check in map_params_to_order: + new_map_params.append(map_param_to_check) + assert len(map_params) == len(new_map_params) + + if map_params == new_map_params: + return None + + return [org_mapping[new_map_param] for new_map_param in new_map_params] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index 19818fd3d1..46d46c4bbe 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -299,9 +299,9 @@ class SerialMapPromoter(BaseMapPromoter): """ # Pattern Matching - exit_first_map = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) - access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) - entry_second_map = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + exit_first_map = dace_transformation.PatternNode(dace_nodes.MapExit) + access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + entry_second_map = dace_transformation.PatternNode(dace_nodes.MapEntry) @classmethod def expressions(cls) -> Any: @@ -346,17 +346,11 @@ def _test_if_promoted_maps_can_be_fused( ) -> bool: """This function checks if the promoted maps can be fused by map fusion. - This function assumes that `self.can_be_applied()` returned `True`. + This function assumes that `super().self.can_be_applied()` returned `True`. Args: state: The state in which we operate. sdfg: The SDFG we process. - - Note: - The current implementation uses a very hacky way to test this. - - Todo: - Find a better way of doing it. """ first_map_exit: dace_nodes.MapExit = self.exit_first_map access_node: dace_nodes.AccessNode = self.access_node @@ -373,23 +367,17 @@ def _test_if_promoted_maps_can_be_fused( # This will lead to a promotion of the map, this is needed that # Map fusion can actually inspect them. self.apply(graph=state, sdfg=sdfg) - - # Now create the map fusion object that we can then use to check if - # the fusion is possible or not. - serial_fuser = gtx_transformations.SerialMapFusion( - only_inner_maps=self.only_inner_maps, - only_toplevel_maps=self.only_toplevel_maps, - ) - candidate = { - type(serial_fuser).map_exit1: first_map_exit, - type(serial_fuser).access_node: access_node, - type(serial_fuser).map_entry2: second_map_entry, - } - state_id = sdfg.node_id(state) - serial_fuser.setup_match(sdfg, sdfg.cfg_id, state_id, candidate, 0, override=True) - - # Now use the serial fuser to see if fusion would succeed - if not serial_fuser.can_be_applied(state, 0, sdfg): + if not gtx_transformations.MapFusionSerial.can_be_applied_to( + sdfg=sdfg, + expr_index=0, + options={ + "only_inner_maps": self.only_inner_maps, + "only_toplevel_maps": self.only_toplevel_maps, + }, + map_exit_1=first_map_exit, + intermediate_access_node=access_node, + map_entry_2=second_map_entry, + ): return False finally: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py deleted file mode 100644 index bca5aa2268..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py +++ /dev/null @@ -1,483 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -"""Implements the serial map fusing transformation. - -Note: - After DaCe [PR#1629](https://github.com/spcl/dace/pull/1629), that implements - a better map fusion transformation is merged, this file will be deleted. -""" - -import copy -from typing import Any, Union - -import dace -from dace import ( - dtypes as dace_dtypes, - properties as dace_properties, - subsets as dace_subsets, - symbolic as dace_symbolic, - transformation as dace_transformation, -) -from dace.sdfg import graph as dace_graph, nodes as dace_nodes - -from gt4py.next.program_processors.runners.dace_fieldview.transformations import map_fusion_helper - - -@dace_properties.make_properties -class SerialMapFusion(map_fusion_helper.MapFusionHelper): - """Specialized replacement for the map fusion transformation that is provided by DaCe. - - As its name is indicating this transformation is only able to handle Maps that - are in sequence. Compared to the native DaCe transformation, this one is able - to handle more complex cases of connection between the maps. In that sense, it - is much more similar to DaCe's `SubgraphFusion` transformation. - - Things that are improved, compared to the native DaCe implementation: - - Nested Maps. - - Temporary arrays and the correct propagation of their Memlets. - - Top Maps that have multiple outputs. - - Conceptually this transformation removes the exit of the first or upper map - and the entry of the lower or second map and then rewrites the connections - appropriately. - - This transformation assumes that an SDFG obeys the structure that is outlined - [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). For that - reason it is not true replacement of the native DaCe transformation. - - Args: - only_inner_maps: Only match Maps that are internal, i.e. inside another Map. - only_toplevel_maps: Only consider Maps that are at the top. - - Notes: - - This transformation modifies more nodes than it matches! - """ - - map_exit1 = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) - access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) - map_entry2 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) - - def __init__( - self, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - - @classmethod - def expressions(cls) -> Any: - """Get the match expression. - - The transformation matches the exit node of the top Map that is connected to - an access node that again is connected to the entry node of the second Map. - An important note is, that the transformation operates not just on the - matched nodes, but more or less on anything that has an incoming connection - from the first Map or an outgoing connection to the second Map entry. - """ - return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] - - def can_be_applied( - self, - graph: Union[dace.SDFGState, dace.SDFG], - expr_index: int, - sdfg: dace.SDFG, - permissive: bool = False, - ) -> bool: - """Tests if the matched Maps can be merged. - - The two Maps are mergeable iff: - - The `can_be_fused()` of the base succeed, which checks some basic constraints. - - The decomposition exists and at least one of the intermediate sets - is not empty. - """ - assert isinstance(self.map_exit1, dace_nodes.MapExit) - assert isinstance(self.map_entry2, dace_nodes.MapEntry) - map_entry_1: dace_nodes.MapEntry = graph.entry_node(self.map_exit1) - map_entry_2: dace_nodes.MapEntry = self.map_entry2 - - # This essentially test the structural properties of the two Maps. - if not self.can_be_fused( - map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg - ): - return False - - # Two maps can be serially fused if the node decomposition exists and - # at least one of the intermediate output sets is not empty. The state - # of the pure outputs is irrelevant for serial map fusion. - output_partition = self.partition_first_outputs( - state=graph, - sdfg=sdfg, - map_exit_1=self.map_exit1, - map_entry_2=self.map_entry2, - ) - if output_partition is None: - return False - _, exclusive_outputs, shared_outputs = output_partition - if not (exclusive_outputs or shared_outputs): - return False - return True - - def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: - """Performs the serial Map fusing. - - The function first computes the map decomposition and then handles the - three sets. The pure outputs are handled by `relocate_nodes()` while - the two intermediate sets are handled by `handle_intermediate_set()`. - - By assumption we do not have to rename anything. - - Args: - graph: The SDFG state we are operating on. - sdfg: The SDFG we are operating on. - """ - # NOTE: `self.map_*` actually stores the ID of the node. - # once we start adding and removing nodes it seems that their ID changes. - # Thus we have to save them here, this is a known behaviour in DaCe. - assert isinstance(graph, dace.SDFGState) - assert isinstance(self.map_exit1, dace_nodes.MapExit) - assert isinstance(self.map_entry2, dace_nodes.MapEntry) - assert self.map_parameter_compatible(self.map_exit1.map, self.map_entry2.map, graph, sdfg) - - map_exit_1: dace_nodes.MapExit = self.map_exit1 - map_entry_2: dace_nodes.MapEntry = self.map_entry2 - map_exit_2: dace_nodes.MapExit = graph.exit_node(self.map_entry2) - map_entry_1: dace_nodes.MapEntry = graph.entry_node(self.map_exit1) - - output_partition = self.partition_first_outputs( - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - ) - assert output_partition is not None # Make MyPy happy. - pure_outputs, exclusive_outputs, shared_outputs = output_partition - - if len(exclusive_outputs) != 0: - self.handle_intermediate_set( - intermediate_outputs=exclusive_outputs, - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - map_exit_2=map_exit_2, - is_exclusive_set=True, - ) - if len(shared_outputs) != 0: - self.handle_intermediate_set( - intermediate_outputs=shared_outputs, - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - map_exit_2=map_exit_2, - is_exclusive_set=False, - ) - assert pure_outputs == set(graph.out_edges(map_exit_1)) - if len(pure_outputs) != 0: - self.relocate_nodes( - from_node=map_exit_1, - to_node=map_exit_2, - state=graph, - sdfg=sdfg, - ) - - # Above we have handled the input of the second map and moved them - # to the first map, now we must move the output of the first map - # to the second one, as this one is used. - self.relocate_nodes( - from_node=map_entry_2, - to_node=map_entry_1, - state=graph, - sdfg=sdfg, - ) - - for node_to_remove in [map_exit_1, map_entry_2]: - assert graph.degree(node_to_remove) == 0 - graph.remove_node(node_to_remove) - - # Now turn the second output node into the output node of the first Map. - map_exit_2.map = map_entry_1.map - - @staticmethod - def handle_intermediate_set( - intermediate_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]], - state: dace.SDFGState, - sdfg: dace.SDFG, - map_exit_1: dace_nodes.MapExit, - map_entry_2: dace_nodes.MapEntry, - map_exit_2: dace_nodes.MapExit, - is_exclusive_set: bool, - ) -> None: - """This function handles the intermediate sets. - - The function is able to handle both the shared and exclusive intermediate - output set, see `partition_first_outputs()`. The main difference is that - in exclusive mode the intermediate nodes will be fully removed from - the SDFG. While in shared mode the intermediate node will be preserved. - - Args: - intermediate_outputs: The set of outputs, that should be processed. - state: The state in which the map is processed. - sdfg: The SDFG that should be optimized. - map_exit_1: The exit of the first/top map. - map_entry_2: The entry of the second map. - map_exit_2: The exit of the second map. - is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. - - Notes: - Before the transformation the `state` does not have to be valid and - after this function has run the state is (most likely) invalid. - - Todo: - Rewrite using `MemletTree`. - """ - - # Essentially this function removes the AccessNode between the two maps. - # However, we still need some temporary memory that we can use, which is - # just much smaller, i.e. a scalar. But all Memlets inside the second map - # assumes that the intermediate memory has the bigger shape. - # To fix that we will create this replacement dict that will replace all - # occurrences of the iteration variables of the second map with zero. - # Note that this is still not enough as the dimensionality might be different. - memlet_repl: dict[str, int] = {str(param): 0 for param in map_entry_2.map.params} - - # Now we will iterate over all intermediate edges and process them. - # If not stated otherwise the comments assume that we run in exclusive mode. - for out_edge in intermediate_outputs: - # This is the intermediate node that, that we want to get rid of. - # In shared mode we want to recreate it after the second map. - inter_node: dace_nodes.AccessNode = out_edge.dst - inter_name = inter_node.data - inter_desc = inter_node.desc(sdfg) - inter_shape = inter_desc.shape - - # Now we will determine the shape of the new intermediate. This size of - # this temporary is given by the Memlet that goes into the first map exit. - pre_exit_edges = list( - state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) - ) - if len(pre_exit_edges) != 1: - raise NotImplementedError() - pre_exit_edge = pre_exit_edges[0] - new_inter_shape_raw = dace_symbolic.overapproximate(pre_exit_edge.data.subset.size()) - - # Over approximation will leave us with some unneeded size one dimensions. - # That are known to cause some troubles, so we will now remove them. - squeezed_dims: list[int] = [] # These are the dimensions we removed. - new_inter_shape: list[int] = [] # This is the final shape of the new intermediate. - for dim, (proposed_dim_size, full_dim_size) in enumerate( - zip(new_inter_shape_raw, inter_shape) - ): - # Order of checks is important! - if full_dim_size == 1: # Must be kept! - new_inter_shape.append(proposed_dim_size) - elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. - squeezed_dims.append(dim) - else: - new_inter_shape.append(proposed_dim_size) - - # This is the name of the new "intermediate" node that we will create. - # It will only have the shape `new_inter_shape` which is basically its - # output within one Map iteration. - # NOTE: The insertion process might generate a new name. - new_inter_name: str = f"__s{sdfg.node_id(state)}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" - - # Now generate the intermediate data container. - if len(new_inter_shape) == 0: - assert pre_exit_edge.data.subset.num_elements() == 1 - is_scalar = True - new_inter_name, new_inter_desc = sdfg.add_scalar( - new_inter_name, - dtype=inter_desc.dtype, - transient=True, - storage=dace_dtypes.StorageType.Register, - find_new_name=True, - ) - - else: - assert (pre_exit_edge.data.subset.num_elements() > 1) or all( - x == 1 for x in new_inter_shape - ) - is_scalar = False - new_inter_name, new_inter_desc = sdfg.add_transient( - new_inter_name, - shape=new_inter_shape, - dtype=inter_desc.dtype, - find_new_name=True, - ) - new_inter_node: dace_nodes.AccessNode = state.add_access(new_inter_name) - - # New we will reroute the output Memlet, thus it will no longer pass - # through the Map exit but through the newly created intermediate. - # we will delete the previous edge later. - pre_exit_memlet: dace.Memlet = pre_exit_edge.data - new_pre_exit_memlet = copy.deepcopy(pre_exit_memlet) - - # We might operate on a different array, but the check below, ensures - # that we do not change the direction of the Memlet. - assert pre_exit_memlet.data == inter_name - new_pre_exit_memlet.data = new_inter_name - - # Now we have to modify the subset of the Memlet. - # Before the subset of the Memlet was dependent on the Map variables, - # however, this is no longer the case, as we removed them. This change - # has to be reflected in the Memlet. - # NOTE: Assert above ensures that the below is correct. - new_pre_exit_memlet.replace(memlet_repl) - if is_scalar: - new_pre_exit_memlet.subset = "0" - new_pre_exit_memlet.other_subset = None - else: - new_pre_exit_memlet.subset.pop(squeezed_dims) - - # Now we create the new edge between the producer and the new output - # (the new intermediate node). We will remove the old edge further down. - new_pre_exit_edge = state.add_edge( - pre_exit_edge.src, - pre_exit_edge.src_conn, - new_inter_node, - None, - new_pre_exit_memlet, - ) - - # We just have handled the last Memlet, but we must actually handle the - # whole producer side, i.e. the scope of the top Map. - for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(): - producer_edge = producer_tree.edge - - # Ensure the correctness of the rerouting below. - # TODO(phimuell): Improve the code below to remove the check. - assert producer_edge.data.data == inter_name - - # Will not change the direction, because of test above! - producer_edge.data.data = new_inter_name - producer_edge.data.replace(memlet_repl) - if is_scalar: - producer_edge.data.dst_subset = "0" - elif producer_edge.data.dst_subset is not None: - producer_edge.data.dst_subset.pop(squeezed_dims) - - # Now after we have handled the input of the new intermediate node, - # we must handle its output. For this we have to "inject" the newly - # created intermediate into the second map. We do this by finding - # the input connectors on the map entry, such that we know where we - # have to reroute inside the Map. - # NOTE: Assumes that map (if connected is the direct neighbour). - conn_names: set[str] = set() - for inter_node_out_edge in state.out_edges(inter_node): - if inter_node_out_edge.dst == map_entry_2: - assert inter_node_out_edge.dst_conn.startswith("IN_") - conn_names.add(inter_node_out_edge.dst_conn) - else: - # If we found another target than the second map entry from the - # intermediate node it means that the node _must_ survive, - # i.e. we are not in exclusive mode. - assert not is_exclusive_set - - # Now we will reroute the connections inside the second map, i.e. - # instead of consuming the old intermediate node, they will now - # consume the new intermediate node. - for in_conn_name in conn_names: - out_conn_name = "OUT_" + in_conn_name[3:] - - for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): - assert inner_edge.data.data == inter_name # DIRECTION!! - - # The create the first Memlet to transmit information, within - # the second map, we do this again by copying and modifying - # the original Memlet. - # NOTE: Test above is important to ensure the direction of the - # Memlet and the correctness of the code below. - new_inner_memlet = copy.deepcopy(inner_edge.data) - new_inner_memlet.replace(memlet_repl) - new_inner_memlet.data = new_inter_name # Because of the assert above, this will not change the direction. - - # Now remove the old edge, that started the second map entry. - # Also add the new edge that started at the new intermediate. - state.remove_edge(inner_edge) - new_inner_edge = state.add_edge( - new_inter_node, - None, - inner_edge.dst, - inner_edge.dst_conn, - new_inner_memlet, - ) - - # Now we do subset modification to ensure that nothing failed. - if is_scalar: - new_inner_memlet.src_subset = "0" - elif new_inner_memlet.src_subset is not None: - new_inner_memlet.src_subset.pop(squeezed_dims) - - # Now clean the Memlets of that tree to use the new intermediate node. - for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(): - consumer_edge = consumer_tree.edge - assert consumer_edge.data.data == inter_name - consumer_edge.data.data = new_inter_name - consumer_edge.data.replace(memlet_repl) - if is_scalar: - consumer_edge.data.src_subset = "0" - elif consumer_edge.data.subset is not None: - consumer_edge.data.subset.pop(squeezed_dims) - - # The edge that leaves the second map entry was already deleted. - # We will now delete the edges that brought the data. - for edge in state.in_edges_by_connector(map_entry_2, in_conn_name): - assert edge.src == inter_node - state.remove_edge(edge) - map_entry_2.remove_in_connector(in_conn_name) - map_entry_2.remove_out_connector(out_conn_name) - - if is_exclusive_set: - # In exclusive mode the old intermediate node is no longer needed. - assert state.degree(inter_node) == 1 - state.remove_edge_and_connectors(out_edge) - state.remove_node(inter_node) - - state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) - map_exit_1.remove_out_connector(out_edge.src_conn) - del sdfg.arrays[inter_name] - - else: - # This is the shared mode, so we have to recreate the intermediate - # node, but this time it is at the exit of the second map. - state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) - - # This is the Memlet that goes from the map internal intermediate - # temporary node to the Map output. This will essentially restore - # or preserve the output for the intermediate node. It is important - # that we use the data that `preExitEdge` was used. - new_exit_memlet = copy.deepcopy(pre_exit_edge.data) - assert new_exit_memlet.data == inter_name - new_exit_memlet.subset = pre_exit_edge.data.dst_subset - new_exit_memlet.other_subset = ( - "0" if is_scalar else dace_subsets.Range.from_array(inter_desc) - ) - - new_pre_exit_conn = map_exit_2.next_connector() - state.add_edge( - new_inter_node, - None, - map_exit_2, - "IN_" + new_pre_exit_conn, - new_exit_memlet, - ) - state.add_edge( - map_exit_2, - "OUT_" + new_pre_exit_conn, - inter_node, - out_edge.dst_conn, - copy.deepcopy(out_edge.data), - ) - map_exit_2.add_in_connector("IN_" + new_pre_exit_conn) - map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn) - - map_exit_1.remove_out_connector(out_edge.src_conn) - state.remove_edge(out_edge) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py new file mode 100644 index 0000000000..6b7bd1b6d5 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -0,0 +1,1010 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""The GT4Py specific simplification pass.""" + +import collections +import copy +import uuid +from typing import Any, Final, Iterable, Optional, TypeAlias + +import dace +from dace import ( + data as dace_data, + properties as dace_properties, + subsets as dace_subsets, + transformation as dace_transformation, +) +from dace.sdfg import nodes as dace_nodes +from dace.transformation import ( + dataflow as dace_dataflow, + pass_pipeline as dace_ppl, + passes as dace_passes, +) + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +GT_SIMPLIFY_DEFAULT_SKIP_SET: Final[set[str]] = {"ScalarToSymbolPromotion", "ConstantPropagation"} +"""Set of simplify passes `gt_simplify()` skips by default. + +The following passes are included: +- `ScalarToSymbolPromotion`: The lowering has sometimes to turn a scalar into a + symbol or vice versa and at a later point to invert this again. However, this + pass has some problems with this pattern so for the time being it is disabled. +- `ConstantPropagation`: Same reasons as `ScalarToSymbolPromotion`. +""" + + +def gt_simplify( + sdfg: dace.SDFG, + validate: bool = True, + validate_all: bool = False, + skip: Optional[Iterable[str]] = None, +) -> Optional[dict[str, Any]]: + """Performs simplifications on the SDFG in place. + + Instead of calling `sdfg.simplify()` directly, you should use this function, + as it is specially tuned for GridTool based SDFGs. + + This function runs the DaCe simplification pass, but the following passes are + replaced: + - `InlineSDFGs`: Instead `gt_inline_nested_sdfg()` will be called. + + Further, the function will run the following passes in addition to DaCe simplify: + - `GT4PyGlobalSelfCopyElimination`: Special copy pattern that in the context + of GT4Py based SDFG behaves as a no op. + + Furthermore, by default, or if `None` is passed for `skip` the passes listed in + `GT_SIMPLIFY_DEFAULT_SKIP_SET` will be skipped. + + Args: + sdfg: The SDFG to optimize. + validate: Perform validation after the pass has run. + validate_all: Perform extensive validation. + skip: List of simplify passes that should not be applied, defaults + to `GT_SIMPLIFY_DEFAULT_SKIP_SET`. + + Note: + Currently DaCe does not provide a way to inject or exchange sub passes in + simplify. The custom inline pass is run at the beginning and the array + elimination at the end. The whole process is run inside a loop that ensures + that `gt_simplify()` results in a fix point. + """ + # Ensure that `skip` is a `set` + skip = GT_SIMPLIFY_DEFAULT_SKIP_SET if skip is None else set(skip) + + result: Optional[dict[str, Any]] = None + + at_least_one_xtrans_run = True + + while at_least_one_xtrans_run: + at_least_one_xtrans_run = False + + if "InlineSDFGs" not in skip: + inline_res = gt_inline_nested_sdfg( + sdfg=sdfg, + multistate=True, + permissive=False, + validate=validate, + validate_all=validate_all, + ) + if inline_res is not None: + at_least_one_xtrans_run = True + result = result or {} + result.update(inline_res) + + simplify_res = dace_passes.SimplifyPass( + validate=validate, + validate_all=validate_all, + verbose=False, + skip=(skip | {"InlineSDFGs"}), + ).apply_pass(sdfg, {}) + + if simplify_res is not None: + at_least_one_xtrans_run = True + result = result or {} + result.update(simplify_res) + + if "GT4PyGlobalSelfCopyElimination" not in skip: + self_copy_removal_result = sdfg.apply_transformations_repeated( + GT4PyGlobalSelfCopyElimination(), + validate=validate, + validate_all=validate_all, + ) + if self_copy_removal_result > 0: + at_least_one_xtrans_run = True + result = result or {} + result.setdefault("GT4PyGlobalSelfCopyElimination", 0) + result["GT4PyGlobalSelfCopyElimination"] += self_copy_removal_result + + return result + + +def gt_inline_nested_sdfg( + sdfg: dace.SDFG, + multistate: bool = True, + permissive: bool = False, + validate: bool = True, + validate_all: bool = False, +) -> Optional[dict[str, int]]: + """Perform inlining of nested SDFG into their parent SDFG. + + The function uses DaCe's `InlineSDFG` transformation, the same used in simplify. + However, before the inline transformation is run the function will run some + cleaning passes that allows inlining nested SDFGs. + As a side effect, the function will split stages into more states. + + Args: + sdfg: The SDFG that should be processed, will be modified in place and returned. + multistate: Allow inlining of multistate nested SDFG, defaults to `True`. + permissive: Be less strict on the accepted SDFGs. + validate: Perform validation after the transformation has finished. + validate_all: Performs extensive validation. + """ + first_iteration = True + nb_preproccess_total = 0 + nb_inlines_total = 0 + while True: + nb_preproccess = sdfg.apply_transformations_repeated( + [dace_dataflow.PruneSymbols, dace_dataflow.PruneConnectors], + validate=False, + validate_all=validate_all, + ) + nb_preproccess_total += nb_preproccess + if (nb_preproccess == 0) and (not first_iteration): + break + + # Create and configure the inline pass + inline_sdfg = dace_passes.InlineSDFGs() + inline_sdfg.progress = False + inline_sdfg.permissive = permissive + inline_sdfg.multistate = multistate + + # Apply the inline pass + # The pass returns `None` no indicate "nothing was done" + nb_inlines = inline_sdfg.apply_pass(sdfg, {}) or 0 + nb_inlines_total += nb_inlines + + # Check result, if needed and test if we can stop + if validate_all or validate: + sdfg.validate() + if nb_inlines == 0: + break + first_iteration = False + + result: dict[str, int] = {} + if nb_inlines_total != 0: + result["InlineSDFGs"] = nb_inlines_total + if nb_preproccess_total != 0: + result["PruneSymbols|PruneConnectors"] = nb_preproccess_total + return result if result else None + + +def gt_substitute_compiletime_symbols( + sdfg: dace.SDFG, + repl: dict[str, Any], + validate: bool = False, + validate_all: bool = False, +) -> None: + """Substitutes symbols that are known at compile time with their value. + + Some symbols are known to have a constant value. This function will remove these + symbols from the SDFG and replace them with the value. + An example where this makes sense are strides that are known to be one. + + Args: + sdfg: The SDFG to process. + repl: Maps the name of the symbol to the value it should be replaced with. + validate: Perform validation at the end of the function. + validate_all: Perform validation also on intermediate steps. + """ + + # We will use the `replace` function of the top SDFG, however, lower levels + # are handled using ConstantPropagation. + sdfg.replace_dict(repl) + + const_prop = dace_passes.ConstantPropagation() + const_prop.recursive = True + const_prop.progress = False + + const_prop.apply_pass( + sdfg=sdfg, + initial_symbols=repl, + _=None, + ) + gt_simplify( + sdfg=sdfg, + validate=validate, + validate_all=validate_all, + ) + dace.sdfg.propagation.propagate_memlets_sdfg(sdfg) + + +def gt_reduce_distributed_buffering( + sdfg: dace.SDFG, +) -> Optional[dict[dace.SDFG, dict[dace.SDFGState, set[str]]]]: + """Removes distributed write back buffers.""" + pipeline = dace_ppl.Pipeline([DistributedBufferRelocator()]) + all_result = {} + + for rsdfg in sdfg.all_sdfgs_recursive(): + ret = pipeline.apply_pass(sdfg, {}) + if ret is not None: + all_result[rsdfg] = ret + + return all_result + + +@dace_properties.make_properties +class GT4PyGlobalSelfCopyElimination(dace_transformation.SingleStateTransformation): + """Remove global self copy. + + This transformation matches the following case `(G) -> (T) -> (G)`, i.e. `G` + is read from and written too at the same time, however, in between is `T` + used as a buffer. In the example above `G` is a global memory and `T` is a + temporary. This situation is generated by the lowering if the data node is + not needed (because the computation on it is only conditional). + + In case `G` refers to global memory rule 3 of ADR-18 guarantees that we can + only have a point wise dependency of the output on the input. + This transformation will remove the write into `G`, i.e. we thus only have + `(G) -> (T)`. The read of `G` and the definition of `T`, will only be removed + if `T` is not used downstream. If it is used `T` will be maintained. + """ + + node_read_g = dace_transformation.PatternNode(dace_nodes.AccessNode) + node_tmp = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) + node_write_g = dace_transformation.PatternNode(dace_nodes.AccessNode) + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.node_read_g, cls.node_tmp, cls.node_write_g)] + + def can_be_applied( + self, + graph: dace.SDFGState | dace.SDFG, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + read_g = self.node_read_g + write_g = self.node_write_g + tmp_node = self.node_tmp + g_desc = read_g.desc(sdfg) + tmp_desc = tmp_node.desc(sdfg) + + # NOTE: We do not check if `G` is read downstream. + if read_g.data != write_g.data: + return False + if g_desc.transient: + return False + if not tmp_desc.transient: + return False + if graph.in_degree(read_g) != 0: + return False + if graph.out_degree(read_g) != 1: + return False + if graph.degree(tmp_node) != 2: + return False + if graph.in_degree(write_g) != 1: + return False + if graph.out_degree(write_g) != 0: + return False + if graph.scope_dict()[read_g] is not None: + return False + + return True + + def _is_read_downstream( + self, + start_state: dace.SDFGState, + sdfg: dace.SDFG, + data_to_look: str, + ) -> bool: + """Scans for reads to `data_to_look`. + + The function will go through states that are reachable from `start_state` + (including) and test if there is a read to the data container `data_to_look`. + It will return `True` the first time it finds such a node. + It is important that the matched nodes, i.e. `self.node_{read_g, write_g, tmp}` + are ignored. + + Args: + start_state: The state where the scanning starts. + sdfg: The SDFG on which we operate. + data_to_look: The data that we want to look for. + + Todo: + Port this function to use DaCe pass pipeline. + """ + read_g: dace_nodes.AccessNode = self.node_read_g + write_g: dace_nodes.AccessNode = self.node_write_g + tmp_node: dace_nodes.AccessNode = self.node_tmp + + return gtx_transformations.util.is_accessed_downstream( + start_state=start_state, + sdfg=sdfg, + data_to_look=data_to_look, + nodes_to_ignore={read_g, write_g, tmp_node}, + ) + + def apply( + self, + graph: dace.SDFGState | dace.SDFG, + sdfg: dace.SDFG, + ) -> None: + read_g: dace_nodes.AccessNode = self.node_read_g + write_g: dace_nodes.AccessNode = self.node_write_g + tmp_node: dace_nodes.AccessNode = self.node_tmp + + # We first check if `T`, the intermediate is not used downstream. In this + # case we can remove the read to `G` and `T` itself from the SDFG. + # We have to do this check before, because the matching is not fully stable. + is_tmp_used_downstream = self._is_read_downstream( + start_state=graph, sdfg=sdfg, data_to_look=tmp_node.data + ) + + # The write to `G` can always be removed. + graph.remove_node(write_g) + + # Also remove the read to `G` and `T` from the SDFG if possible. + if not is_tmp_used_downstream: + graph.remove_node(read_g) + graph.remove_node(tmp_node) + # It could still be used in a parallel branch. + try: + sdfg.remove_data(tmp_node.data, validate=True) + except ValueError as e: + if not str(e).startswith(f"Cannot remove data descriptor {tmp_node.data}:"): + raise + + +AccessLocation: TypeAlias = tuple[dace.SDFGState, dace_nodes.AccessNode] +"""Describes an access node and the state in which it is located. +""" + + +@dace_properties.make_properties +class DistributedBufferRelocator(dace_transformation.Pass): + """Moves the final write back of the results to where it is needed. + + In certain cases, especially in case where we have `if` the result is computed + in each branch and then in the join state written back. Thus there is some + additional storage needed. + The transformation will look for the following situation: + - A transient data container, called `src_cont`, is written into another + container, called `dst_cont`, which is not transient. + - The access node of `src_cont` has an in degree of zero and an out degree of one. + - The access node of `dst_cont` has an in degree of of one and an + out degree of zero (this might be lifted). + - `src_cont` is not used afterwards. + - `dst_cont` is only used to implement the buffering. + + The function will relocate the writing of `dst_cont` to where `src_cont` is + written, which might be multiple locations. + It will also remove the writing back. + It is advised that after this transformation simplify is run again. + + Note: + Essentially this transformation removes the double buffering of `dst_cont`. + Because we ensure that that `dst_cont` is non transient this is okay, as our + rule guarantees this. + + Todo: + - Allow that `dst_cont` can also be transient. + - Allow that `dst_cont` does not need to be a sink node, this is most + likely most relevant if it is transient. + - Check if `dst_cont` is used between where we want to place it and + where it is currently used. + """ + + def modifies(self) -> dace_ppl.Modifies: + return dace_ppl.Modifies.Memlets | dace_ppl.Modifies.AccessNodes + + def should_reapply(self, modified: dace_ppl.Modifies) -> bool: + return modified & (dace_ppl.Modifies.Memlets | dace_ppl.Modifies.AccessNodes) + + def depends_on(self) -> set[type[dace_transformation.Pass]]: + return { + dace_transformation.passes.StateReachability, + dace_transformation.passes.AccessSets, + } + + def apply_pass( + self, sdfg: dace.SDFG, pipeline_results: dict[str, Any] + ) -> Optional[dict[dace.SDFGState, set[str]]]: + reachable: dict[dace.SDFGState, set[dace.SDFGState]] = pipeline_results[ + "StateReachability" + ][sdfg.cfg_id] + access_sets: dict[dace.SDFGState, tuple[set[str], set[str]]] = pipeline_results[ + "AccessSets" + ][sdfg.cfg_id] + result: dict[dace.SDFGState, set[str]] = collections.defaultdict(set) + + to_relocate = self._find_candidates(sdfg, reachable, access_sets) + if len(to_relocate) == 0: + return None + self._relocate_write_backs(sdfg, to_relocate) + + for (wb_an, wb_state), _ in to_relocate: + result[wb_state].add(wb_an.data) + + return result + + def _relocate_write_backs( + self, + sdfg: dace.SDFG, + to_relocate: list[tuple[AccessLocation, list[AccessLocation]]], + ) -> None: + """Perform the actual relocation.""" + for (wb_an, wb_state), def_locations in to_relocate: + # Get the memlet that we have to replicate. + wb_edge = next(iter(wb_state.out_edges(wb_an))) + wb_memlet: dace.Memlet = wb_edge.data + final_dest_name: str = wb_edge.dst.data + + for def_an, def_state in def_locations: + def_state.add_edge( + def_an, + wb_edge.src_conn, + def_state.add_access(final_dest_name), + wb_edge.dst_conn, + copy.deepcopy(wb_memlet), + ) + + # Now remove the old node and if the old target become isolated + # remove that as well. + old_dst = wb_edge.dst + wb_state.remove_node(wb_an) + if wb_state.degree(old_dst) == 0: + wb_state.remove_node(old_dst) + + def _find_candidates( + self, + sdfg: dace.SDFG, + reachable: dict[dace.SDFGState, set[dace.SDFGState]], + access_sets: dict[dace.SDFGState, tuple[set[str], set[str]]], + ) -> list[tuple[AccessLocation, list[AccessLocation]]]: + """Determines all temporaries that have to be relocated. + + Returns: + A list of tuples. The first element element of the tuple is an + `AccessLocation` that describes where the temporary is read. + The second element is a list of `AccessLocation`s that describes + where the temporary is defined. + """ + # All nodes that are used as distributed buffers. + candidate_src_cont: list[AccessLocation] = [] + + # Which `src_cont` access node is written back to which global memory. + src_cont_to_global: dict[dace_nodes.AccessNode, str] = {} + + for state in sdfg.states(): + # These are the possible targets we want to write into. + candidate_dst_nodes: set[dace_nodes.AccessNode] = { + node + for node in state.sink_nodes() + if ( + isinstance(node, dace_nodes.AccessNode) + and state.in_degree(node) == 1 + and (not node.desc(sdfg).transient) + ) + } + if len(candidate_dst_nodes) == 0: + continue + + for src_cont in state.source_nodes(): + if not isinstance(src_cont, dace_nodes.AccessNode): + continue + if not src_cont.desc(sdfg).transient: + continue + if state.out_degree(src_cont) != 1: + continue + dst_candidate: dace_nodes.AccessNode = next( + iter(edge.dst for edge in state.out_edges(src_cont)) + ) + if dst_candidate not in candidate_dst_nodes: + continue + candidate_src_cont.append((src_cont, state)) + src_cont_to_global[src_cont] = dst_candidate.data + + if len(candidate_src_cont) == 0: + return [] + + # Now we have to find the places where the temporary sources are defined. + # I.e. This is also the location where the original value is defined. + result_candidates: list[tuple[AccessLocation, list[AccessLocation]]] = [] + + def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: + return { + src_state + for src_state in sdfg.states() + if dst_state in reachable[src_state] and dst_state is not src_state + } + + for src_cont in candidate_src_cont: + def_locations: list[AccessLocation] = [] + for upstream_state in find_upstream_states(src_cont[1]): + if src_cont[0].data in access_sets[upstream_state][1]: + def_locations.extend( + (data_node, upstream_state) + for data_node in upstream_state.data_nodes() + if data_node.data == src_cont[0].data + ) + if len(def_locations) != 0: + result_candidates.append((src_cont, def_locations)) + + # This transformation removes `src_cont` by writing its content directly + # to `dst_cont`, at the point where it is defined. + # For this transformation to be valid the following conditions have to be met: + # - Between the definition of `src_cont` and the write back to `dst_cont`, + # `dst_cont` can not be accessed. + # - Between the definitions of `src_cont` and the point where it is written + # back, `src_cont` can only be accessed in the range that is written back. + # - After the write back point, `src_cont` shall not be accessed. This + # restriction could be lifted. + # + # To keep the implementation simple, we use the conditions: + # - `src_cont` is only accessed were it is defined and at the write back + # point. + # - Between the definitions of `src_cont` and the write back point, + # `dst_cont` is not used. + + result: list[tuple[AccessLocation, list[AccessLocation]]] = [] + + for wb_localation, def_locations in result_candidates: + for def_node, def_state in def_locations: + # Test if `src_cont` is only accessed where it is defined and + # where it is written back. + if gtx_transformations.util.is_accessed_downstream( + start_state=def_state, + sdfg=sdfg, + data_to_look=wb_localation[0].data, + nodes_to_ignore={def_node, wb_localation[0]}, + ): + break + # check if the global data is not used between the definition of + # `dst_cont` and where its written back. We allow one exception, + # if the global data is used in the state the distributed temporary + # is defined is used only for reading then it is ignored. This is + # allowed because of rule 3 of ADR0018. + glob_nodes_in_def_state = { + dnode + for dnode in def_state.data_nodes() + if dnode.data == src_cont_to_global[wb_localation[0]] + } + if any(def_state.in_degree(gdnode) != 0 for gdnode in glob_nodes_in_def_state): + break + if gtx_transformations.util.is_accessed_downstream( + start_state=def_state, + sdfg=sdfg, + data_to_look=src_cont_to_global[wb_localation[0]], + nodes_to_ignore=glob_nodes_in_def_state, + states_to_ignore={wb_localation[1]}, + ): + break + else: + result.append((wb_localation, def_locations)) + + return result + + +@dace_properties.make_properties +class GT4PyMoveTaskletIntoMap(dace_transformation.SingleStateTransformation): + """Moves a Tasklet, with no input into a map. + + Tasklets without inputs, are mostly used to generate constants. + However, if they are outside a Map, then this constant value is an + argument to the kernel, and can not be used by the compiler. + + This transformation moves such Tasklets into a Map scope. + """ + + tasklet = dace_transformation.PatternNode(dace_nodes.Tasklet) + access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.tasklet, cls.access_node, cls.map_entry)] + + def can_be_applied( + self, + graph: dace.SDFGState | dace.SDFG, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + tasklet: dace_nodes.Tasklet = self.tasklet + access_node: dace_nodes.AccessNode = self.access_node + access_desc: dace_data.Data = access_node.desc(sdfg) + map_entry: dace_nodes.MapEntry = self.map_entry + + if graph.in_degree(tasklet) != 0: + return False + if graph.out_degree(tasklet) != 1: + return False + if tasklet.has_side_effects(sdfg): + return False + if tasklet.code_init.as_string: + return False + if tasklet.code_exit.as_string: + return False + if tasklet.code_global.as_string: + return False + if tasklet.state_fields: + return False + if not isinstance(access_desc, dace_data.Scalar): + return False + if not access_desc.transient: + return False + if not any( + edge.dst_conn and edge.dst_conn.startswith("IN_") + for edge in graph.out_edges(access_node) + if edge.dst is map_entry + ): + return False + # NOTE: We allow that the access node is used in multiple places. + + return True + + def apply( + self, + graph: dace.SDFGState | dace.SDFG, + sdfg: dace.SDFG, + ) -> None: + tasklet: dace_nodes.Tasklet = self.tasklet + access_node: dace_nodes.AccessNode = self.access_node + access_desc: dace_data.Scalar = access_node.desc(sdfg) + map_entry: dace_nodes.MapEntry = self.map_entry + + # Find _a_ connection that leads from the access node to the map. + edge_to_map = next( + iter( + edge + for edge in graph.out_edges(access_node) + if edge.dst is map_entry and edge.dst_conn.startswith("IN_") + ) + ) + connector_name: str = edge_to_map.dst_conn[3:] + + # This is the tasklet that we will put inside the map, note we have to do it + # this way to avoid some name clash stuff. + inner_tasklet: dace_nodes.Tasklet = graph.add_tasklet( + name=f"{tasklet.label}__clone_{str(uuid.uuid1()).replace('-', '_')}", + outputs=tasklet.out_connectors.keys(), + inputs=set(), + code=tasklet.code, + language=tasklet.language, + debuginfo=tasklet.debuginfo, + ) + inner_desc: dace_data.Scalar = access_desc.clone() + inner_data_name: str = sdfg.add_datadesc(access_node.data, inner_desc, find_new_name=True) + inner_an: dace_nodes.AccessNode = graph.add_access(inner_data_name) + + # Connect the tasklet with the map entry and the access node. + graph.add_nedge(map_entry, inner_tasklet, dace.Memlet()) + graph.add_edge( + inner_tasklet, + next(iter(inner_tasklet.out_connectors.keys())), + inner_an, + None, + dace.Memlet(f"{inner_data_name}[0]"), + ) + + # Now we will reroute the edges went through the inner map, through the + # inner access node instead. + for old_inner_edge in list( + graph.out_edges_by_connector(map_entry, "OUT_" + connector_name) + ): + # We now modify the downstream data. This is because we no longer refer + # to the data outside but the one inside. + self._modify_downstream_memlets( + state=graph, + edge=old_inner_edge, + old_data=access_node.data, + new_data=inner_data_name, + ) + + # After we have changed the properties of the MemletTree of `edge` + # we will now reroute it, such that the inner access node is used. + graph.add_edge( + inner_an, + None, + old_inner_edge.dst, + old_inner_edge.dst_conn, + old_inner_edge.data, + ) + graph.remove_edge(old_inner_edge) + map_entry.remove_in_connector("IN_" + connector_name) + map_entry.remove_out_connector("OUT_" + connector_name) + + # Now we can remove the map connection between the outer/old access + # node and the map. + graph.remove_edge(edge_to_map) + + # The data is no longer referenced in this state, so we can potentially + # remove + if graph.out_degree(access_node) == 0: + if not gtx_transformations.util.is_accessed_downstream( + start_state=graph, + sdfg=sdfg, + data_to_look=access_node.data, + nodes_to_ignore={access_node}, + ): + graph.remove_nodes_from([tasklet, access_node]) + # Needed if data is accessed in a parallel branch. + try: + sdfg.remove_data(access_node.data, validate=True) + except ValueError as e: + if not str(e).startswith(f"Cannot remove data descriptor {access_node.data}:"): + raise + + def _modify_downstream_memlets( + self, + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge, + old_data: str, + new_data: str, + ) -> None: + """Replaces the data along on the tree defined by `edge`. + + The function will traverse the MemletTree defined by `edge`. + Any Memlet that refers to `old_data` will be replaced with + `new_data`. + + Args: + state: The sate in which we operate. + edge: The edge defining the MemletTree. + old_data: The name of the data that should be replaced. + new_data: The name of the new data the Memlet should refer to. + """ + mtree: dace.memlet.MemletTree = state.memlet_tree(edge) + for tedge in mtree.traverse_children(True): + # Because we only change the name of the data, we do not change the + # direction of the Memlet, so `{src, dst}_subset` will remain the same. + if tedge.edge.data.data == old_data: + tedge.edge.data.data = new_data + + +@dace_properties.make_properties +class GT4PyMapBufferElimination(dace_transformation.SingleStateTransformation): + """Allows to remove unneeded buffering at map output. + + The transformation matches the case `MapExit -> (T) -> (G)`, where `T` is an + AccessNode referring to a transient and `G` an AccessNode that refers to non + transient memory. + If the following conditions are met then `T` is removed. + - `T` is not used to filter computations, i.e. what is written into `G` + is covered by what is written into `T`. + - `T` is not used anywhere else. + - `G` is not also an input to the map, except there is only a pointwise + dependency in `G`, see the note below. + - Everything needs to be at top scope. + + Notes: + - Rule 3 of ADR18 should guarantee that any valid GT4Py program meets the + point wise dependency in `G`, for that reason it is possible to disable + this test by specifying `assume_pointwise`. + + Todo: + - Implement a real pointwise test. + """ + + map_exit = dace_transformation.PatternNode(dace_nodes.MapExit) + tmp_ac = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) + glob_ac = dace_transformation.PatternNode(dace_nodes.AccessNode) + + assume_pointwise = dace_properties.Property( + dtype=bool, + default=False, + desc="Dimensions that should become the leading dimension.", + ) + + def __init__( + self, + assume_pointwise: Optional[bool] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if assume_pointwise is not None: + self.assume_pointwise = assume_pointwise + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.map_exit, cls.tmp_ac, cls.glob_ac)] + + def depends_on(self) -> set[type[dace_transformation.Pass]]: + return {dace_transformation.passes.ConsolidateEdges} + + def can_be_applied( + self, + graph: dace.SDFGState | dace.SDFG, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + tmp_ac: dace_nodes.AccessNode = self.tmp_ac + glob_ac: dace_nodes.AccessNode = self.glob_ac + tmp_desc: dace_data.Data = tmp_ac.desc(sdfg) + glob_desc: dace_data.Data = glob_ac.desc(sdfg) + + if not tmp_desc.transient: + return False + if glob_desc.transient: + return False + if graph.in_degree(tmp_ac) != 1: + return False + if any(gtx_transformations.util.is_view(ac, sdfg) for ac in [tmp_ac, glob_ac]): + return False + if len(glob_desc.shape) != len(tmp_desc.shape): + return False + + # Test if we are on the top scope (it is likely). + if graph.scope_dict()[glob_ac] is not None: + return False + + # Now perform if we are point wise + if not self._perform_pointwise_test(graph, sdfg): + return False + + # Test if `tmp` is only anywhere else, this is important for removing it. + if graph.out_degree(tmp_ac) != 1: + return False + if gtx_transformations.util.is_accessed_downstream( + start_state=graph, + sdfg=sdfg, + data_to_look=tmp_ac.data, + nodes_to_ignore={tmp_ac}, + ): + return False + + # Now we ensure that `tmp` is not used to filter out some computations. + map_to_tmp_edge = next(edge for edge in graph.in_edges(tmp_ac)) + tmp_to_glob_edge = next(edge for edge in graph.out_edges(tmp_ac)) + + tmp_in_subset = map_to_tmp_edge.data.get_dst_subset(map_to_tmp_edge, graph) + tmp_out_subset = tmp_to_glob_edge.data.get_src_subset(tmp_to_glob_edge, graph) + glob_in_subset = tmp_to_glob_edge.data.get_dst_subset(tmp_to_glob_edge, graph) + if tmp_in_subset is None: + tmp_in_subset = dace_subsets.Range.from_array(tmp_desc) + if tmp_out_subset is None: + tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) + if glob_in_subset is None: + return False + + # TODO(phimuell): Do we need simplify in the check. + # TODO(phimuell): Restrict this to having the same size. + if tmp_out_subset != tmp_in_subset: + return False + return True + + def _perform_pointwise_test( + self, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> bool: + """Test if `G` is only point wise accessed. + + This function will also consider the `assume_pointwise` property. + """ + map_exit: dace_nodes.MapExit = self.map_exit + map_entry: dace_nodes.MapEntry = state.entry_node(map_exit) + glob_ac: dace_nodes.AccessNode = self.glob_ac + glob_data: str = glob_ac.data + + # First we check if `G` is also an input to this map. + conflicting_inputs: set[dace_nodes.AccessNode] = set() + for in_edge in state.in_edges(map_entry): + if not isinstance(in_edge.src, dace_nodes.AccessNode): + continue + + # Find the source of this data, if it is a view we trace it to + # its origin. + src_node: dace_nodes.AccessNode = gtx_transformations.util.track_view( + in_edge.src, state, sdfg + ) + + # Test if there is a conflict; We do not store the source but the + # actual node that is adjacent. + if src_node.data == glob_data: + conflicting_inputs.add(in_edge.src) + + # If there are no conflicting inputs, then we are point wise. + # This is an implementation detail that make life simpler. + if len(conflicting_inputs) == 0: + return True + + # If we can assume pointwise computations, then we do not have to do + # anything. + if self.assume_pointwise: + return True + + # Currently the only test that we do is, if we have a view, then we + # are not point wise. + # TODO(phimuell): Improve/implement this. + return any(gtx_transformations.util.is_view(node, sdfg) for node in conflicting_inputs) + + def apply( + self, + graph: dace.SDFGState | dace.SDFG, + sdfg: dace.SDFG, + ) -> None: + # Removal + # Propagation ofthe shift. + map_exit: dace_nodes.MapExit = self.map_exit + tmp_ac: dace_nodes.AccessNode = self.tmp_ac + tmp_desc: dace_data.Data = tmp_ac.desc(sdfg) + tmp_data = tmp_ac.data + glob_ac: dace_nodes.AccessNode = self.glob_ac + glob_data = glob_ac.data + + map_to_tmp_edge = next(edge for edge in graph.in_edges(tmp_ac)) + tmp_to_glob_edge = next(edge for edge in graph.out_edges(tmp_ac)) + + glob_in_subset = tmp_to_glob_edge.data.get_dst_subset(tmp_to_glob_edge, graph) + tmp_out_subset = tmp_to_glob_edge.data.get_src_subset(tmp_to_glob_edge, graph) + if tmp_out_subset is None: + tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) + assert glob_in_subset is not None + + # We now remove the `tmp` node, and create a new connection between + # the global node and the map exit. + new_map_to_glob_edge = graph.add_edge( + map_exit, + map_to_tmp_edge.src_conn, + glob_ac, + tmp_to_glob_edge.dst_conn, + dace.Memlet( + data=glob_ac.data, + subset=copy.deepcopy(glob_in_subset), + ), + ) + graph.remove_edge(map_to_tmp_edge) + graph.remove_edge(tmp_to_glob_edge) + graph.remove_node(tmp_ac) + + # We can not unconditionally remove the data `tmp` refers to, because + # it could be that in a parallel branch the `tmp` is also defined. + try: + sdfg.remove_data(tmp_ac.data, validate=True) + except ValueError as e: + if not str(e).startswith(f"Cannot remove data descriptor {tmp_ac.data}:"): + raise + + # Now we must modify the memlets inside the map scope, because + # they now write into `G` instead of `tmp`, which has a different + # offset. + # NOTE: Assumes that `tmp_out_subset` and `tmp_in_subset` are the same. + correcting_offset = glob_in_subset.offset_new(tmp_out_subset, negative=True) + mtree = graph.memlet_tree(new_map_to_glob_edge) + for tree in mtree.traverse_children(include_self=False): + curr_edge = tree.edge + curr_dst_subset = curr_edge.data.get_dst_subset(curr_edge, graph) + if curr_edge.data.data == tmp_data: + curr_edge.data.data = glob_data + if curr_dst_subset is not None: + curr_dst_subset.offset(correcting_offset, negative=False) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py new file mode 100644 index 0000000000..4e254f2880 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -0,0 +1,99 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dace +from dace import data as dace_data + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +def gt_change_transient_strides( + sdfg: dace.SDFG, + gpu: bool, +) -> dace.SDFG: + """Modifies the strides of transients. + + The function will analyse the access patterns and set the strides of + transients in the optimal way. + The function should run after all maps have been created. + + Args: + sdfg: The SDFG to process. + gpu: If the SDFG is supposed to run on the GPU. + + Note: + Currently the function will not scan the access pattern. Instead it will + either use FORTRAN order for GPU or C order (which is assumed to be the + default, so it is a no ops). + + Todo: + - Implement the estimation correctly. + - Handle the case of nested SDFGs correctly; on the outside a transient, + but on the inside a non transient. + """ + # TODO(phimeull): Implement this function correctly. + + # We assume that by default we have C order which is already correct, + # so in this case we have a no ops + if not gpu: + return sdfg + + for nsdfg in sdfg.all_sdfgs_recursive(): + # TODO(phimuell): Handle the case when transient goes into nested SDFG + # on the inside it is a non transient, so it is ignored. + _gt_change_transient_strides_non_recursive_impl(nsdfg) + + +def _gt_change_transient_strides_non_recursive_impl( + sdfg: dace.SDFG, +) -> None: + """Essentially this function just changes the stride to FORTRAN order.""" + for top_level_transient in _find_toplevel_transients(sdfg, only_arrays=True): + desc: dace_data.Array = sdfg.arrays[top_level_transient] + ndim = len(desc.shape) + if ndim <= 1: + continue + # We assume that everything is in C order initially, to get FORTRAN order + # we simply have to reverse the order. + new_stride_order = list(range(ndim)) + desc.set_strides_from_layout(*new_stride_order) + + +def _find_toplevel_transients( + sdfg: dace.SDFG, + only_arrays: bool = False, +) -> set[str]: + """Find all top level transients in the SDFG. + + The function will scan the SDFG, ignoring nested one, and return the + name of all transients that have an access node at the top level. + However, it will ignore access nodes that refers to registers. + """ + top_level_transients: set[str] = set() + for state in sdfg.states(): + scope_dict = state.scope_dict() + for dnode in state.data_nodes(): + data: str = dnode.data + if scope_dict[dnode] is not None: + if data in top_level_transients: + top_level_transients.remove(data) + continue + elif data in top_level_transients: + continue + elif gtx_transformations.util.is_view(dnode, sdfg): + continue + desc: dace_data.Data = dnode.desc(sdfg) + + if not desc.transient: + continue + elif only_arrays and not isinstance(desc, dace_data.Array): + continue + top_level_transients.add(data) + return top_level_transients diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py index 29bae7bbe0..29c099eecf 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py @@ -8,153 +8,220 @@ """Common functionality for the transformations/optimization pipeline.""" -from typing import Iterable, Union +from typing import Any, Container, Optional, Union import dace -from dace.sdfg import graph as dace_graph, nodes as dace_nodes +from dace import data as dace_data +from dace.sdfg import nodes as dace_nodes +from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -def is_nested_sdfg( - sdfg: Union[dace.SDFG, dace.SDFGState, dace_nodes.NestedSDFG], -) -> bool: - """Tests if `sdfg` is a NestedSDFG.""" - if isinstance(sdfg, dace.SDFGState): - sdfg = sdfg.parent - if isinstance(sdfg, dace_nodes.NestedSDFG): - return True - elif isinstance(sdfg, dace.SDFG): - return sdfg.parent_nsdfg_node is not None - raise TypeError(f"Does not know how to handle '{type(sdfg).__name__}'.") - - -def all_nodes_between( - graph: dace.SDFG | dace.SDFGState, - begin: dace_nodes.Node, - end: dace_nodes.Node, - reverse: bool = False, -) -> set[dace_nodes.Node] | None: - """Find all nodes that are reachable from `begin` but bound by `end`. - - Essentially the function starts a DFS at `begin`. If an edge is found that lead - to `end`, this edge is ignored. It will thus found any node that is reachable - from `begin` by a path that does not involve `end`. The returned set will - never contain `end` nor `begin`. In case `end` is never found the function - will return `None`. - - If `reverse` is set to `True` the function will start exploring at `end` and - follows the outgoing edges, i.e. the meaning of `end` and `begin` are swapped. + +def gt_make_transients_persistent( + sdfg: dace.SDFG, + device: dace.DeviceType, +) -> dict[int, set[str]]: + """ + Changes the lifetime of certain transients to `Persistent`. + + A persistent lifetime means that the transient is allocated only the very first + time the SDFG is executed and only deallocated if the underlying `CompiledSDFG` + object goes out of scope. The main advantage is, that memory must not be + allocated every time the SDFG is run. The downside is that the SDFG can not be + called by different threads. Args: - graph: The graph to operate on. - begin: The start of the DFS. - end: The terminator node of the DFS. - reverse: Perform a backward DFS. - - Notes: - - The returned set will also contain the nodes of path that starts at - `begin` and ends at a node that is not `end`. + sdfg: The SDFG to process. + device: The device type. + + Returns: + A `dict` mapping SDFG IDs to a set of transient arrays that + were made persistent. + + Note: + This function is based on a similar function in DaCe. However, the DaCe + function does, for unknown reasons, also reset the `wcr_nonatomic` property, + but only for GPU. """ + result: dict[int, set[str]] = {} + for nsdfg in sdfg.all_sdfgs_recursive(): + fsyms: set[str] = nsdfg.free_symbols + modify_lifetime: set[str] = set() + not_modify_lifetime: set[str] = set() + + for state in nsdfg.states(): + for dnode in state.data_nodes(): + if dnode.data in not_modify_lifetime: + continue - def next_nodes(node: dace_nodes.Node) -> Iterable[dace_nodes.Node]: - return ( - (edge.src for edge in graph.in_edges(node)) - if reverse - else (edge.dst for edge in graph.out_edges(node)) - ) + if dnode.data in nsdfg.constants_prop: + not_modify_lifetime.add(dnode.data) + continue - if reverse: - begin, end = end, begin + desc = dnode.desc(nsdfg) + if not desc.transient or type(desc) not in {dace.data.Array, dace.data.Scalar}: + not_modify_lifetime.add(dnode.data) + continue + if desc.storage == dace.StorageType.Register: + not_modify_lifetime.add(dnode.data) + continue - to_visit: list[dace_nodes.Node] = [begin] - seen: set[dace_nodes.Node] = set() + if desc.lifetime == dace.AllocationLifetime.External: + not_modify_lifetime.add(dnode.data) + continue - while len(to_visit) > 0: - node: dace_nodes.Node = to_visit.pop() - if node != end and node not in seen: - to_visit.extend(next_nodes(node)) - seen.add(node) + try: + # The symbols describing the total size must be a subset of the + # free symbols of the SDFG (symbols passed as argument). + # NOTE: This ignores the renaming of symbols through the + # `symbol_mapping` property of nested SDFGs. + if not set(map(str, desc.total_size.free_symbols)).issubset(fsyms): + not_modify_lifetime.add(dnode.data) + continue + except AttributeError: # total_size is an integer / has no free symbols + pass - # If `end` was not found we have to return `None` to indicate this. - if end not in seen: - return None + # Make it persistent. + modify_lifetime.add(dnode.data) - # `begin` and `end` are not included in the output set. - return seen - {begin, end} + # Now setting the lifetime. + result[nsdfg.cfg_id] = modify_lifetime - not_modify_lifetime + for aname in result[nsdfg.cfg_id]: + nsdfg.arrays[aname].lifetime = dace.AllocationLifetime.Persistent + return result -def find_downstream_consumers( - state: dace.SDFGState, - begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], - only_tasklets: bool = False, - reverse: bool = False, -) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: - """Find all downstream connectors of `begin`. - - A consumer, in for this function, is any node that is neither an entry nor - an exit node. The function returns a set of pairs, the first element is the - node that acts as consumer and the second is the edge that leads to it. - By setting `only_tasklets` the nodes the function finds are only Tasklets. - - To find this set the function starts a search at `begin`, however, it is also - possible to pass an edge as `begin`. - If `reverse` is `True` the function essentially finds the producers that are - upstream. + +def gt_find_constant_arguments( + call_args: dict[str, Any], + include: Optional[Container[str]] = None, +) -> dict[str, Any]: + """Scans the calling arguments for compile time constants. + + The output of this function can be used as input to + `gt_substitute_compiletime_symbols()`, which then removes these symbols. + + By specifying `include` it is possible to force the function to include + additional arguments, that would not be matched otherwise. Importantly, + their value is not checked. Args: - state: The state in which to look for the consumers. - begin: The initial node that from which the search starts. - only_tasklets: Return only Tasklets. - reverse: Follow the reverse direction. + call_args: The full list of arguments that will be passed to the SDFG. + include: List of arguments that should be included. """ - if isinstance(begin, dace_graph.MultiConnectorEdge): - to_visit: list[dace_graph.MultiConnectorEdge[dace.Memlet]] = [begin] - else: - to_visit = state.in_edges(begin) if reverse else state.out_edges(begin) + if include is None: + include = set() + ret_value: dict[str, Any] = {} - seen: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - found: set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]] = set() + for name, value in call_args.items(): + if name in include or (dace_utils.is_field_symbol(name) and value == 1): + ret_value[name] = value - while len(to_visit) > 0: - curr_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = to_visit.pop() - next_node: dace_nodes.Node = curr_edge.src if reverse else curr_edge.dst - - if curr_edge in seen: - continue - seen.add(curr_edge) - - if isinstance(next_node, (dace_nodes.MapEntry, dace_nodes.MapExit)): - if not reverse: - # In forward mode a Map entry could also mean the definition of a - # dynamic map range. - if isinstance(next_node, dace_nodes.MapEntry) and ( - not curr_edge.dst_conn.startswith("IN_") - ): - if not only_tasklets: - found.add((next_node, curr_edge)) - continue - target_conn = curr_edge.dst_conn[3:] - new_edges = state.out_edges_by_connector(curr_edge.dst, "OUT_" + target_conn) - else: - target_conn = curr_edge.src_conn[4:] - new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn) - to_visit.extend(new_edges) + return ret_value - elif isinstance(next_node, dace_nodes.Tasklet) or not only_tasklets: - # We have found a consumer. - found.add((next_node, curr_edge)) - return found +def is_accessed_downstream( + start_state: dace.SDFGState, + sdfg: dace.SDFG, + data_to_look: str, + nodes_to_ignore: Optional[set[dace_nodes.AccessNode]] = None, + states_to_ignore: Optional[set[dace.SDFGState]] = None, +) -> bool: + """Scans for accesses to the data container `data_to_look`. + The function will go through states that are reachable from `start_state` + (included) and test if there is an AccessNode that refers to `data_to_look`. + It will return `True` the first time it finds such a node. -def find_upstream_producers( + The function will ignore all nodes that are listed in `nodes_to_ignore`. + Furthermore, states listed in `states_to_ignore` will be ignored, i.e. + handled as they did not exist. + + Args: + start_state: The state where the scanning starts. + sdfg: The SDFG on which we operate. + data_to_look: The data that we want to look for. + nodes_to_ignore: Ignore these nodes. + states_to_ignore: Ignore these states. + """ + seen_states: set[dace.SDFGState] = set() + to_visit: list[dace.SDFGState] = [start_state] + ign_dnodes: set[dace_nodes.AccessNode] = nodes_to_ignore or set() + ign_states: set[dace.SDFGState] = states_to_ignore or set() + + while len(to_visit) > 0: + state = to_visit.pop() + seen_states.add(state) + for dnode in state.data_nodes(): + if dnode.data != data_to_look: + continue + if dnode in ign_dnodes: + continue + if state.out_degree(dnode) != 0: + return True # There is a read operation + + # Look for new states, also scan the interstate edges. + for out_edge in sdfg.out_edges(state): + if out_edge.dst in ign_states: + continue + if data_to_look in out_edge.data.read_symbols(): + return True + if out_edge.dst in seen_states: + continue + to_visit.append(out_edge.dst) + + return False + + +def is_view( + node: Union[dace_nodes.AccessNode, dace_data.Data], + sdfg: dace.SDFG, +) -> bool: + """Tests if `node` points to a view or not.""" + node_desc: dace_data.Data = node.desc(sdfg) if isinstance(node, dace_nodes.AccessNode) else node + return isinstance(node_desc, dace_data.View) + + +def track_view( + view: dace_nodes.AccessNode, state: dace.SDFGState, - begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], - only_tasklets: bool = False, -) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: - """Same as `find_downstream_consumers()` but with `reverse` set to `True`.""" - return find_downstream_consumers( - state=state, - begin=begin, - only_tasklets=only_tasklets, - reverse=True, - ) + sdfg: dace.SDFG, +) -> dace_nodes.AccessNode: + """Find the original data of a View. + + Given the View `view`, the function will trace the view back to the original + access node. For convenience, if `view` is not a `View` the argument will be + returned. + + Args: + view: The view that should be traced. + state: The state in which we operate. + sdfg: The SDFG on which we operate. + """ + + # Test if it is a view at all, if not return the passed node as source. + if not is_view(view, sdfg): + return view + + # First determine if the view is used for reading or writing. + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") + if curr_edge.dst_conn == "views": + # The view is used for reading. + next_node = lambda curr_edge: curr_edge.src # noqa: E731 + elif curr_edge.src_conn == "views": + # The view is used for writing. + next_node = lambda curr_edge: curr_edge.dst # noqa: E731 + else: + raise RuntimeError(f"Failed to determine the direction of the view '{view}' | {curr_edge}.") + + # Now trace the view back. + org_view = view + view = next_node(curr_edge) + while is_view(view, sdfg): + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"View tracing of '{org_view}' failed at note '{view}'.") + view = next_node(curr_edge) + return view diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 80b8f4f39b..d7413f32d7 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -169,14 +169,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): BACKEND_SKIP_TEST_MATRIX = { EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, - OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST - + [ - (ALL, SKIP, UNSUPPORTED_MESSAGE) - ], # TODO(edopao): Enable once the optimization pipeline is merged - OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST - + [ - (ALL, SKIP, UNSUPPORTED_MESSAGE) - ], # TODO(edopao): Enable once the optimization pipeline is merged. + OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py index e85ef6ad1f..0eb0bf39c2 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py @@ -11,7 +11,7 @@ import pytest -@pytest.fixture() +@pytest.fixture(autouse=True) def set_dace_settings() -> Generator[None, None, None]: """Sets the common DaCe settings for the tests. @@ -24,6 +24,6 @@ def set_dace_settings() -> Generator[None, None, None]: import dace with dace.config.temporary_config(): - dace.Config.set("optimizer", "match_exception", value=False) + dace.Config.set("optimizer", "match_exception", value=True) dace.Config.set("compiler", "allow_view_arguments", value=True) yield diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py new file mode 100644 index 0000000000..04a4f098ef --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py @@ -0,0 +1,142 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + + +def test_constant_substitution(): + sdfg, nsdfg = _make_sdfg() + + # Ensure that `One` is present. + assert len(sdfg.symbols) == 2 + assert len(nsdfg.sdfg.symbols) == 2 + assert len(nsdfg.symbol_mapping) == 2 + assert "One" in sdfg.symbols + assert "One" in nsdfg.sdfg.symbols + assert "One" in nsdfg.symbol_mapping + assert "One" == str(nsdfg.symbol_mapping["One"]) + assert all(str(desc.strides[1]) == "One" for desc in sdfg.arrays.values()) + assert all(str(desc.strides[1]) == "One" for desc in nsdfg.sdfg.arrays.values()) + assert all(str(desc.strides[0]) == "N" for desc in sdfg.arrays.values()) + assert all(str(desc.strides[0]) == "N" for desc in nsdfg.sdfg.arrays.values()) + assert "One" in sdfg.used_symbols(True) + + # Now replace `One` with 1 + gtx_transformations.gt_substitute_compiletime_symbols(sdfg, {"One": 1}) + + assert len(sdfg.symbols) == 1 + assert len(nsdfg.sdfg.symbols) == 1 + assert len(nsdfg.symbol_mapping) == 1 + assert "One" not in sdfg.symbols + assert "One" not in nsdfg.sdfg.symbols + assert "One" not in nsdfg.symbol_mapping + assert all(desc.strides[1] == 1 and len(desc.strides) == 2 for desc in sdfg.arrays.values()) + assert all( + desc.strides[1] == 1 and len(desc.strides) == 2 for desc in nsdfg.sdfg.arrays.values() + ) + assert all(str(desc.strides[0]) == "N" for desc in sdfg.arrays.values()) + assert all(str(desc.strides[0]) == "N" for desc in nsdfg.sdfg.arrays.values()) + assert "One" not in sdfg.used_symbols(True) + + +def _make_nested_sdfg() -> dace.SDFG: + sdfg = dace.SDFG("nested") + N = dace.symbol(sdfg.add_symbol("N", dace.int32)) + One = dace.symbol(sdfg.add_symbol("One", dace.int32)) + for name in "ABC": + sdfg.add_array( + name=name, + dtype=dace.float64, + shape=(N, N), + strides=(N, One), + transient=False, + ) + state = sdfg.add_state(is_start_block=True) + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:N", "__i1": "0:N"}, + inputs={ + "__in0": dace.Memlet("A[__i0, __i1]"), + "__in1": dace.Memlet("B[__i0, __i1]"), + }, + code="__out = __in0 + __in1", + outputs={"__out": dace.Memlet("C[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_sdfg() -> tuple[dace.SDFG, dace.nodes.NestedSDFG]: + sdfg = dace.SDFG("outer_sdfg") + N = dace.symbol(sdfg.add_symbol("N", dace.int32)) + One = dace.symbol(sdfg.add_symbol("One", dace.int32)) + for name in "ABCD": + sdfg.add_array( + name=name, + dtype=dace.float64, + shape=(N, N), + strides=(N, One), + transient=False, + ) + sdfg.arrays["C"].transient = True + + first_state: dace.SDFGState = sdfg.add_state(is_start_block=True) + nested_sdfg: dace.SDFG = _make_nested_sdfg() + nsdfg = first_state.add_nested_sdfg( + nested_sdfg, + parent=sdfg, + inputs={"A", "B"}, + outputs={"C"}, + symbol_mapping={"One": "One", "N": "N"}, + ) + first_state.add_edge( + first_state.add_access("A"), + None, + nsdfg, + "A", + dace.Memlet("A[0:N, 0:N]"), + ) + first_state.add_edge( + first_state.add_access("B"), + None, + nsdfg, + "B", + dace.Memlet("B[0:N, 0:N]"), + ) + first_state.add_edge( + nsdfg, + "C", + first_state.add_access("C"), + None, + dace.Memlet("C[0:N, 0:N]"), + ) + + second_state: dace.SDFGState = sdfg.add_state_after(first_state) + second_state.add_mapped_tasklet( + "outer_computation", + map_ranges={"__i0": "0:N", "__i1": "0:N"}, + inputs={ + "__in0": dace.Memlet("A[__i0, __i1]"), + "__in1": dace.Memlet("C[__i0, __i1]"), + }, + code="__out = __in0 * __in1", + outputs={"__out": dace.Memlet("D[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg, nsdfg diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py new file mode 100644 index 0000000000..3d9201c603 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py @@ -0,0 +1,239 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np +import copy + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes +from dace import data as dace_data + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _create_sdfg_double_read_part_1( + sdfg: dace.SDFG, + state: dace.SDFGState, + me: dace.nodes.MapEntry, + mx: dace.nodes.MapExit, + A_in: dace.nodes.AccessNode, + nb: int, +) -> dace.nodes.Tasklet: + tskl = state.add_tasklet( + name=f"tasklet_1", inputs={"__in1"}, outputs={"__out"}, code="__out = __in1 + 1.0" + ) + + state.add_edge(A_in, None, me, f"IN_{nb}", dace.Memlet("A[0:10]")) + state.add_edge(me, f"OUT_{nb}", tskl, "__in1", dace.Memlet("A[__i0]")) + me.add_in_connector(f"IN_{nb}") + me.add_out_connector(f"OUT_{nb}") + + state.add_edge(tskl, "__out", mx, f"IN_{nb}", dace.Memlet("A[__i0]")) + state.add_edge(mx, f"OUT_{nb}", state.add_access("A"), None, dace.Memlet("A[0:10]")) + mx.add_in_connector(f"IN_{nb}") + mx.add_out_connector(f"OUT_{nb}") + + +def _create_sdfg_double_read_part_2( + sdfg: dace.SDFG, + state: dace.SDFGState, + me: dace.nodes.MapEntry, + mx: dace.nodes.MapExit, + A_in: dace.nodes.AccessNode, + nb: int, +) -> dace.nodes.Tasklet: + tskl = state.add_tasklet( + name=f"tasklet_2", inputs={"__in1"}, outputs={"__out"}, code="__out = __in1 + 3.0" + ) + + state.add_edge(A_in, None, me, f"IN_{nb}", dace.Memlet("A[0:10]")) + state.add_edge(me, f"OUT_{nb}", tskl, "__in1", dace.Memlet("A[__i0]")) + me.add_in_connector(f"IN_{nb}") + me.add_out_connector(f"OUT_{nb}") + + state.add_edge(tskl, "__out", mx, f"IN_{nb}", dace.Memlet("B[__i0]")) + state.add_edge(mx, f"OUT_{nb}", state.add_access("B"), None, dace.Memlet("B[0:10]")) + mx.add_in_connector(f"IN_{nb}") + mx.add_out_connector(f"OUT_{nb}") + + +def _create_sdfg_double_read( + version: int, +) -> tuple[dace.SDFG]: + sdfg = dace.SDFG(util.unique_name(f"double_read_version_{version}")) + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + A_in = state.add_access("A") + me, mx = state.add_map("map", ndrange={"__i0": "0:10"}) + + if version == 0: + _create_sdfg_double_read_part_1(sdfg, state, me, mx, A_in, 0) + _create_sdfg_double_read_part_2(sdfg, state, me, mx, A_in, 1) + elif version == 1: + _create_sdfg_double_read_part_1(sdfg, state, me, mx, A_in, 1) + _create_sdfg_double_read_part_2(sdfg, state, me, mx, A_in, 0) + else: + raise ValueError(f"Does not know version {version}") + sdfg.validate() + return sdfg + + +def test_local_double_buffering_double_read_sdfg(): + sdfg0 = _create_sdfg_double_read(0) + sdfg1 = _create_sdfg_double_read(1) + args0 = {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in "AB"} + args1 = copy.deepcopy(args0) + + count0 = gtx_transformations.gt_create_local_double_buffering(sdfg0) + assert count0 == 1 + + count1 = gtx_transformations.gt_create_local_double_buffering(sdfg1) + assert count1 == 1 + + sdfg0(**args0) + sdfg1(**args1) + for name in args0: + assert np.allclose(args0[name], args1[name]), f"Failed verification in '{name}'." + + +def test_local_double_buffering_no_connection(): + """There is no direct connection between read and write.""" + sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_connection")) + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + A_in, B, A_out = (state.add_access(name) for name in "ABA") + + comp_tskl, me, mx = state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("B[__i0]")}, + input_nodes={A_in}, + output_nodes={B}, + external_edges=True, + ) + + fill_tasklet = state.add_tasklet( + name="fill_tasklet", + inputs=set(), + code="__out = 2.", + outputs={"__out"}, + ) + state.add_nedge(me, fill_tasklet, dace.Memlet()) + state.add_edge(fill_tasklet, "__out", mx, "IN_1", dace.Memlet("A[__i0]")) + state.add_edge(mx, "OUT_1", A_out, None, dace.Memlet("A[0:10]")) + mx.add_in_connector("IN_1") + mx.add_out_connector("OUT_1") + sdfg.validate() + + count = gtx_transformations.gt_create_local_double_buffering(sdfg) + assert count == 1 + + # Ensure that a second application of the transformation does not run again. + count_again = gtx_transformations.gt_create_local_double_buffering(sdfg) + assert count_again == 0 + + # Find the newly created access node. + comp_tasklet_producers = [in_edge.src for in_edge in state.in_edges(comp_tskl)] + assert len(comp_tasklet_producers) == 1 + new_double_buffer = comp_tasklet_producers[0] + assert isinstance(new_double_buffer, dace_nodes.AccessNode) + assert not any(new_double_buffer.data == name for name in "AB") + assert isinstance(new_double_buffer.desc(sdfg), dace_data.Scalar) + assert new_double_buffer.desc(sdfg).transient + + # The newly created access node, must have an empty Memlet to the fill tasklet. + read_dependencies = [ + out_edge.dst for out_edge in state.out_edges(new_double_buffer) if out_edge.data.is_empty() + ] + assert len(read_dependencies) == 1 + assert read_dependencies[0] is fill_tasklet + + res = {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in "AB"} + ref = {"A": np.full_like(res["A"], 2.0), "B": res["A"] + 10.0} + sdfg(**res) + for name in res: + assert np.allclose(res[name], ref[name]), f"Failed verification in '{name}'." + + +def test_local_double_buffering_no_apply(): + """Here it does not apply, because are all distinct.""" + sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_apply")) + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("B[__i0]")}, + external_edges=True, + ) + sdfg.validate() + + count = gtx_transformations.gt_create_local_double_buffering(sdfg) + assert count == 0 + + +def test_local_double_buffering_already_buffered(): + """It is already buffered.""" + sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_apply")) + state = sdfg.add_state(is_start_block=True) + sdfg.add_array( + "A", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + + tsklt, me, mx = state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("A[__i0]")}, + external_edges=True, + ) + + sdfg.add_scalar("tmp", dtype=dace.float64, transient=True) + tmp = state.add_access("tmp") + me_to_tskl_edge = next(iter(state.out_edges(me))) + + state.add_edge(me, me_to_tskl_edge.src_conn, tmp, None, dace.Memlet("A[__i0]")) + state.add_edge(tmp, None, tsklt, "__in1", dace.Memlet("tmp[0]")) + state.remove_edge(me_to_tskl_edge) + sdfg.validate() + + count = gtx_transformations.gt_create_local_double_buffering(sdfg) + assert count == 0 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py new file mode 100644 index 0000000000..1543a048ad --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py @@ -0,0 +1,84 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +# from . import util + + +# dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes +import dace + + +def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG("NAME") # util.unique_name("distributed_buffer_sdfg")) + + for name in ["a", "b", "tmp"]: + sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) + sdfg.arrays["tmp"].transient = True + sdfg.arrays["b"].shape = (100, 100) + + state1: dace.SDFGState = sdfg.add_state(is_start_block=True) + state1.add_mapped_tasklet( + "computation", + map_ranges={"__i1": "0:10", "__i2": "0:10"}, + inputs={"__in": dace.Memlet("a[__i1, __i2]")}, + code="__out = __in + 10.0", + outputs={"__out": dace.Memlet("tmp[__i1, __i2]")}, + external_edges=True, + ) + + state2 = sdfg.add_state_after(state1) + state2_tskl = state2.add_tasklet( + name="empty_blocker_tasklet", + inputs={}, + code="pass", + outputs={"__out"}, + side_effects=True, + ) + state2.add_edge( + state2_tskl, + "__out", + state2.add_access("a"), + None, + dace.Memlet("a[0, 0]"), + ) + + state3 = sdfg.add_state_after(state2) + state3.add_edge( + state3.add_access("tmp"), + None, + state3.add_access("b"), + None, + dace.Memlet("tmp[0:10, 0:10] -> [11:21, 22:32]"), + ) + sdfg.validate() + assert sdfg.number_of_nodes() == 3 + + return sdfg, state1 + + +def test_distributed_buffer_remover(): + sdfg, state1 = _mk_distributed_buffer_sdfg() + assert state1.number_of_nodes() == 5 + assert not any(dnode.data == "b" for dnode in state1.data_nodes()) + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert res is not None + + # Because the final state has now become empty + assert sdfg.number_of_nodes() == 3 + assert state1.number_of_nodes() == 6 + assert any(dnode.data == "b" for dnode in state1.data_nodes()) + assert any(dnode.data == "tmp" for dnode in state1.data_nodes()) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py new file mode 100644 index 0000000000..4ca44d43eb --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py @@ -0,0 +1,148 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_self_copy_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: + """Generates an SDFG that contains the self copying pattern.""" + sdfg = dace.SDFG(util.unique_name("self_copy_sdfg")) + state = sdfg.add_state(is_start_block=True) + + for name in "GT": + sdfg.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + transient=True, + ) + sdfg.arrays["G"].transient = False + g_read, tmp_node, g_write = (state.add_access(name) for name in "GTG") + + state.add_nedge(g_read, tmp_node, dace.Memlet("G[0:10, 0:10]")) + state.add_nedge(tmp_node, g_write, dace.Memlet("G[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state + + +def test_global_self_copy_elimination_only_pattern(): + """Contains only the pattern -> Total elimination.""" + sdfg, state = _make_self_copy_sdfg() + assert sdfg.number_of_nodes() == 1 + assert state.number_of_nodes() == 3 + assert util.count_nodes(state, dace_nodes.AccessNode) == 3 + assert state.number_of_edges() == 2 + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyGlobalSelfCopyElimination, validate=True, validate_all=True + ) + assert count != 0 + + assert sdfg.number_of_nodes() == 1 + assert ( + state.number_of_nodes() == 0 + ), f"Expected that 0 access nodes remained, but {state.number_of_nodes()} were there." + + +def test_global_self_copy_elimination_g_downstream(): + """`G` is read downstream. + + Since we ignore reads to `G` downstream, this will not influence the + transformation. + """ + sdfg, state1 = _make_self_copy_sdfg() + + # Add a read to `G` downstream. + state2 = sdfg.add_state_after(state1) + sdfg.add_array( + "output", + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + + state2.add_mapped_tasklet( + "downstream_computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("G[__i0, __i1]")}, + code="__out = __in + 10.0", + outputs={"__out": dace.Memlet("output[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + assert state2.number_of_nodes() == 5 + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyGlobalSelfCopyElimination, validate=True, validate_all=True + ) + assert count != 0 + + assert sdfg.number_of_nodes() == 2 + assert ( + state1.number_of_nodes() == 0 + ), f"Expected that 0 access nodes remained, but {state.number_of_nodes()} were there." + assert state2.number_of_nodes() == 5 + assert util.count_nodes(state2, dace_nodes.AccessNode) == 2 + assert util.count_nodes(state2, dace_nodes.MapEntry) == 1 + + +def test_global_self_copy_elimination_tmp_downstream(): + """`T` is read downstream. + + Because `T` is read downstream, the read to `G` will be retained, but the write + will be removed. + """ + sdfg, state1 = _make_self_copy_sdfg() + + # Add a read to `G` downstream. + state2 = sdfg.add_state_after(state1) + sdfg.add_array( + "output", + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + + state2.add_mapped_tasklet( + "downstream_computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("T[__i0, __i1]")}, + code="__out = __in + 10.0", + outputs={"__out": dace.Memlet("output[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + assert state2.number_of_nodes() == 5 + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyGlobalSelfCopyElimination, validate=True, validate_all=True + ) + assert count != 0 + + assert sdfg.number_of_nodes() == 2 + assert state1.number_of_nodes() == 2 + assert util.count_nodes(state1, dace_nodes.AccessNode) == 2 + assert all(state1.degree(node) == 1 for node in state1.nodes()) + assert next(iter(state1.source_nodes())).data == "G" + assert next(iter(state1.sink_nodes())).data == "T" + + assert state2.number_of_nodes() == 5 + assert util.count_nodes(state2, dace_nodes.AccessNode) == 2 + assert util.count_nodes(state2, dace_nodes.MapEntry) == 1 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index 30266d71d1..89f067e5a9 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -24,11 +24,12 @@ def _get_trivial_gpu_promotable( tasklet_code: str, + trivial_map_range: str = "0", ) -> tuple[dace.SDFG, dace_nodes.MapEntry, dace_nodes.MapEntry]: - """Returns an SDFG that is suitable to test the `TrivialGPUMapPromoter` promoter. + """Returns an SDFG that is suitable to test the `TrivialGPUMapElimination` promoter. The first map is a trivial map (`Map[__trival_gpu_it=0]`) containing a Tasklet, - that does not have an output, but writes a scalar value into `tmp` (output + that does not have an input, but writes a scalar value into `tmp` (output connector `__out`), the body of this Tasklet can be controlled through the `tasklet_code` argument. The second map (`Map[__i0=0:N]`) contains a Tasklet that computes the sum of its @@ -41,6 +42,7 @@ def _get_trivial_gpu_promotable( Args: tasklet_code: The body of the Tasklet inside the trivial map. + trivial_map_range: Range of the trivial map, defaults to `"0"`. """ sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) state = sdfg.add_state("state", is_start_block=True) @@ -57,11 +59,11 @@ def _get_trivial_gpu_promotable( _, trivial_map_entry, _ = state.add_mapped_tasklet( "trivail_top_tasklet", - map_ranges={"__trivial_gpu_it": "0"}, + map_ranges={"__trivial_gpu_it": trivial_map_range}, inputs={}, code=tasklet_code, outputs={"__out": dace.Memlet("tmp[0]")}, - output_nodes={"tmp": tmp}, + output_nodes={tmp}, external_edges=True, schedule=schedule, ) @@ -74,15 +76,15 @@ def _get_trivial_gpu_promotable( }, code="__out = __in0 + __in1", outputs={"__out": dace.Memlet("b[__i0]")}, - input_nodes={"a": a, "tmp": tmp}, - output_nodes={"b": b}, + input_nodes={a, tmp}, + output_nodes={b}, external_edges=True, schedule=schedule, ) return sdfg, trivial_map_entry, second_map_entry -def test_trivial_gpu_map_promoter(): +def test_trivial_gpu_map_promoter_1(): """Tests if the GPU map promoter works. By using a body such as `__out = 3.0`, the transformation will apply. @@ -92,15 +94,15 @@ def test_trivial_gpu_map_promoter(): org_second_map_ranges = copy.deepcopy(second_map_entry.map.range) nb_runs = sdfg.apply_transformations_once_everywhere( - gtx_dace_fieldview_gpu_utils.TrivialGPUMapPromoter(), + gtx_dace_fieldview_gpu_utils.TrivialGPUMapElimination(do_not_fuse=True), validate=True, validate_all=True, ) assert ( nb_runs == 1 - ), f"Expected that 'TrivialGPUMapPromoter' applies once but it applied {nb_runs}." + ), f"Expected that 'TrivialGPUMapElimination' applies once but it applied {nb_runs}." trivial_map_params = trivial_map_entry.map.params - trivial_map_ranges = trivial_map_ranges.map.range + trivial_map_ranges = trivial_map_entry.map.range second_map_params = second_map_entry.map.params second_map_ranges = second_map_entry.map.range @@ -119,32 +121,82 @@ def test_trivial_gpu_map_promoter(): assert sdfg.is_valid() -def test_trivial_gpu_map_promoter(): +def test_trivial_gpu_map_promoter_2(): """Test if the GPU promoter does not fuse a special trivial map. By using a body such as `__out = __trivial_gpu_it` inside the - Tasklet's body, the map parameter is now used, and thus can not be fused. + Tasklet's body, the map parameter must now be replaced inside + the Tasklet's body. """ sdfg, trivial_map_entry, second_map_entry = _get_trivial_gpu_promotable( - "__out = __trivial_gpu_it" + tasklet_code="__out = __trivial_gpu_it", + trivial_map_range="2", + ) + state: dace.SDFGStae = sdfg.nodes()[0] + trivial_tasklet: dace_nodes.Tasklet = next( + iter( + out_edge.dst + for out_edge in state.out_edges(trivial_map_entry) + if isinstance(out_edge.dst, dace_nodes.Tasklet) + ) ) - org_trivial_map_params = list(trivial_map_entry.map.params) - org_second_map_params = list(second_map_entry.map.params) nb_runs = sdfg.apply_transformations_once_everywhere( - gtx_dace_fieldview_gpu_utils.TrivialGPUMapPromoter(), + gtx_dace_fieldview_gpu_utils.TrivialGPUMapElimination(do_not_fuse=True), validate=True, validate_all=True, ) - assert ( - nb_runs == 0 - ), f"Expected that 'TrivialGPUMapPromoter' does not apply but it applied {nb_runs}." - trivial_map_params = trivial_map_entry.map.params - second_map_params = second_map_entry.map.params - assert ( - trivial_map_params == org_trivial_map_params - ), f"Expected the trivial map to have parameters '{org_trivial_map_params}', but it had '{trivial_map_params}'." - assert ( - second_map_params == org_second_map_params - ), f"Expected the trivial map to have parameters '{org_trivial_map_params}', but it had '{trivial_map_params}'." - assert sdfg.is_valid() + assert nb_runs == 1 + + expected_trivial_code = "__out = 2" + assert trivial_tasklet.code == expected_trivial_code + + +def test_set_gpu_properties(): + """Tests the `gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize()`.""" + sdfg = dace.SDFG("gpu_properties_test") + state = sdfg.add_state(is_start_block=True) + + map_entries: dict[int, dace_nodes.MapEntry] = {} + for dim in [1, 2, 3]: + shape = (10,) * dim + sdfg.add_array( + f"A_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global + ) + sdfg.add_array( + f"B_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global + ) + _, me, _ = state.add_mapped_tasklet( + f"map_{dim}", + map_ranges={f"__i{i}": f"0:{s}" for i, s in enumerate(shape)}, + inputs={"__in": dace.Memlet(f"A_{dim}[{','.join(f'__i{i}' for i in range(dim))}]")}, + code="__out = math.cos(__in)", + outputs={"__out": dace.Memlet(f"B_{dim}[{','.join(f'__i{i}' for i in range(dim))}]")}, + external_edges=True, + ) + map_entries[dim] = me + + sdfg.apply_gpu_transformations() + sdfg.validate() + + gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize( + sdfg=sdfg, + block_size=(10, "11", 12), + launch_factor_2d=2, + block_size_2d=(2, 2, 2), + launch_bounds_3d=200, + ) + + map1, map2, map3 = (map_entries[d].map for d in [1, 2, 3]) + + assert len(map1.params) == 1 + assert map1.gpu_block_size == [10, 1, 1] + assert map1.gpu_launch_bounds == "0" + + assert len(map2.params) == 2 + assert map2.gpu_block_size == [2, 2, 1] + assert map2.gpu_launch_bounds == "8" + + assert len(map3.params) == 3 + assert map3.gpu_block_size == [10, 11, 12] + assert map3.gpu_launch_bounds == "200" diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py index c1e0ddd2f6..aac58eb32c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py @@ -29,7 +29,7 @@ def _get_simple_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], np The k blocking transformation can be applied to the SDFG, however no node can be taken out. This is because how it is constructed. However, applying - some simplistic transformations this can be done. + some simplistic transformations will enable the transformation. """ sdfg = dace.SDFG(util.unique_name("simple_block_sdfg")) state = sdfg.add_state("state", is_start_block=True) @@ -136,6 +136,83 @@ def _get_chained_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], n return sdfg, lambda a, b: (a + (2 * b.reshape((-1, 1)) + 3)) +def _get_sdfg_with_empty_memlet( + first_tasklet_independent: bool, + only_empty_memlets: bool, +) -> tuple[ + dace.SDFG, dace_nodes.MapEntry, dace_nodes.Tasklet, dace_nodes.AccessNode, dace_nodes.Tasklet +]: + """Generates an SDFG with an empty tasklet. + + The map contains two (serial) tasklets, connected through an access node. + The first tasklet has an empty memlet that connects it to the map entry. + Depending on `first_tasklet_independent` the tasklet is either independent + or not. The second tasklet has an additional in connector that accesses an array. + + If `only_empty_memlets` is given then the second memlet will only depend + on the input of the first tasklet. However, since it is connected to the + map exit, it will be classified as dependent. + + Returns: + The function returns the SDFG, the map entry and the first tasklet (that + is either dependent or independent), the access node between the tasklets + and the second tasklet that is always dependent. + """ + sdfg = dace.SDFG(util.unique_name("empty_memlet_sdfg")) + state = sdfg.add_state("state", is_start_block=True) + sdfg.add_symbol("N", dace.int32) + sdfg.add_symbol("M", dace.int32) + sdfg.add_array("b", ("N", "M"), dace.float64, transient=False) + b = state.add_access("b") + sdfg.add_scalar("tmp", dtype=dace.float64, transient=True) + tmp = state.add_access("tmp") + + if not only_empty_memlets: + sdfg.add_array("a", ("N", "M"), dace.float64, transient=False) + a = state.add_access("a") + + # This is the first tasklet. + task1 = state.add_tasklet( + "task1", + inputs={}, + outputs={"__out0"}, + code="__out0 = 1.0" if first_tasklet_independent else "__out0 = j", + ) + + if only_empty_memlets: + task2 = state.add_tasklet( + "task2", inputs={"__in0"}, outputs={"__out0"}, code="__out0 = __in0 + 1.0" + ) + else: + task2 = state.add_tasklet( + "task2", inputs={"__in0", "__in1"}, outputs={"__out0"}, code="__out0 = __in0 + __in1" + ) + + # Now create the map + mentry, mexit = state.add_map("map", ndrange={"i": "0:N", "j": "0:M"}) + + if not only_empty_memlets: + state.add_edge(a, None, mentry, "IN_a", dace.Memlet("a[0:N, 0:M]")) + state.add_edge(mentry, "OUT_a", task2, "__in1", dace.Memlet("a[i, j]")) + + state.add_edge(task2, "__out0", mexit, "IN_b", dace.Memlet("b[i, j]")) + state.add_edge(mexit, "OUT_b", b, None, dace.Memlet("b[0:N, 0:M]")) + + state.add_edge(mentry, None, task1, None, dace.Memlet()) + state.add_edge(task1, "__out0", tmp, None, dace.Memlet("tmp[0]")) + state.add_edge(tmp, None, task2, "__in0", dace.Memlet("tmp[0]")) + + if not only_empty_memlets: + mentry.add_in_connector("IN_a") + mentry.add_out_connector("OUT_a") + mexit.add_in_connector("IN_b") + mexit.add_out_connector("OUT_b") + + sdfg.validate() + + return sdfg, mentry, task1, tmp, task2 + + def test_only_dependent(): """Just applying the transformation to the SDFG. @@ -152,11 +229,12 @@ def test_only_dependent(): ref = reff(a, b) # Apply the transformation - sdfg.apply_transformations_repeated( + count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) + assert count == 1 assert len(sdfg.states()) == 1 state = sdfg.states()[0] @@ -216,11 +294,12 @@ def test_intermediate_access_node(): assert np.allclose(ref, c) # Apply the transformation. - sdfg.apply_transformations_repeated( + count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) + assert count == 1 # Inspect if the SDFG was modified correctly. # We only inspect `tmp` which now has to be between the two maps. @@ -254,12 +333,12 @@ def test_chained_access() -> None: c[:] = 0 # Apply the transformation. - ret = sdfg.apply_transformations_repeated( + count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) - assert ret == 1, f"Expected that the transformation was applied 1 time, but it was {ret}." + assert count == 1 # Now run the SDFG to see if it is still the same sdfg(a=a, b=b, c=c, M=M, N=N) @@ -305,3 +384,422 @@ def test_chained_access() -> None: assert isinstance(inner_tasklet, dace_nodes.Tasklet) assert inner_tasklet not in first_level_tasklets + + +def test_direct_map_exit_connection() -> dace.SDFG: + """Generates a SDFG with a mapped independent tasklet connected to the map exit. + + Because the tasklet is connected to the map exit it can not be independent. + """ + sdfg = dace.SDFG(util.unique_name("mapped_tasklet_sdfg")) + state = sdfg.add_state("state", is_start_block=True) + sdfg.add_array("a", (10,), dace.float64, transient=False) + sdfg.add_array("b", (10, 30), dace.float64, transient=False) + tsklt, me, mx = state.add_mapped_tasklet( + name="comp", + map_ranges=dict(i=f"0:10", j=f"0:30"), + inputs=dict(__in0=dace.Memlet("a[i]")), + outputs=dict(__out=dace.Memlet("b[i, j]")), + code="__out = __in0 + 1", + external_edges=True, + ) + + assert all(out_edge.dst is tsklt for out_edge in state.out_edges(me)) + assert all(in_edge.src is tsklt for in_edge in state.in_edges(mx)) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1 + + assert all(isinstance(out_edge.dst, dace_nodes.MapEntry) for out_edge in state.out_edges(me)) + assert all(isinstance(in_edge.src, dace_nodes.MapExit) for in_edge in state.in_edges(mx)) + + +def test_empty_memlet_1(): + sdfg, mentry, itask, tmp, task2 = _get_sdfg_with_empty_memlet( + first_tasklet_independent=True, + only_empty_memlets=False, + ) + state: dace.SDFGState = next(iter(sdfg.nodes())) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1 + + scope_dict = state.scope_dict() + assert scope_dict[mentry] is None + assert scope_dict[itask] is mentry + assert scope_dict[tmp] is mentry + assert scope_dict[task2] is not mentry + assert scope_dict[task2] is not None + assert all( + isinstance(in_edge.src, dace_nodes.MapEntry) and in_edge.src is not mentry + for in_edge in state.in_edges(task2) + ) + + +def test_empty_memlet_2(): + sdfg, mentry, dtask, tmp, task2 = _get_sdfg_with_empty_memlet( + first_tasklet_independent=False, + only_empty_memlets=False, + ) + state: dace.SDFGState = next(iter(sdfg.nodes())) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1 + + # Find the inner map entry + assert all( + isinstance(out_edge.dst, dace_nodes.MapEntry) for out_edge in state.out_edges(mentry) + ) + inner_mentry = next(iter(state.out_edges(mentry))).dst + + scope_dict = state.scope_dict() + assert scope_dict[mentry] is None + assert scope_dict[inner_mentry] is mentry + assert scope_dict[dtask] is inner_mentry + assert scope_dict[tmp] is inner_mentry + assert scope_dict[task2] is inner_mentry + + +def test_empty_memlet_3(): + # This is the only interesting case with only empty memlet. + sdfg, mentry, dtask, tmp, task2 = _get_sdfg_with_empty_memlet( + first_tasklet_independent=False, + only_empty_memlets=True, + ) + state: dace.SDFGState = next(iter(sdfg.nodes())) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1 + + # The top map only has a single output, which is the empty edge, that is holding + # the inner map entry in the scope. + assert all(out_edge.data.is_empty() for out_edge in state.out_edges(mentry)) + assert state.in_degree(mentry) == 0 + assert state.out_degree(mentry) == 1 + assert all( + isinstance(out_edge.dst, dace_nodes.MapEntry) for out_edge in state.out_edges(mentry) + ) + + inner_mentry = next(iter(state.out_edges(mentry))).dst + + scope_dict = state.scope_dict() + assert scope_dict[mentry] is None + assert scope_dict[inner_mentry] is mentry + assert scope_dict[dtask] is inner_mentry + assert scope_dict[tmp] is inner_mentry + assert scope_dict[task2] is inner_mentry + + +def _make_loop_blocking_sdfg_with_inner_map( + add_independent_part: bool, +) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapEntry]: + """ + Generate the SDFGs with an inner map. + + The SDFG has an inner map that is classified as dependent. If + `add_independent_part` is `True` then the SDFG has a part that is independent. + Note that everything is read from a single connector. + + Return: + The function will return the SDFG, the state and the map entry for the outer + and inner map. + """ + sdfg = dace.SDFG(util.unique_name("sdfg_with_inner_map")) + state = sdfg.add_state(is_start_block=True) + + for name in "AB": + sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) + + me_out, mx_out = state.add_map("outer_map", ndrange={"__i0": "0:10"}) + me_in, mx_in = state.add_map("inner_map", ndrange={"__i1": "0:10"}) + A, B = (state.add_access(name) for name in "AB") + tskl = state.add_tasklet( + "computation", inputs={"__in1", "__in2"}, outputs={"__out"}, code="__out = __in1 + __in2" + ) + + if add_independent_part: + sdfg.add_array("C", shape=(10,), dtype=dace.float64, transient=False) + sdfg.add_scalar("tmp", dtype=dace.float64, transient=True) + sdfg.add_scalar("tmp2", dtype=dace.float64, transient=True) + tmp, tmp2, C = (state.add_access(name) for name in ("tmp", "tmp2", "C")) + tskli = state.add_tasklet( + "independent_comp", inputs={"__field"}, outputs={"__out"}, code="__out = __field[1, 1]" + ) + + # construct the inner map of the map. + state.add_edge(A, None, me_out, "IN_A", dace.Memlet("A[0:10, 0:10]")) + me_out.add_in_connector("IN_A") + state.add_edge(me_out, "OUT_A", me_in, "IN_A", dace.Memlet("A[__i0, 0:10]")) + me_out.add_out_connector("OUT_A") + me_in.add_in_connector("IN_A") + state.add_edge(me_in, "OUT_A", tskl, "__in1", dace.Memlet("A[__i0, __i1]")) + me_in.add_out_connector("OUT_A") + + state.add_edge(me_out, "OUT_A", me_in, "IN_A1", dace.Memlet("A[__i0, 0:10]")) + me_in.add_in_connector("IN_A1") + state.add_edge(me_in, "OUT_A1", tskl, "__in2", dace.Memlet("A[__i0, 9 - __i1]")) + me_in.add_out_connector("OUT_A1") + + state.add_edge(tskl, "__out", mx_in, "IN_B", dace.Memlet("B[__i0, __i1]")) + mx_in.add_in_connector("IN_B") + state.add_edge(mx_in, "OUT_B", mx_out, "IN_B", dace.Memlet("B[__i0, 0:10]")) + mx_in.add_out_connector("OUT_B") + mx_out.add_in_connector("IN_B") + state.add_edge(mx_out, "OUT_B", B, None, dace.Memlet("B[0:10, 0:10]")) + mx_out.add_out_connector("OUT_B") + + # If requested add a part that is independent, i.e. is before the inner loop + if add_independent_part: + state.add_edge(me_out, "OUT_A", tskli, "__field", dace.Memlet("A[0:10, 0:10]")) + state.add_edge(tskli, "__out", tmp, None, dace.Memlet("tmp[0]")) + state.add_edge(tmp, None, tmp2, None, dace.Memlet("tmp2[0]")) + state.add_edge(tmp2, None, mx_out, "IN_tmp", dace.Memlet("C[__i0]")) + mx_out.add_in_connector("IN_tmp") + state.add_edge(mx_out, "OUT_tmp", C, None, dace.Memlet("C[0:10]")) + mx_out.add_out_connector("OUT_tmp") + + sdfg.validate() + return sdfg, state, me_out, me_in + + +def test_loop_blocking_inner_map(): + """ + Tests with an inner map, without an independent part. + """ + sdfg, state, outer_map, inner_map = _make_loop_blocking_sdfg_with_inner_map(False) + assert all(oedge.dst is inner_map for oedge in state.out_edges(outer_map)) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="__i0"), + validate=True, + validate_all=True, + ) + assert count == 1 + assert all( + oedge.dst is not inner_map and isinstance(oedge.dst, dace_nodes.MapEntry) + for oedge in state.out_edges(outer_map) + ) + inner_blocking_map: dace_nodes.MapEntry = next( + oedge.dst + for oedge in state.out_edges(outer_map) + if isinstance(oedge.dst, dace_nodes.MapEntry) + ) + assert inner_blocking_map is not inner_map + + assert all(oedge.dst is inner_map for oedge in state.out_edges(inner_blocking_map)) + + +def test_loop_blocking_inner_map_with_independent_part(): + """ + Tests with an inner map with an independent part. + """ + sdfg, state, outer_map, inner_map = _make_loop_blocking_sdfg_with_inner_map(True) + + # Find the parts that are independent. + itskl: dace_nodes.Tasklet = next( + oedge.dst + for oedge in state.out_edges(outer_map) + if isinstance(oedge.dst, dace_nodes.Tasklet) + ) + assert itskl.label == "independent_comp" + i_access_node: dace_nodes.AccessNode = next(oedge.dst for oedge in state.out_edges(itskl)) + assert i_access_node.data == "tmp" + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="__i0"), + validate=True, + validate_all=True, + ) + assert count == 1 + inner_blocking_map: dace_nodes.MapEntry = next( + oedge.dst + for oedge in state.out_edges(outer_map) + if isinstance(oedge.dst, dace_nodes.MapEntry) + ) + assert inner_blocking_map is not inner_map + + assert all(oedge.dst in {inner_blocking_map, itskl} for oedge in state.out_edges(outer_map)) + assert state.scope_dict()[i_access_node] is outer_map + assert all(oedge.dst is inner_blocking_map for oedge in state.out_edges(i_access_node)) + + +def _make_mixed_memlet_sdfg( + tskl1_independent: bool, +) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.Tasklet, dace_nodes.Tasklet]: + """ + Generates the SDFGs for the mixed Memlet tests. + + The SDFG that is generated has the following structure: + - `tsklt2`, is always dependent, it has an incoming connection from the + map entry, and an incoming, but empty, connection with `tskl1`. + - `tskl1` is connected to the map entry, depending on `tskl1_independent` + it is independent or dependent, it has an empty connection to `tskl2`, + thus it is sequenced before. + - Both have connection to other nodes down stream, but they are dependent. + + Returns: + A tuple containing the following objects. + - The SDFG. + - The SDFG state. + - The outer map entry node. + - `tskl1`. + - `tskl2`. + """ + sdfg = dace.SDFG(util.unique_name("mixed_memlet_sdfg")) + state = sdfg.add_state(is_start_block=True) + names_array = ["A", "B", "C"] + names_scalar = ["tmp1", "tmp2"] + for aname in names_array: + sdfg.add_array( + aname, + shape=((10,) if aname == "A" else (10, 10)), + dtype=dace.float64, + transient=False, + ) + for sname in names_scalar: + sdfg.add_scalar( + sname, + dtype=dace.float64, + transient=True, + ) + A, B, C, tmp1, tmp2 = (state.add_access(name) for name in names_array + names_scalar) + + me, mx = state.add_map("outer_map", ndrange={"i": "0:10", "j": "0:10"}) + tskl1 = state.add_tasklet( + "tskl1", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1" if tskl1_independent else "__out = __in1 + j", + ) + tskl2 = state.add_tasklet( + "tskl2", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 + 10.0", + ) + tskl3 = state.add_tasklet( + "tskl3", + inputs={"__in1", "__in2"}, + outputs={"__out"}, + code="__out = __in1 + __in2", + ) + + state.add_edge(A, None, me, "IN_A", dace.Memlet("A[0:10]")) + me.add_in_connector("IN_A") + state.add_edge(me, "OUT_A", tskl1, "__in1", dace.Memlet("A[i]")) + me.add_out_connector("OUT_A") + state.add_edge(tskl1, "__out", tmp1, None, dace.Memlet("tmp1[0]")) + + state.add_edge(B, None, me, "IN_B", dace.Memlet("B[0:10, 0:10]")) + me.add_in_connector("IN_B") + state.add_edge(me, "OUT_B", tskl2, "__in1", dace.Memlet("B[i, j]")) + me.add_out_connector("OUT_B") + state.add_edge(tskl2, "__out", tmp2, None, dace.Memlet("tmp2[0]")) + + # Add the empty Memlet that sequences `tskl1` before `tskl2`. + state.add_edge(tskl1, None, tskl2, None, dace.Memlet()) + + state.add_edge(tmp1, None, tskl3, "__in1", dace.Memlet("tmp1[0]")) + state.add_edge(tmp2, None, tskl3, "__in2", dace.Memlet("tmp2[0]")) + state.add_edge(tskl3, "__out", mx, "IN_C", dace.Memlet("C[i, j]")) + mx.add_in_connector("IN_C") + state.add_edge(mx, "OUT_C", C, None, dace.Memlet("C[0:10, 0:10]")) + mx.add_out_connector("OUT_C") + sdfg.validate() + + return (sdfg, state, me, tskl1, tskl2) + + +def _apply_and_run_mixed_memlet_sdfg(sdfg: dace.SDFG) -> None: + ref = { + "A": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "B": np.array(np.random.rand(10, 10), dtype=np.float64, copy=True), + "C": np.array(np.random.rand(10, 10), dtype=np.float64, copy=True), + } + res = copy.deepcopy(ref) + sdfg(**ref) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=2, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1, f"Expected one application, but git {count}" + sdfg(**res) + assert all(np.allclose(ref[name], res[name]) for name in ref) + + +def test_loop_blocking_mixked_memlets_1(): + sdfg, state, me, tskl1, tskl2 = _make_mixed_memlet_sdfg(True) + mx = state.exit_node(me) + + _apply_and_run_mixed_memlet_sdfg(sdfg) + scope_dict = state.scope_dict() + + # Ensure that `tskl1` is independent. + assert scope_dict[tskl1] is me + + # The output of `tskl1`, which is `tmp1` should also be classified as independent. + tmp1 = next(iter(edge.dst for edge in state.out_edges(tskl1) if not edge.data.is_empty())) + assert scope_dict[tmp1] is me + assert isinstance(tmp1, dace_nodes.AccessNode) + assert tmp1.data == "tmp1" + + # Find the inner map. + inner_map_entry: dace_nodes.MapEntry = scope_dict[tskl2] + assert inner_map_entry is not me and isinstance(inner_map_entry, dace_nodes.MapEntry) + inner_map_exit: dace_nodes.MapExit = state.exit_node(inner_map_entry) + + outer_scope = {tskl1, tmp1, inner_map_entry, inner_map_exit, mx} + for node in state.nodes(): + if scope_dict[node] is None: + assert (node is me) or ( + isinstance(node, dace_nodes.AccessNode) and node.data in {"A", "B", "C"} + ) + elif scope_dict[node] is me: + assert node in outer_scope + else: + assert ( + (node is inner_map_exit) + or (isinstance(node, dace_nodes.AccessNode) and node.data == "tmp2") + or (isinstance(node, dace_nodes.Tasklet) and node.label in {"tskl2", "tskl3"}) + ) + + +def test_loop_blocking_mixked_memlets_2(): + sdfg, state, me, tskl1, tskl2 = _make_mixed_memlet_sdfg(False) + mx = state.exit_node(me) + + _apply_and_run_mixed_memlet_sdfg(sdfg) + scope_dict = state.scope_dict() + + # Because `tskl1` is now dependent, everything is now dependent. + inner_map_entry = scope_dict[tskl1] + assert isinstance(inner_map_entry, dace_nodes.MapEntry) + assert inner_map_entry is not me + + for node in state.nodes(): + if scope_dict[node] is None: + assert (node is me) or ( + isinstance(node, dace_nodes.AccessNode) and node.data in {"A", "B", "C"} + ) + elif scope_dict[node] is me: + assert isinstance(node, dace_nodes.MapEntry) or (node is mx) + else: + assert scope_dict[node] is inner_map_entry diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py new file mode 100644 index 0000000000..1a4ce6d047 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py @@ -0,0 +1,264 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np +import copy + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_test_data(names: list[str]) -> dict[str, np.ndarray]: + return {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in names} + + +def _make_test_sdfg( + output_name: str = "G", + input_name: str = "G", + tmp_name: str = "T", + array_size: int | str = 10, + tmp_size: int | str | None = None, + map_range: tuple[int | str, int | str] | None = None, + tmp_to_glob_memlet: str | None = None, + in_offset: str | None = None, + out_offset: str | None = None, +) -> dace.SDFG: + if isinstance(array_size, str): + array_size = sdfg.add_symbol(array_size, dace.int32, find_new_name=True) + if tmp_size is None: + tmp_size = array_size + if map_range is None: + map_range = (0, array_size) + if tmp_to_glob_memlet is None: + tmp_to_glob_memlet = f"{tmp_name}[0:{array_size}] -> [0:{array_size}]" + elif tmp_to_glob_memlet[0] == "[": + tmp_to_glob_memlet = tmp_name + tmp_to_glob_memlet + if in_offset is None: + in_offset = "0" + if out_offset is None: + out_offset = in_offset + + sdfg = dace.SDFG(util.unique_name("map_buffer")) + state = sdfg.add_state(is_start_block=True) + names = {input_name, tmp_name, output_name} + for name in names: + sdfg.add_array( + name, + shape=((array_size,) if name != tmp_name else (tmp_size,)), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays[tmp_name].transient = True + + input_ac = state.add_access(input_name) + tmp_ac = state.add_access(tmp_name) + output_ac = state.add_access(output_name) + + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": f"{map_range[0]}:{map_range[1]}"}, + inputs={"__in1": dace.Memlet(data=input_ac.data, subset=f"__i0 + {in_offset}")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet(data=tmp_ac.data, subset=f"__i0 + {out_offset}")}, + input_nodes={input_ac}, + output_nodes={tmp_ac}, + external_edges=True, + ) + state.add_edge( + tmp_ac, + None, + output_ac, + None, + dace.Memlet(tmp_to_glob_memlet), + ) + sdfg.validate() + return sdfg + + +def _perform_test( + sdfg: dace.SDFG, + xform: gtx_transformations.GT4PyMapBufferElimination, + exp_count: int, + array_size: int = 10, +) -> None: + ref = { + name: np.array(np.random.rand(array_size), dtype=np.float64, copy=True) + for name, desc in sdfg.arrays.items() + if not desc.transient + } + if "array_size" in sdfg.symbols: + ref["array_size"] = array_size + + res = copy.deepcopy(ref) + sdfg(**ref) + + count = sdfg.apply_transformations_repeated([xform], validate=True, validate_all=True) + assert count == exp_count, f"Expected {exp_count} applications, but got {count}" + + if count == 0: + return + + sdfg(**res) + assert all(np.allclose(ref[name], res[name]) for name in ref.keys()), f"Failed for '{name}'." + + +def test_map_buffer_elimination_simple(): + sdfg = _make_test_sdfg() + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=True), + exp_count=1, + ) + + +def test_map_buffer_elimination_simple_2(): + sdfg = _make_test_sdfg() + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=0, + ) + + +def test_map_buffer_elimination_simple_3(): + sdfg = _make_test_sdfg(input_name="A", output_name="O") + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_offset_1(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + tmp_to_glob_memlet="[2:8] -> [2:8]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_offset_2(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + in_offset="-2", + out_offset="-2", + tmp_to_glob_memlet="[0:6] -> [0:6]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_offset_3(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + in_offset="-2", + out_offset="-2", + tmp_to_glob_memlet="[0:6] -> [2:8]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_offset_4(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + in_offset="-2", + out_offset="-2", + tmp_to_glob_memlet="[1:7] -> [2:8]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=0, + ) + + +def test_map_buffer_elimination_offset_5(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + tmp_size=6, + in_offset="0", + out_offset="-2", + tmp_to_glob_memlet="[0:6] -> [2:8]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_not_apply(): + """Indirect accessing, because of this the double buffer is needed.""" + sdfg = dace.SDFG(util.unique_name("map_buffer")) + state = sdfg.add_state(is_start_block=True) + + names = ["A", "tmp", "idx"] + for name in names: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.int32 if name == "tmp" else dace.float64, + transient=False, + ) + sdfg.arrays["tmp"].transient = True + + tmp = state.add_access("tmp") + state.add_mapped_tasklet( + "indirect_accessing", + map_ranges={"__i0": "0:10"}, + inputs={ + "__field": dace.Memlet("A[0:10]"), + "__idx": dace.Memlet("idx[__i0]"), + }, + code="__out = __field[__idx]", + outputs={"__out": dace.Memlet("tmp[__i0]")}, + output_nodes={tmp}, + external_edges=True, + ) + state.add_nedge(tmp, state.add_access("A"), dace.Memlet("tmp[0:10] -> [0:10]")) + + # TODO(phimuell): Update the transformation such that we can specify + # `assume_pointwise=True` and the test would still pass. + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMapBufferElimination( + assume_pointwise=False, + ), + validate=True, + validate_all=True, + ) + assert count == 0 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py index c9d467ba80..b468b80b8e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py @@ -58,14 +58,14 @@ def _make_serial_sdfg_1( inputs={"__in0": dace.Memlet("a[__i0, __i1]")}, code="__out = __in0 + 1.0", outputs={"__out": dace.Memlet("tmp[__i0, __i1]")}, - output_nodes={"tmp": tmp}, + output_nodes={tmp}, external_edges=True, ) state.add_mapped_tasklet( name="second_computation", map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], - input_nodes={"tmp": tmp}, + input_nodes={tmp}, inputs={"__in0": dace.Memlet("tmp[__i0, __i1]")}, code="__out = __in0 + 3.0", outputs={"__out": dace.Memlet("b[__i0, __i1]")}, @@ -118,17 +118,14 @@ def _make_serial_sdfg_2( "__out0": dace.Memlet("tmp_1[__i0, __i1]"), "__out1": dace.Memlet("tmp_2[__i0, __i1]"), }, - output_nodes={ - "tmp_1": tmp_1, - "tmp_2": tmp_2, - }, + output_nodes={tmp_1, tmp_2}, external_edges=True, ) state.add_mapped_tasklet( name="first_computation", map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], - input_nodes={"tmp_1": tmp_1}, + input_nodes={tmp_1}, inputs={"__in0": dace.Memlet("tmp_1[__i0, __i1]")}, code="__out = __in0 + 3.0", outputs={"__out": dace.Memlet("b[__i0, __i1]")}, @@ -137,7 +134,7 @@ def _make_serial_sdfg_2( state.add_mapped_tasklet( name="second_computation", map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], - input_nodes={"tmp_2": tmp_2}, + input_nodes={tmp_2}, inputs={"__in0": dace.Memlet("tmp_2[__i0, __i1]")}, code="__out = __in0 - 3.0", outputs={"__out": dace.Memlet("c[__i0, __i1]")}, @@ -194,14 +191,14 @@ def _make_serial_sdfg_3( }, code="__out = __in0 + __in1", outputs={"__out": dace.Memlet("tmp[__i0]")}, - output_nodes={"tmp": tmp}, + output_nodes={tmp}, external_edges=True, ) state.add_mapped_tasklet( name="indirect_access", map_ranges=[("__i0", f"0:{N_output}")], - input_nodes={"tmp": tmp}, + input_nodes={tmp}, inputs={ "__index": dace.Memlet("idx[__i0]"), "__array": dace.Memlet.simple("tmp", subset_str=f"0:{N_input}", num_accesses=1), @@ -220,19 +217,19 @@ def test_exclusive_itermediate(): sdfg = _make_serial_sdfg_1(N) # Now apply the optimizations. - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 1 assert "tmp" not in sdfg.arrays # Test if the intermediate is a scalar intermediate_nodes: list[dace_nodes.Node] = [ node - for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) + for node in util.count_nodes(sdfg, dace_nodes.AccessNode, True) if node.data not in ["a", "b"] ] assert len(intermediate_nodes) == 1 @@ -257,19 +254,19 @@ def test_shared_itermediate(): sdfg.arrays["tmp"].transient = False # Now apply the optimizations. - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 1 assert "tmp" in sdfg.arrays # Test if the intermediate is a scalar intermediate_nodes: list[dace_nodes.Node] = [ node - for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) + for node in util.count_nodes(sdfg, dace_nodes.AccessNode, True) if node.data not in ["a", "b", "tmp"] ] assert len(intermediate_nodes) == 1 @@ -291,21 +288,21 @@ def test_pure_output_node(): """Tests the path of a pure intermediate.""" N = 10 sdfg = _make_serial_sdfg_2(N) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 3 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 3 # The first fusion will only bring it down to two maps. sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 1 a = np.random.rand(N, N) b = np.empty_like(a) @@ -327,17 +324,17 @@ def test_array_intermediate(): """ N = 10 sdfg = _make_serial_sdfg_1(N) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations_repeated([dace_dataflow.MapExpansion]) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 4 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 4 # Now perform the fusion sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(only_toplevel_maps=True), + gtx_transformations.MapFusionSerial(only_toplevel_maps=True), validate=True, validate_all=True, ) - map_entries = util._count_nodes(sdfg, dace_nodes.MapEntry, return_nodes=True) + map_entries = util.count_nodes(sdfg, dace_nodes.MapEntry, return_nodes=True) scope = next(iter(sdfg.states())).scope_dict() assert len(map_entries) == 3 @@ -349,7 +346,7 @@ def test_array_intermediate(): # Find the access node that is the new intermediate node. inner_access_nodes: list[dace_nodes.AccessNode] = [ node - for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) + for node in util.count_nodes(sdfg, dace_nodes.AccessNode, True) if scope[node] is not None ] assert len(inner_access_nodes) == 1 @@ -374,7 +371,7 @@ def test_interstate_transient(): """ N = 10 sdfg = _make_serial_sdfg_2(N) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 3 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 3 assert sdfg.number_of_nodes() == 1 # Now add the new state and the new output. @@ -393,15 +390,15 @@ def test_interstate_transient(): # Now apply the transformation sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) assert "tmp_1" in sdfg.arrays assert "tmp_2" not in sdfg.arrays assert sdfg.number_of_nodes() == 2 - assert util._count_nodes(head_state, dace_nodes.MapEntry) == 1 - assert util._count_nodes(new_state, dace_nodes.MapEntry) == 1 + assert util.count_nodes(head_state, dace_nodes.MapEntry) == 1 + assert util.count_nodes(new_state, dace_nodes.MapEntry) == 1 a = np.random.rand(N, N) b = np.empty_like(a) @@ -430,7 +427,7 @@ def test_indirect_access(): c = np.empty(N_output) idx = np.random.randint(low=0, high=N_input, size=N_output, dtype=np.int32) sdfg = _make_serial_sdfg_3(N_input=N_input, N_output=N_output) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 def _ref(a, b, idx): tmp = a + b @@ -443,11 +440,11 @@ def _ref(a, b, idx): # Now "apply" the transformation sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 c[:] = -1.0 sdfg(a=a, b=b, idx=idx, c=c) @@ -455,5 +452,58 @@ def _ref(a, b, idx): def test_indirect_access_2(): - # TODO(phimuell): Index should be computed and that map should be fusable. - pass + """Indirect accesses, with non point wise input dependencies. + + Because `a` is used as input and output and `a` is indirectly accessed + the access to `a` can not be point wise so, fusing is not possible. + """ + sdfg = dace.SDFG(util.unique_name("indirect_access_sdfg_2")) + state = sdfg.add_state(is_start_block=True) + + names = ["a", "b", "idx", "tmp"] + + for name in names: + sdfg.add_array( + name=name, + shape=(10,), + dtype=dace.int32 if name == "idx" else dace.float64, + transient=False, + ) + sdfg.arrays["tmp"].transient = True + + a_in, b, idx, tmp, a_out = (state.add_access(name) for name in (names + ["a"])) + + state.add_mapped_tasklet( + "indirect_access", + map_ranges={"__i0": "0:10"}, + inputs={ + "__idx": dace.Memlet("idx[__i0]"), + "__field": dace.Memlet("a[0:10]", volume=1), + }, + code="__out = __field[__idx]", + outputs={"__out": dace.Memlet("tmp[__i0]")}, + input_nodes={a_in, idx}, + output_nodes={tmp}, + external_edges=True, + ) + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10"}, + inputs={ + "__in1": dace.Memlet("tmp[__i0]"), + "__in2": dace.Memlet("b[__i0]"), + }, + code="__out = __in1 + __in2", + outputs={"__out": dace.Memlet("a[__i0]")}, + input_nodes={tmp, b}, + output_nodes={a_out}, + external_edges=True, + ) + sdfg.validate() + + count = sdfg.apply_transformations_repeated( + gtx_transformations.MapFusionSerial(), + validate=True, + validate_all=True, + ) + assert count == 0 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py new file mode 100644 index 0000000000..72efc2fe34 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py @@ -0,0 +1,100 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + + +def _perform_reorder_test( + sdfg: dace.SDFG, + leading_dim: list[str], + expected_order: list[str], +) -> None: + """Performs the reorder transformation and test it. + + If `expected_order` is the empty list, then the transformation should not apply. + """ + map_entries: list[dace.nodes.MapEntry] = util.count_nodes(sdfg, dace.nodes.MapEntry, True) + assert len(map_entries) == 1 + map_entry: dace.nodes.MapEntry = map_entries[0] + old_map_params = map_entry.map.params.copy() + + apply_count = sdfg.apply_transformations_repeated( + gtx_transformations.MapIterationOrder( + leading_dims=leading_dim, + ), + validate=True, + validate_all=True, + ) + new_map_params = map_entry.map.params.copy() + + if len(expected_order) == 0: + assert ( + apply_count == 0 + ), f"Expected that the transformation was not applied. New map order: {map_entry.map.params}" + return + else: + assert ( + apply_count > 0 + ), f"Expected that the transformation was applied. Old map order: {map_entry.map.params}; Expected order: {expected_order}" + assert len(expected_order) == len(new_map_params) + + assert ( + expected_order == new_map_params + ), f"Expected map order {expected_order} but got {new_map_params} instead." + + +def _make_test_sdfg(map_params: list[str]) -> dace.SDFG: + """Generate an SDFG for the test.""" + sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) + state: dace.SDFGState = sdfg.add_state("state", is_start_block=True) + dim = len(map_params) + for aname in ["a", "b"]: + sdfg.add_array(aname, shape=((4,) * dim), dtype=dace.float64, transient=False) + + state.add_mapped_tasklet( + "mapped_tasklet", + map_ranges=[(map_param, "0:4") for map_param in map_params], + inputs={"__in": dace.Memlet("a[" + ",".join(map_params) + "]")}, + code="__out = __in + 1", + outputs={"__out": dace.Memlet("b[" + ",".join(map_params) + "]")}, + external_edges=True, + ) + sdfg.validate() + + return sdfg + + +def test_map_order_1(): + sdfg = _make_test_sdfg(["EDim", "KDim", "VDim"]) + _perform_reorder_test(sdfg, ["EDim", "VDim"], ["KDim", "VDim", "EDim"]) + + +def test_map_order_2(): + sdfg = _make_test_sdfg(["VDim", "KDim"]) + _perform_reorder_test(sdfg, ["EDim", "VDim"], ["KDim", "VDim"]) + + +def test_map_order_3(): + sdfg = _make_test_sdfg(["EDim", "KDim"]) + _perform_reorder_test(sdfg, ["EDim", "VDim"], ["KDim", "EDim"]) + + +def test_map_order_4(): + sdfg = _make_test_sdfg(["CDim", "KDim"]) + _perform_reorder_test(sdfg, ["EDim", "VDim"], []) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py new file mode 100644 index 0000000000..7b39bc4e1d --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py @@ -0,0 +1,164 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import numpy as np +import pytest + + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes, propagation as dace_propagation +from dace.transformation import dataflow as dace_dataflow + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_movable_tasklet( + outer_tasklet_code: str, +) -> tuple[ + dace.SDFG, dace.SDFGState, dace_nodes.Tasklet, dace_nodes.AccessNode, dace_nodes.MapEntry +]: + sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) + state = sdfg.add_state("state", is_start_block=True) + + sdfg.add_scalar("outer_scalar", dtype=dace.float64, transient=True) + for name in "AB": + sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) + A, B, outer_scalar = (state.add_access(name) for name in ["A", "B", "outer_scalar"]) + + outer_tasklet = state.add_tasklet( + name="outer_tasklet", + inputs=set(), + outputs={"__out"}, + code=f"__out = {outer_tasklet_code}", + ) + state.add_edge(outer_tasklet, "__out", outer_scalar, None, dace.Memlet("outer_scalar[0]")) + + _, me, _ = state.add_mapped_tasklet( + "map", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={ + "__in0": dace.Memlet("A[__i0, __i1]"), + "__in1": dace.Memlet("outer_scalar[0]"), + }, + code="__out = __in0 + __in1", + outputs={"__out": dace.Memlet("B[__i0, __i1]")}, + external_edges=True, + input_nodes={outer_scalar, A}, + output_nodes={B}, + ) + sdfg.validate() + + return sdfg, state, outer_tasklet, outer_scalar, me + + +def test_move_tasklet_inside_trivial_memlet_tree(): + sdfg, state, outer_tasklet, outer_scalar, me = _make_movable_tasklet( + outer_tasklet_code="1.2", + ) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate_all=True, + ) + assert count == 1 + + A = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + B = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + ref = A + 1.2 + + csdfg = sdfg.compile() + csdfg(A=A, B=B) + assert np.allclose(B, ref) + + +def test_move_tasklet_inside_non_trivial_memlet_tree(): + sdfg, state, outer_tasklet, outer_scalar, me = _make_movable_tasklet( + outer_tasklet_code="1.2", + ) + # By expanding the maps, we the memlet tree is no longer trivial. + sdfg.apply_transformations_repeated(dace_dataflow.MapExpansion) + assert util.count_nodes(state, dace_nodes.MapEntry) == 2 + me = None + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate_all=True, + ) + assert count == 1 + + A = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + B = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + ref = A + 1.2 + + csdfg = sdfg.compile() + csdfg(A=A, B=B) + assert np.allclose(B, ref) + + +def test_move_tasklet_inside_two_inner_connector(): + sdfg, state, outer_tasklet, outer_scalar, me = _make_movable_tasklet( + outer_tasklet_code="32.2", + ) + mapped_tasklet = next( + iter(e.dst for e in state.out_edges(me) if isinstance(e.dst, dace_nodes.Tasklet)) + ) + + state.add_edge( + me, + f"OUT_{outer_scalar.data}", + mapped_tasklet, + "__in2", + dace.Memlet(f"{outer_scalar.data}[0]"), + ) + mapped_tasklet.add_in_connector("__in2") + mapped_tasklet.code.as_string = "__out = __in0 + __in1 + __in2" + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate_all=True, + ) + assert count == 1 + + A = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + B = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + ref = A + 2 * (32.2) + + csdfg = sdfg.compile() + csdfg(A=A, B=B) + assert np.allclose(B, ref) + + +def test_move_tasklet_inside_outer_scalar_used_outside(): + sdfg, state, outer_tasklet, outer_scalar, me = _make_movable_tasklet( + outer_tasklet_code="22.6", + ) + sdfg.add_array("C", shape=(1,), dtype=dace.float64, transient=False) + state.add_edge(outer_scalar, None, state.add_access("C"), None, dace.Memlet("C[0]")) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate_all=True, + ) + assert count == 1 + + A = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + B = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + C = np.array(np.random.rand(1), dtype=np.float64, copy=True) + ref_C = 22.6 + ref_B = A + ref_C + + csdfg = sdfg.compile() + csdfg(A=A, B=B, C=C) + assert np.allclose(B, ref_B) + assert np.allclose(C, ref_C) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py index 96584b8273..8626cb8e07 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py @@ -68,7 +68,7 @@ def test_serial_map_promotion(): external_edges=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 assert len(map_entry_1d.map.params) == 1 assert len(map_entry_1d.map.range) == 1 assert len(map_entry_2d.map.params) == 2 @@ -83,7 +83,7 @@ def test_serial_map_promotion(): validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 assert len(map_entry_1d.map.params) == 2 assert len(map_entry_1d.map.range) == 2 assert len(map_entry_2d.map.params) == 2 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py index ac88f4fef8..b82cecee98 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py @@ -14,7 +14,7 @@ @overload -def _count_nodes( +def count_nodes( graph: Union[dace.SDFG, dace.SDFGState], node_type: tuple[type, ...] | type, return_nodes: Literal[False], @@ -22,14 +22,14 @@ def _count_nodes( @overload -def _count_nodes( +def count_nodes( graph: Union[dace.SDFG, dace.SDFGState], node_type: tuple[type, ...] | type, return_nodes: Literal[True], ) -> list[dace_nodes.Node]: ... -def _count_nodes( +def count_nodes( graph: Union[dace.SDFG, dace.SDFGState], node_type: tuple[type, ...] | type, return_nodes: bool = False, From 39fb949c2e7d0ff9ff4f1b9c3fff921ee8561086 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 5 Dec 2024 10:03:49 +0100 Subject: [PATCH 071/178] style[cartesian]: readability improvements and more type hints (#1752) This PR detaches a couple of cleanups in the dace backend from the in-progress gt4py/dace bridge: mostly readability improvements and some easy type hints. There's also the occasional unused variable / argument in here. --- src/gt4py/cartesian/backend/dace_backend.py | 16 +++--- .../gtc/dace/expansion_specification.py | 50 +++++++++---------- src/gt4py/cartesian/gtc/dace/oir_to_dace.py | 3 +- src/gt4py/cartesian/gtc/dace/utils.py | 2 +- .../test_code_generation.py | 2 +- 5 files changed, 36 insertions(+), 37 deletions(-) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index f49895a435..a6d28f5994 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -56,17 +56,17 @@ def _specialize_transient_strides(sdfg: dace.SDFG, layout_map): - repldict = replace_strides( + replacement_dictionary = replace_strides( [array for array in sdfg.arrays.values() if array.transient], layout_map ) - sdfg.replace_dict(repldict) + sdfg.replace_dict(replacement_dictionary) for state in sdfg.nodes(): for node in state.nodes(): if isinstance(node, dace.nodes.NestedSDFG): - for k, v in repldict.items(): + for k, v in replacement_dictionary.items(): if k in node.symbol_mapping: node.symbol_mapping[k] = v - for k in repldict.keys(): + for k in replacement_dictionary.keys(): if k in sdfg.symbols: sdfg.remove_symbol(k) @@ -143,7 +143,7 @@ def _to_device(sdfg: dace.SDFG, device: str) -> None: node.device = dace.DeviceType.GPU -def _pre_expand_trafos(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map): +def _pre_expand_transformations(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map): args_data = make_args_data_from_gtir(gtir_pipeline) # stencils without effect @@ -164,7 +164,7 @@ def _pre_expand_trafos(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map) return sdfg -def _post_expand_trafos(sdfg: dace.SDFG): +def _post_expand_transformations(sdfg: dace.SDFG): # DaCe "standard" clean-up transformations sdfg.simplify(validate=False) @@ -355,7 +355,7 @@ def _unexpanded_sdfg(self): sdfg = OirSDFGBuilder().visit(oir_node) _to_device(sdfg, self.builder.backend.storage_info["device"]) - _pre_expand_trafos( + _pre_expand_transformations( self.builder.gtir_pipeline, sdfg, self.builder.backend.storage_info["layout_map"], @@ -371,7 +371,7 @@ def unexpanded_sdfg(self): def _expanded_sdfg(self): sdfg = self._unexpanded_sdfg() sdfg.expand_library_nodes() - _post_expand_trafos(sdfg) + _post_expand_transformations(sdfg) return sdfg def expanded_sdfg(self): diff --git a/src/gt4py/cartesian/gtc/dace/expansion_specification.py b/src/gt4py/cartesian/gtc/dace/expansion_specification.py index c716f1a103..af9a814843 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion_specification.py +++ b/src/gt4py/cartesian/gtc/dace/expansion_specification.py @@ -107,7 +107,8 @@ def get_expansion_order_index(expansion_order, axis): for idx, item in enumerate(expansion_order): if isinstance(item, Iteration) and item.axis == axis: return idx - elif isinstance(item, Map): + + if isinstance(item, Map): for it in item.iterations: if it.kind == "contiguous" and it.axis == axis: return idx @@ -136,7 +137,9 @@ def _choose_loop_or_map(node, eo): return eo -def _order_as_spec(computation_node, expansion_order): +def _order_as_spec( + computation_node: StencilComputation, expansion_order: Union[List[str], List[ExpansionItem]] +) -> List[ExpansionItem]: expansion_order = list(_choose_loop_or_map(computation_node, eo) for eo in expansion_order) expansion_specification = [] for item in expansion_order: @@ -170,7 +173,7 @@ def _order_as_spec(computation_node, expansion_order): return expansion_specification -def _populate_strides(node, expansion_specification): +def _populate_strides(node: StencilComputation, expansion_specification: List[ExpansionItem]): """Fill in `stride` attribute of `Iteration` and `Loop` dataclasses. For loops, stride is set to either -1 or 1, based on iteration order. @@ -185,10 +188,7 @@ def _populate_strides(node, expansion_specification): for it in iterations: if isinstance(it, Loop): if it.stride is None: - if node.oir_node.loop_order == common.LoopOrder.BACKWARD: - it.stride = -1 - else: - it.stride = 1 + it.stride = -1 if node.oir_node.loop_order == common.LoopOrder.BACKWARD else 1 else: if it.stride is None: if it.kind == "tiling": @@ -204,7 +204,7 @@ def _populate_strides(node, expansion_specification): it.stride = 1 -def _populate_storages(self, expansion_specification): +def _populate_storages(expansion_specification: List[ExpansionItem]): assert all(isinstance(es, ExpansionItem) for es in expansion_specification) innermost_axes = set(dcir.Axis.dims_3d()) tiled_axes = set() @@ -222,7 +222,7 @@ def _populate_storages(self, expansion_specification): tiled_axes.remove(it.axis) -def _populate_cpu_schedules(self, expansion_specification): +def _populate_cpu_schedules(expansion_specification: List[ExpansionItem]): is_outermost = True for es in expansion_specification: if isinstance(es, Map): @@ -234,7 +234,7 @@ def _populate_cpu_schedules(self, expansion_specification): es.schedule = dace.ScheduleType.Default -def _populate_gpu_schedules(self, expansion_specification): +def _populate_gpu_schedules(expansion_specification: List[ExpansionItem]): # On GPU if any dimension is tiled and has a contiguous map in the same axis further in # pick those two maps as Device/ThreadBlock maps. If not, Make just device map with # default blocksizes @@ -267,16 +267,16 @@ def _populate_gpu_schedules(self, expansion_specification): es.schedule = dace.ScheduleType.Default -def _populate_schedules(self, expansion_specification): +def _populate_schedules(node: StencilComputation, expansion_specification: List[ExpansionItem]): assert all(isinstance(es, ExpansionItem) for es in expansion_specification) - assert hasattr(self, "_device") - if self.device == dace.DeviceType.GPU: - _populate_gpu_schedules(self, expansion_specification) + assert hasattr(node, "_device") + if node.device == dace.DeviceType.GPU: + _populate_gpu_schedules(expansion_specification) else: - _populate_cpu_schedules(self, expansion_specification) + _populate_cpu_schedules(expansion_specification) -def _collapse_maps_gpu(self, expansion_specification): +def _collapse_maps_gpu(expansion_specification: List[ExpansionItem]) -> List[ExpansionItem]: def _union_map_items(last_item, next_item): if last_item.schedule == next_item.schedule: return ( @@ -307,7 +307,7 @@ def _union_map_items(last_item, next_item): ), ) - res_items = [] + res_items: List[ExpansionItem] = [] for item in expansion_specification: if isinstance(item, Map): if not res_items or not isinstance(res_items[-1], Map): @@ -324,8 +324,8 @@ def _union_map_items(last_item, next_item): return res_items -def _collapse_maps_cpu(self, expansion_specification): - res_items = [] +def _collapse_maps_cpu(expansion_specification: List[ExpansionItem]) -> List[ExpansionItem]: + res_items: List[ExpansionItem] = [] for item in expansion_specification: if isinstance(item, Map): if ( @@ -360,12 +360,12 @@ def _collapse_maps_cpu(self, expansion_specification): return res_items -def _collapse_maps(self, expansion_specification): - assert hasattr(self, "_device") - if self.device == dace.DeviceType.GPU: - res_items = _collapse_maps_gpu(self, expansion_specification) +def _collapse_maps(node: StencilComputation, expansion_specification: List[ExpansionItem]): + assert hasattr(node, "_device") + if node.device == dace.DeviceType.GPU: + res_items = _collapse_maps_gpu(expansion_specification) else: - res_items = _collapse_maps_cpu(self, expansion_specification) + res_items = _collapse_maps_cpu(expansion_specification) expansion_specification.clear() expansion_specification.extend(res_items) @@ -387,7 +387,7 @@ def make_expansion_order( _populate_strides(node, expansion_specification) _populate_schedules(node, expansion_specification) _collapse_maps(node, expansion_specification) - _populate_storages(node, expansion_specification) + _populate_storages(expansion_specification) return expansion_specification diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py index f12c13cd0e..14448bb08e 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py @@ -123,6 +123,7 @@ def visit_VerticalLoop( state.add_edge( access_node, None, library_node, "__in_" + field, dace.Memlet(field, subset=subset) ) + for field in access_collection.write_fields(): access_node = state.add_access(field, debuginfo=dace.DebugInfo(0)) library_node.add_out_connector("__out_" + field) @@ -131,8 +132,6 @@ def visit_VerticalLoop( library_node, "__out_" + field, access_node, None, dace.Memlet(field, subset=subset) ) - return - def visit_Stencil(self, node: oir.Stencil, **kwargs): ctx = OirSDFGBuilder.SDFGContext(stencil=node) for param in node.params: diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index 517e80ceb3..bd65861a49 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -40,7 +40,7 @@ def array_dimensions(array: dace.data.Array): return dims -def replace_strides(arrays, get_layout_map): +def replace_strides(arrays: List[dace.data.Array], get_layout_map) -> Dict[str, str]: symbol_mapping = {} for array in arrays: dims = array_dimensions(array) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index e51b3ef09d..4609184547 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -421,7 +421,7 @@ def stencil(field_a: gtscript.Field[np.float_], field_b: gtscript.Field[np.int_] @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_mask_with_offset_written_in_conditional(backend): - @gtscript.stencil(backend, externals={"mord": 5}) + @gtscript.stencil(backend) def stencil(outp: gtscript.Field[np.float_]): with computation(PARALLEL), interval(...): cond = True From 8b6abc22fe07da99157afc3a03d7c3911651bff8 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 6 Dec 2024 08:18:35 +0100 Subject: [PATCH 072/178] refactor[next]: remove use of Fencil in tracing (eliminate `closure`) (#1772) --- src/gt4py/next/iterator/embedded.py | 11 ++--- src/gt4py/next/iterator/runtime.py | 9 +--- src/gt4py/next/iterator/tracing.py | 44 +++---------------- .../program_processors/runners/roundtrip.py | 1 - .../iterator_tests/test_builtins.py | 4 +- tests/next_tests/unit_tests/conftest.py | 1 + .../iterator_tests/test_runtime_domain.py | 9 ++-- 7 files changed, 22 insertions(+), 57 deletions(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 3c63ffef30..13c64e264e 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1706,8 +1706,10 @@ def impl(*iters: ItIterator): return impl -def _dimension_to_tag(domain: Domain) -> dict[Tag, range]: - return {k.value if isinstance(k, common.Dimension) else k: v for k, v in domain.items()} +def _dimension_to_tag( + domain: runtime.CartesianDomain | runtime.UnstructuredDomain, +) -> dict[Tag, range]: + return {k.value: v for k, v in domain.items()} def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProviderType) -> None: @@ -1828,7 +1830,7 @@ def impl(*args): # TODO(havogt): after updating all tests to use the new program, # we should get rid of closure and move the implementation to this function - closure(_dimension_to_tag(domain), fun, out, list(args)) + closure(domain, fun, out, list(args)) return out return impl @@ -1839,9 +1841,8 @@ def index(axis: common.Dimension) -> common.Field: return IndexField(axis) -@runtime.closure.register(EMBEDDED) def closure( - domain_: Domain, + domain_: runtime.CartesianDomain | runtime.UnstructuredDomain, sten: Callable[..., Any], out, #: MutableLocatedField, ins: list[common.Field | Scalar | tuple[common.Field | Scalar | tuple, ...]], diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index d42f961202..e47a6886ad 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -26,7 +26,7 @@ # TODO(tehrengruber): remove cirular dependency and import unconditionally from gt4py.next import backend as next_backend -__all__ = ["offset", "fundef", "fendef", "closure", "set_at", "if_stmt"] +__all__ = ["offset", "fundef", "fendef", "set_at", "if_stmt"] @dataclass(frozen=True) @@ -163,7 +163,7 @@ def impl(out, *inps): # if passed as a dict, we need to convert back to builtins for interpretation by the backends assert offset_provider is not None dom = _deduce_domain(dom, common.offset_provider_to_type(offset_provider)) - closure(dom, self.fundef_dispatcher, out, [*inps]) + set_at(builtins.as_fieldop(self.fundef_dispatcher, dom)(*inps), dom, out) return impl @@ -208,11 +208,6 @@ def fundef(fun): return FundefDispatcher(fun) -@builtin_dispatch -def closure(*args): # TODO remove - return BackendNotSelectedError() - - @builtin_dispatch def set_at(*args): return BackendNotSelectedError() diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 6772d4b507..81e9551e5c 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -23,7 +23,6 @@ Lambda, NoneLiteral, OffsetLiteral, - StencilClosure, Sym, SymRef, ) @@ -202,9 +201,6 @@ def __bool__(self): class TracerContext: fundefs: ClassVar[List[FunctionDefinition]] = [] - closures: ClassVar[ - List[StencilClosure] - ] = [] # TODO(havogt): remove after refactoring to `Program` is complete, currently handles both programs and fencils body: ClassVar[List[itir.Stmt]] = [] @classmethod @@ -212,10 +208,6 @@ def add_fundef(cls, fun): if fun not in cls.fundefs: cls.fundefs.append(fun) - @classmethod - def add_closure(cls, closure): - cls.closures.append(closure) - @classmethod def add_stmt(cls, stmt): cls.body.append(stmt) @@ -225,23 +217,10 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_traceback): type(self).fundefs = [] - type(self).closures = [] type(self).body = [] iterator.builtins.builtin_dispatch.pop_key() -@iterator.runtime.closure.register(TRACING) -def closure(domain, stencil, output, inputs): - if hasattr(stencil, "__name__") and stencil.__name__ in iterator.builtins.__all__: - stencil = _s(stencil.__name__) - else: - stencil(*(_s(param) for param in inspect.signature(stencil).parameters)) - stencil = make_node(stencil) - TracerContext.add_closure( - StencilClosure(domain=domain, stencil=stencil, output=output, inputs=inputs) - ) - - @iterator.runtime.set_at.register(TRACING) def set_at(expr: itir.Expr, domain: itir.Expr, target: itir.Expr) -> None: TracerContext.add_stmt(itir.SetAt(expr=expr, domain=domain, target=target)) @@ -328,19 +307,10 @@ def trace_fencil_definition( params = _make_fencil_params(fun, args) trace_function_call(fun, args=(_s(param.id) for param in params)) - if TracerContext.closures: - return itir.FencilDefinition( - id=fun.__name__, - function_definitions=TracerContext.fundefs, - params=params, - closures=TracerContext.closures, - ) - else: - assert TracerContext.body - return itir.Program( - id=fun.__name__, - function_definitions=TracerContext.fundefs, - params=params, - declarations=[], # TODO - body=TracerContext.body, - ) + return itir.Program( + id=fun.__name__, + function_definitions=TracerContext.fundefs, + params=params, + declarations=[], # TODO + body=TracerContext.body, + ) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 1dd568b95a..25eda5a2ed 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -46,7 +46,6 @@ class EmbeddedDSL(codegen.TemplatedGenerator): AxisLiteral = as_fmt("{value}") FunCall = as_fmt("{fun}({','.join(args)})") Lambda = as_mako("(lambda ${','.join(params)}: ${expr})") - StencilClosure = as_mako("closure(${domain}, ${stencil}, ${output}, [${','.join(inputs)}])") FunctionDefinition = as_mako( """ @fundef diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 5e3a2fcd14..c0a4cd166d 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -18,6 +18,7 @@ from gt4py.next.iterator import builtins as it_builtins from gt4py.next.iterator.builtins import ( and_, + as_fieldop, bool, can_deref, cartesian_domain, @@ -45,9 +46,8 @@ plus, shift, xor_, - as_fieldop, ) -from gt4py.next.iterator.runtime import set_at, closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import fendef, fundef, offset, set_at from gt4py.next.program_processors.runners.gtfn import run_gtfn from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 99bc44efa7..8f6d5787d3 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -53,6 +53,7 @@ def _program_processor(request) -> tuple[ProgramProcessor, bool]: (None, True), (next_tests.definitions.ProgramBackendId.ROUNDTRIP, True), (next_tests.definitions.ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES, True), + (next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, True), (next_tests.definitions.ProgramBackendId.DOUBLE_ROUNDTRIP, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), diff --git a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py index 13e8637d1a..bf2df06bf2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py @@ -27,21 +27,20 @@ def foo(inp): dtype=None, ) +I = gtx.Dimension("I") + def test_deduce_domain(): assert isinstance(_deduce_domain({}, {}), CartesianDomain) assert isinstance(_deduce_domain(UnstructuredDomain(), {}), UnstructuredDomain) assert isinstance(_deduce_domain({}, {"foo": connectivity}), UnstructuredDomain) assert isinstance( - _deduce_domain(CartesianDomain([("I", range(1))]), {"foo": connectivity}), CartesianDomain + _deduce_domain(CartesianDomain([(I, range(1))]), {"foo": connectivity}), CartesianDomain ) -I = gtx.Dimension("I") - - def test_embedded_error_on_wrong_domain(): - dom = CartesianDomain([("I", range(1))]) + dom = CartesianDomain([(I, range(1))]) out = gtx.as_field([I], np.zeros(1)) with pytest.raises(RuntimeError, match="expected 'UnstructuredDomain'"): From 06813d54d9daec17bbac68aab32f6081c7f46b8e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 6 Dec 2024 11:52:17 +0100 Subject: [PATCH 073/178] refactor[next]: remove all FencilDefinitions from tests (#1773) --- src/gt4py/next/iterator/ir.py | 10 +- .../iterator/transforms/symbol_ref_utils.py | 4 +- .../ffront_tests/test_decorator.py | 6 +- .../iterator_tests/test_pretty_parser.py | 36 +---- .../iterator_tests/test_pretty_printer.py | 36 +---- .../iterator_tests/test_type_inference.py | 151 +++++++----------- .../transforms_tests/test_symbol_ref_utils.py | 23 ++- .../gtfn_tests/test_gtfn_module.py | 41 ++--- 8 files changed, 96 insertions(+), 211 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 7098e9fa2e..6efee29362 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -189,17 +189,11 @@ def _input_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribu "scan", "if_", "index", # `index(dim)` creates a dim-field that has the current index at each point + "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) *ARITHMETIC_BUILTINS, *TYPEBUILTINS, } -# only used in `Program`` not `FencilDefinition` -# TODO(havogt): restructure after refactoring to GTIR -GTIR_BUILTINS = { - *BUILTINS, - "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) -} - class FencilDefinition(Node, ValidatedSymbolTableTrait): id: Coerced[SymbolName] @@ -243,7 +237,7 @@ class Program(Node, ValidatedSymbolTableTrait): implicit_domain: bool = False _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ - Sym(id=name) for name in sorted(GTIR_BUILTINS) + Sym(id=name) for name in sorted(BUILTINS) ] # sorted for serialization stability diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 05163a3630..1765259a81 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -140,6 +140,4 @@ def collect_symbol_refs( def get_user_defined_symbols(symtable: dict[eve.SymbolName, itir.Sym]) -> set[str]: - return {str(sym) for sym in symtable.keys()} - { - str(n.id) for n in itir.FencilDefinition._NODE_SYMBOLS_ - } + return {str(sym) for sym in symtable.keys()} - {str(n.id) for n in itir.Program._NODE_SYMBOLS_} diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index 47419c278b..45bf7428a6 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -30,10 +30,8 @@ def testee_op(a: cases.IField) -> cases.IField: def testee(a: cases.IField, out: cases.IField): testee_op(a, out=out) - assert isinstance(testee.itir, (itir.Program, itir.FencilDefinition)) - assert isinstance( - testee.with_backend(cartesian_case.backend).itir, (itir.Program, itir.FencilDefinition) - ) + assert isinstance(testee.itir, itir.Program) + assert isinstance(testee.with_backend(cartesian_case.backend).itir, itir.Program) def test_frozen(cartesian_case): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index da4bea8874..bf47f997d6 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -7,8 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.next.iterator import ir -from gt4py.next.iterator.pretty_parser import pparse from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.pretty_parser import pparse from gt4py.next.type_system import type_specifications as ts @@ -208,18 +208,6 @@ def test_temporary(): assert actual == expected -def test_stencil_closure(): - testee = "y ← (deref)(x) @ cartesian_domain();" - expected = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], - ) - actual = pparse(testee) - assert actual == expected - - def test_set_at(): testee = "y @ cartesian_domain() ← x;" expected = ir.SetAt( @@ -262,28 +250,6 @@ def test_if_stmt(): assert actual == expected -# TODO(havogt): remove after refactoring to GTIR -def test_fencil_definition(): - testee = "f(d, x, y) {\n g = λ(x) → x;\n y ← (deref)(x) @ cartesian_domain();\n}" - expected = ir.FencilDefinition( - id="f", - function_definitions=[ - ir.FunctionDefinition(id="g", params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")) - ], - params=[ir.Sym(id="d"), ir.Sym(id="x"), ir.Sym(id="y")], - closures=[ - ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], - ) - ], - ) - actual = pparse(testee) - assert actual == expected - - def test_program(): testee = "f(d, x, y) {\n g = λ(x) → x;\n tmp = temporary(domain=cartesian_domain(), dtype=float64);\n y @ cartesian_domain() ← x;\n}" expected = ir.Program( diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 69a45cf128..11f50dbf6d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -7,8 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.next.iterator import ir -from gt4py.next.iterator.pretty_printer import PrettyPrinter, pformat from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.pretty_printer import PrettyPrinter, pformat from gt4py.next.type_system import type_specifications as ts @@ -313,18 +313,6 @@ def test_temporary(): assert actual == expected -def test_stencil_closure(): - testee = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], - ) - expected = "y ← (deref)(x) @ cartesian_domain();" - actual = pformat(testee) - assert actual == expected - - def test_set_at(): testee = ir.SetAt( expr=ir.SymRef(id="x"), @@ -336,28 +324,6 @@ def test_set_at(): assert actual == expected -# TODO(havogt): remove after refactoring. -def test_fencil_definition(): - testee = ir.FencilDefinition( - id="f", - function_definitions=[ - ir.FunctionDefinition(id="g", params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")) - ], - params=[ir.Sym(id="d"), ir.Sym(id="x"), ir.Sym(id="y")], - closures=[ - ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], - ) - ], - ) - actual = pformat(testee) - expected = "f(d, x, y) {\n g = λ(x) → x;\n y ← (deref)(x) @ cartesian_domain();\n}" - assert actual == expected - - def test_program(): testee = ir.Program( id="f", diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 65a5b5888d..7eb4e86adb 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -23,13 +23,12 @@ ) from gt4py.next.type_system import type_specifications as ts -from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import simple_mesh - from next_tests.integration_tests.cases import ( C2E, E2V, V2E, E2VDim, + Edge, IDim, Ioff, JDim, @@ -37,11 +36,12 @@ Koff, V2EDim, Vertex, - Edge, - mesh_descriptor, exec_alloc_descriptor, + mesh_descriptor, unstructured_case, ) +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import simple_mesh + bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) @@ -275,48 +275,35 @@ def test_cast_first_arg_inference(): assert result.type == float64_type -# TODO(tehrengruber): Rewrite tests to use itir.Program def test_cartesian_fencil_definition(): cartesian_domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) ) - testee = itir.FencilDefinition( + testee = itir.Program( id="f", function_definitions=[], params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( + expr=im.call(im.call("as_fieldop")(im.ref("deref"), cartesian_domain))( + im.ref("inp") + ), domain=cartesian_domain, - stencil=im.ref("deref"), - output=im.ref("out"), - inputs=[im.ref("inp")], + target=im.ref("out"), ), ], ) result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) - closure_type = it_ts.StencilClosureType( - domain=it_ts.DomainType(dims=[IDim]), - stencil=ts.FunctionType( - pos_only_args=[ - it_ts.IteratorType( - position_dims=[IDim], defined_dims=[IDim], element_type=float64_type - ) - ], - pos_or_kw_args={}, - kw_only_args={}, - returns=float64_type, - ), - output=float_i_field, - inputs=[float_i_field], - ) - fencil_type = it_ts.FencilType( - params={"inp": float_i_field, "out": float_i_field}, closures=[closure_type] - ) - assert result.type == fencil_type - assert result.closures[0].type == closure_type + program_type = it_ts.ProgramType(params={"inp": float_i_field, "out": float_i_field}) + assert result.type == program_type + domain_type = it_ts.DomainType(dims=[IDim]) + assert result.body[0].domain.type == domain_type + assert result.body[0].expr.type == float_i_field + assert result.body[0].target.type == float_i_field def test_unstructured_fencil_definition(): @@ -326,44 +313,34 @@ def test_unstructured_fencil_definition(): im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), ) - testee = itir.FencilDefinition( + testee = itir.Program( id="f", function_definitions=[], params=[im.sym("inp", float_edge_k_field), im.sym("out", float_vertex_k_field)], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), unstructured_domain + ) + )(im.ref("inp")), domain=unstructured_domain, - stencil=im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), - output=im.ref("out"), - inputs=[im.ref("inp")], + target=im.ref("out"), ), ], ) result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) - closure_type = it_ts.StencilClosureType( - domain=it_ts.DomainType(dims=[Vertex, KDim]), - stencil=ts.FunctionType( - pos_only_args=[ - it_ts.IteratorType( - position_dims=[Vertex, KDim], - defined_dims=[Edge, KDim], - element_type=float64_type, - ) - ], - pos_or_kw_args={}, - kw_only_args={}, - returns=float64_type, - ), - output=float_vertex_k_field, - inputs=[float_edge_k_field], - ) - fencil_type = it_ts.FencilType( - params={"inp": float_edge_k_field, "out": float_vertex_k_field}, closures=[closure_type] + program_type = it_ts.ProgramType( + params={"inp": float_edge_k_field, "out": float_vertex_k_field} ) - assert result.type == fencil_type - assert result.closures[0].type == closure_type + assert result.type == program_type + domain_type = it_ts.DomainType(dims=[Vertex, KDim]) + assert result.body[0].domain.type == domain_type + assert result.body[0].expr.type == float_vertex_k_field + assert result.body[0].target.type == float_vertex_k_field def test_function_definition(): @@ -371,45 +348,29 @@ def test_function_definition(): im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) ) - testee = itir.FencilDefinition( + testee = itir.Program( id="f", function_definitions=[ itir.FunctionDefinition(id="foo", params=[im.sym("it")], expr=im.deref("it")), itir.FunctionDefinition(id="bar", params=[im.sym("it")], expr=im.call("foo")("it")), ], params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( domain=cartesian_domain, - stencil=im.ref("bar"), - output=im.ref("out"), - inputs=[im.ref("inp")], + expr=im.call(im.call("as_fieldop")(im.ref("bar"), cartesian_domain))(im.ref("inp")), + target=im.ref("out"), ), ], ) result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) - closure_type = it_ts.StencilClosureType( - domain=it_ts.DomainType(dims=[IDim]), - stencil=ts.FunctionType( - pos_only_args=[ - it_ts.IteratorType( - position_dims=[IDim], defined_dims=[IDim], element_type=float64_type - ) - ], - pos_or_kw_args={}, - kw_only_args={}, - returns=float64_type, - ), - output=float_i_field, - inputs=[float_i_field], - ) - fencil_type = it_ts.FencilType( - params={"inp": float_i_field, "out": float_i_field}, closures=[closure_type] - ) - assert result.type == fencil_type - assert result.closures[0].type == closure_type + program_type = it_ts.ProgramType(params={"inp": float_i_field, "out": float_i_field}) + assert result.type == program_type + assert result.body[0].expr.type == float_i_field + assert result.body[0].target.type == float_i_field def test_fencil_with_nb_field_input(): @@ -419,24 +380,30 @@ def test_fencil_with_nb_field_input(): im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), ) - testee = itir.FencilDefinition( + testee = itir.Program( id="f", function_definitions=[], params=[im.sym("inp", float_vertex_v2e_field), im.sym("out", float_vertex_k_field)], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( domain=unstructured_domain, - stencil=im.lambda_("it")(im.call(im.call("reduce")("plus", 0.0))(im.deref("it"))), - output=im.ref("out"), - inputs=[im.ref("inp")], + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.call(im.call("reduce")("plus", 0.0))(im.deref("it"))), + unstructured_domain, + ) + )(im.ref("inp")), + target=im.ref("out"), ), ], ) result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) - assert result.closures[0].stencil.expr.args[0].type == float64_list_type - assert result.closures[0].stencil.type.returns == float64_type + stencil = result.body[0].expr.fun.args[0] + assert stencil.expr.args[0].type == float64_list_type + assert stencil.type.returns == float64_type def test_program_tuple_setat_short_target(): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py index 0c118ff6dc..c162860c7c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py @@ -6,28 +6,23 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from dataclasses import dataclass -from typing import Optional -from gt4py import eve from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms.symbol_ref_utils import ( - collect_symbol_refs, - get_user_defined_symbols, -) +from gt4py.next.iterator.transforms.symbol_ref_utils import get_user_defined_symbols def test_get_user_defined_symbols(): - ir = itir.FencilDefinition( + domain = itir.FunCall(fun=itir.SymRef(id="cartesian_domain"), args=[]) + ir = itir.Program( id="foo", function_definitions=[], params=[itir.Sym(id="target_symbol")], - closures=[ - itir.StencilClosure( - domain=itir.FunCall(fun=itir.SymRef(id="cartesian_domain"), args=[]), - stencil=itir.SymRef(id="deref"), - output=itir.SymRef(id="target_symbol"), - inputs=[], + declarations=[], + body=[ + itir.SetAt( + expr=itir.Lambda(params=[itir.Sym(id="foo")], expr=itir.SymRef(id="foo")), + domain=domain, + target=itir.SymRef(id="target_symbol"), ) ], ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index e64bd8a57d..0586d48703 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -6,11 +6,11 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import numpy as np -import pytest import copy -import diskcache +import diskcache +import numpy as np +import pytest import gt4py.next as gtx from gt4py.next.iterator import ir as itir @@ -19,18 +19,17 @@ from gt4py.next.program_processors.codegens.gtfn import gtfn_module from gt4py.next.program_processors.runners import gtfn from gt4py.next.type_system import type_translation -from next_tests.integration_tests import cases -from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import KDim +from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case - from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + KDim, exec_alloc_descriptor, ) @pytest.fixture -def fencil_example(): +def program_example(): IDim = gtx.Dimension("I") params = [gtx.as_field([IDim], np.empty((1,), dtype=np.float32)), np.float32(3.14)] param_types = [type_translation.from_value(param) for param in params] @@ -48,7 +47,7 @@ def fencil_example(): ) ], ) - fencil = itir.FencilDefinition( + program = itir.Program( id="example", params=[im.sym(name, type_) for name, type_ in zip(("buf", "sc"), param_types)], function_definitions=[ @@ -58,20 +57,22 @@ def fencil_example(): expr=im.literal("1", "float32"), ) ], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( + expr=im.call(im.call("as_fieldop")(itir.SymRef(id="stencil"), domain))( + itir.SymRef(id="buf"), itir.SymRef(id="sc") + ), domain=domain, - stencil=itir.SymRef(id="stencil"), - output=itir.SymRef(id="buf"), - inputs=[itir.SymRef(id="buf"), itir.SymRef(id="sc")], + target=itir.SymRef(id="buf"), ) ], ) - return fencil, params + return program, params -def test_codegen(fencil_example): - fencil, parameters = fencil_example +def test_codegen(program_example): + fencil, parameters = program_example module = gtfn_module.translate_program_cpu( stages.CompilableProgram( data=fencil, @@ -85,8 +86,8 @@ def test_codegen(fencil_example): assert module.language is languages.CPP -def test_hash_and_diskcache(fencil_example, tmp_path): - fencil, parameters = fencil_example +def test_hash_and_diskcache(program_example, tmp_path): + fencil, parameters = program_example compilable_program = stages.CompilableProgram( data=fencil, args=arguments.CompileTimeArgs.from_concrete_no_size( @@ -129,8 +130,8 @@ def test_hash_and_diskcache(fencil_example, tmp_path): ) != gtfn.fingerprint_compilable_program(altered_program_column_axis) -def test_gtfn_file_cache(fencil_example): - fencil, parameters = fencil_example +def test_gtfn_file_cache(program_example): + fencil, parameters = program_example compilable_program = stages.CompilableProgram( data=fencil, args=arguments.CompileTimeArgs.from_concrete_no_size( From 2c48858ff00f5f7ac2786f945bbf6bca60bfd4bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:20:31 +0100 Subject: [PATCH 074/178] feat[dace]: Restirct Loop Blocking (#1775) Made it possible to disable loop blocking if there are no independent nodes. --- .../transformations/auto_optimize.py | 6 ++- .../transformations/loop_blocking.py | 24 ++++++++- .../test_loop_blocking.py | 49 +++++++++++++++++++ 3 files changed, 76 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py index bc1d21ca05..4a06d2f416 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py @@ -32,6 +32,7 @@ def gt_auto_optimize( gpu_block_size: Optional[Sequence[int | str] | str] = None, blocking_dim: Optional[gtx_common.Dimension] = None, blocking_size: int = 10, + blocking_only_if_independent_nodes: Optional[bool] = None, reuse_transients: bool = False, gpu_launch_bounds: Optional[int | str] = None, gpu_launch_factor: Optional[int] = None, @@ -90,6 +91,9 @@ def gt_auto_optimize( one for all. blocking_dim: On which dimension blocking should be applied. blocking_size: How many elements each block should process. + blocking_only_if_independent_nodes: If `True` only apply loop blocking if + there are independent nodes in the Map, see the `require_independent_nodes` + option of the `LoopBlocking` transformation. reuse_transients: Run the `TransientReuse` transformation, might reduce memory footprint. gpu_launch_bounds: Use this value as `__launch_bounds__` for _all_ GPU Maps. gpu_launch_factor: Use the number of threads times this value as `__launch_bounds__` @@ -101,7 +105,6 @@ def gt_auto_optimize( validate: Perform validation during the steps. validate_all: Perform extensive validation. - Note: For identifying symbols that can be treated as compile time constants `gt_find_constant_arguments()` function can be used. @@ -227,6 +230,7 @@ def gt_auto_optimize( gtx_transformations.LoopBlocking( blocking_size=blocking_size, blocking_parameter=blocking_dim, + require_independent_nodes=blocking_only_if_independent_nodes, ), validate=validate, validate_all=validate_all, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py index d401c06f15..27b6c68072 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py @@ -36,12 +36,16 @@ class LoopBlocking(dace_transformation.SingleStateTransformation): What makes this transformation different from simple blocking, is that the inner map will not just be inserted right after the outer Map. Instead the transformation will first identify all nodes that does not depend - on the blocking parameter `I` and relocate them between the outer and inner map. - Thus these operations will only be performed once, per inner loop. + on the blocking parameter `I`, called independent nodes and relocate them + between the outer and inner map. Note that an independent node must be connected + to the MapEntry or another independent node. + Thus these operations will only be performed once, per outer loop iteration. Args: blocking_size: The size of the block, denoted as `B` above. blocking_parameter: On which parameter should we block. + require_independent_nodes: If `True` only apply loop blocking if the Map + actually contains independent nodes. Defaults to `False`. Todo: - Modify the inner map such that it always starts at zero. @@ -59,6 +63,12 @@ class LoopBlocking(dace_transformation.SingleStateTransformation): desc="Name of the iteration variable on which to block (must be an exact match);" " 'I' in the above description.", ) + require_independent_nodes = dace_properties.Property( + dtype=bool, + default=False, + desc="If 'True' then blocking is only applied if there are independent nodes.", + ) + # Set of nodes that are independent of the blocking parameter. _independent_nodes: Optional[set[dace_nodes.AccessNode]] _dependent_nodes: Optional[set[dace_nodes.AccessNode]] @@ -69,6 +79,7 @@ def __init__( self, blocking_size: Optional[int] = None, blocking_parameter: Optional[Union[gtx_common.Dimension, str]] = None, + require_independent_nodes: Optional[bool] = None, ) -> None: super().__init__() if isinstance(blocking_parameter, gtx_common.Dimension): @@ -77,6 +88,8 @@ def __init__( self.blocking_parameter = blocking_parameter if blocking_size is not None: self.blocking_size = blocking_size + if require_independent_nodes is not None: + self.require_independent_nodes = require_independent_nodes self._independent_nodes = None self._dependent_nodes = None @@ -250,6 +263,9 @@ def partition_map_output( member variables are updated. If the partition does not exists the function will return `False` and the respective member variables will be `None`. + The function will honor `self.require_independent_nodes`. Thus if no independent + nodes were found the function behaves as if the partition does not exist. + Args: state: The state on which we operate. sdfg: The SDFG in which we operate on. @@ -295,6 +311,10 @@ def partition_map_output( if not found_new_independent_node: break + if self.require_independent_nodes and len(self._independent_nodes) == 0: + self._independent_nodes = None + return False + # After the independent set is computed compute the set of dependent nodes # as the set of all nodes adjacent to `outer_entry` that are not dependent. self._dependent_nodes = { diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py index aac58eb32c..67bec9c09f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py @@ -803,3 +803,52 @@ def test_loop_blocking_mixked_memlets_2(): assert isinstance(node, dace_nodes.MapEntry) or (node is mx) else: assert scope_dict[node] is inner_map_entry + + +def test_loop_blocking_no_independent_nodes(): + import dace + + sdfg = dace.SDFG(util.unique_name("mixed_memlet_sdfg")) + state = sdfg.add_state(is_start_block=True) + names = ["A", "B"] + for aname in names: + sdfg.add_array( + aname, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + state.add_mapped_tasklet( + "fully_dependent_computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0, __i1]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("B[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + + # Because there is nothing that is independent the transformation will + # not apply if `require_independent_nodes` is enabled. + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking( + blocking_size=2, + blocking_parameter="__i1", + require_independent_nodes=True, + ), + validate=True, + validate_all=True, + ) + assert count == 0 + + # But it will apply once this requirement is lifted. + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking( + blocking_size=2, + blocking_parameter="__i1", + require_independent_nodes=False, + ), + validate=True, + validate_all=True, + ) + assert count == 1 From 54f176f1e77536c4911d56ebaff35a53a7d37d6d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 6 Dec 2024 15:21:19 +0100 Subject: [PATCH 075/178] refactor[next]: remove FencilDefinition definition (#1774) After this commit `FencilDefinition`s are completely removed. Next step could be to cleanup `itir` -> `gtir` everywhere. --- docs/user/next/advanced/HackTheToolchain.md | 2 +- src/gt4py/next/backend.py | 18 +- src/gt4py/next/ffront/decorator.py | 5 +- src/gt4py/next/ffront/foast_to_itir.py | 512 --------------- src/gt4py/next/ffront/past_to_itir.py | 115 +--- src/gt4py/next/iterator/ir.py | 37 +- src/gt4py/next/iterator/pretty_parser.py | 21 - src/gt4py/next/iterator/pretty_printer.py | 41 -- src/gt4py/next/iterator/tracing.py | 12 +- .../next/iterator/transforms/__init__.py | 4 +- .../iterator/transforms/collapse_tuple.py | 2 +- src/gt4py/next/iterator/transforms/cse.py | 4 +- .../iterator/transforms/fencil_to_program.py | 31 - .../next/iterator/transforms/pass_manager.py | 22 +- .../iterator/transforms/program_to_fencil.py | 31 - .../transforms/prune_closure_inputs.py | 44 -- .../iterator/transforms/symbol_ref_utils.py | 2 +- .../next/iterator/type_system/inference.py | 51 +- .../type_system/type_specifications.py | 24 - src/gt4py/next/otf/stages.py | 4 +- .../codegens/gtfn/gtfn_module.py | 14 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 2 +- .../program_processors/formatters/gtfn.py | 2 +- .../program_processors/formatters/lisp.py | 67 -- .../formatters/pretty_print.py | 2 +- .../program_processors/program_formatter.py | 8 +- .../runners/dace_fieldview/workflow.py | 2 +- .../next/program_processors/runners/gtfn.py | 2 +- .../program_processors/runners/roundtrip.py | 10 +- .../ffront_tests/test_decorator.py | 6 +- .../test_temporaries_with_sizes.py | 6 +- tests/next_tests/unit_tests/conftest.py | 1 - .../ffront_tests/test_foast_to_itir.py | 598 ------------------ .../ffront_tests/test_past_to_gtir.py | 11 +- .../ffront_tests/test_past_to_itir.py | 214 ------- .../transforms_tests/test_domain_inference.py | 21 +- .../test_prune_closure_inputs.py | 68 -- .../dace_tests/test_gtir_to_sdfg.py | 1 + 38 files changed, 91 insertions(+), 1926 deletions(-) delete mode 100644 src/gt4py/next/ffront/foast_to_itir.py delete mode 100644 src/gt4py/next/iterator/transforms/fencil_to_program.py delete mode 100644 src/gt4py/next/iterator/transforms/program_to_fencil.py delete mode 100644 src/gt4py/next/iterator/transforms/prune_closure_inputs.py delete mode 100644 src/gt4py/next/program_processors/formatters/lisp.py delete mode 100644 tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py delete mode 100644 tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py delete mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py diff --git a/docs/user/next/advanced/HackTheToolchain.md b/docs/user/next/advanced/HackTheToolchain.md index 029833cb7d..358f6e8d0d 100644 --- a/docs/user/next/advanced/HackTheToolchain.md +++ b/docs/user/next/advanced/HackTheToolchain.md @@ -15,7 +15,7 @@ from gt4py import eve ```python cached_lowering_toolchain = gtx.backend.DEFAULT_TRANSFORMS.replace( - past_to_itir=gtx.ffront.past_to_itir.past_to_itir_factory(cached=False) + past_to_itir=gtx.ffront.past_to_itir.past_to_gtir_factory(cached=False) ) ``` diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index e223d7771c..e075422ca3 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -16,7 +16,6 @@ from gt4py.next import allocators as next_allocators from gt4py.next.ffront import ( foast_to_gtir, - foast_to_itir, foast_to_past, func_to_foast, func_to_past, @@ -41,7 +40,7 @@ ARGS: typing.TypeAlias = arguments.JITArgs CARG: typing.TypeAlias = arguments.CompileTimeArgs -IT_PRG: typing.TypeAlias = itir.FencilDefinition | itir.Program +IT_PRG: typing.TypeAlias = itir.Program INPUT_DATA: typing.TypeAlias = DSL_FOP | FOP | DSL_PRG | PRG | IT_PRG @@ -93,7 +92,7 @@ class Transforms(workflow.MultiWorkflow[INPUT_PAIR, stages.CompilableProgram]): ) past_to_itir: workflow.Workflow[AOT_PRG, stages.CompilableProgram] = dataclasses.field( - default_factory=past_to_itir.past_to_itir_factory + default_factory=past_to_itir.past_to_gtir_factory ) def step_order(self, inp: INPUT_PAIR) -> list[str]: @@ -126,7 +125,7 @@ def step_order(self, inp: INPUT_PAIR) -> list[str]: ) case PRG(): steps.extend(["past_lint", "field_view_prog_args_transform", "past_to_itir"]) - case itir.FencilDefinition() | itir.Program(): + case itir.Program(): pass case _: raise ValueError("Unexpected input.") @@ -135,17 +134,6 @@ def step_order(self, inp: INPUT_PAIR) -> list[str]: DEFAULT_TRANSFORMS: Transforms = Transforms() -# FIXME[#1582](havogt): remove after refactoring to GTIR -# note: this step is deliberately placed here, such that the cache is shared -_foast_to_itir_step = foast_to_itir.adapted_foast_to_itir_factory(cached=True) -LEGACY_TRANSFORMS: Transforms = Transforms( - past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=False), - foast_to_itir=_foast_to_itir_step, - field_view_op_to_prog=foast_to_past.operator_to_program_factory( - foast_to_itir_step=_foast_to_itir_step - ), -) - # TODO(tehrengruber): Rename class and `executor` & `transforms` attribute. Maybe: # `Backend` -> `Toolchain` diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 61756f30c9..d187095019 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -34,7 +34,6 @@ from gt4py.next.ffront import ( field_operator_ast as foast, foast_to_gtir, - foast_to_itir, past_process_args, signature, stages as ffront_stages, @@ -186,7 +185,7 @@ def _all_closure_vars(self) -> dict[str, Any]: return transform_utils._get_closure_vars_recursively(self.past_stage.closure_vars) @functools.cached_property - def itir(self) -> itir.FencilDefinition: + def gtir(self) -> itir.Program: no_args_past = toolchain.CompilableProgram( data=ffront_stages.PastProgramDefinition( past_node=self.past_stage.past_node, @@ -561,7 +560,7 @@ def with_grid_type(self, grid_type: common.GridType) -> FieldOperator: # a different backend than the one of the program that calls this field operator. Just use # the hard-coded lowering until this is cleaned up. def __gt_itir__(self) -> itir.FunctionDefinition: - return foast_to_itir.foast_to_itir(self.foast_stage) + return foast_to_gtir.foast_to_gtir(self.foast_stage) # FIXME[#1582](tehrengruber): remove after refactoring to GTIR def __gt_gtir__(self) -> itir.FunctionDefinition: diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py deleted file mode 100644 index 538b0f3ddb..0000000000 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ /dev/null @@ -1,512 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -# FIXME[#1582](havogt): remove after refactoring to GTIR - -import dataclasses -from typing import Any, Callable, Optional - -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.eve.extended_typing import Never -from gt4py.eve.utils import UIDGenerator -from gt4py.next import common -from gt4py.next.ffront import ( - dialect_ast_enums, - fbuiltins, - field_operator_ast as foast, - lowering_utils, - stages as ffront_stages, - type_specifications as ts_ffront, -) -from gt4py.next.ffront.experimental import EXPERIMENTAL_FUN_BUILTIN_NAMES -from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES -from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind -from gt4py.next.ffront.stages import AOT_FOP, FOP -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_info, type_specifications as ts - - -def foast_to_itir(inp: FOP) -> itir.Expr: - """ - Lower a FOAST field operator node to Iterator IR. - - See the docstring of `FieldOperatorLowering` for details. - """ - return FieldOperatorLowering.apply(inp.foast_node) - - -def foast_to_itir_factory(cached: bool = True) -> workflow.Workflow[FOP, itir.Expr]: - """Wrap `foast_to_itir` into a chainable and, optionally, cached workflow step.""" - wf = foast_to_itir - if cached: - wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage) - return wf - - -def adapted_foast_to_itir_factory(**kwargs: Any) -> workflow.Workflow[AOT_FOP, itir.Expr]: - """Wrap the `foast_to_itir` workflow step into an adapter to fit into backend transform workflows.""" - return toolchain.StripArgsAdapter(foast_to_itir_factory(**kwargs)) - - -def promote_to_list(node_type: ts.TypeSpec) -> Callable[[itir.Expr], itir.Expr]: - if not type_info.contains_local_field(node_type): - return lambda x: im.promote_to_lifted_stencil("make_const_list")(x) - return lambda x: x - - -@dataclasses.dataclass -class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): - """ - Lower FieldOperator AST (FOAST) to Iterator IR (ITIR). - - The strategy is to lower every expression to lifted stencils, - i.e. taking iterators and returning iterator. - - Examples - -------- - >>> from gt4py.next.ffront.func_to_foast import FieldOperatorParser - >>> from gt4py.next import Field, Dimension, float64 - >>> - >>> IDim = Dimension("IDim") - >>> def fieldop(inp: Field[[IDim], "float64"]): - ... return inp - >>> - >>> parsed = FieldOperatorParser.apply_to_function(fieldop) - >>> lowered = FieldOperatorLowering.apply(parsed) - >>> type(lowered) - - >>> lowered.id - SymbolName('fieldop') - >>> lowered.params # doctest: +ELLIPSIS - [Sym(id=SymbolName('inp'))] - """ - - uid_generator: UIDGenerator = dataclasses.field(default_factory=UIDGenerator) - - @classmethod - def apply(cls, node: foast.LocatedNode) -> itir.Expr: - return cls().visit(node) - - def visit_FunctionDefinition( - self, node: foast.FunctionDefinition, **kwargs: Any - ) -> itir.FunctionDefinition: - params = self.visit(node.params) - return itir.FunctionDefinition( - id=node.id, params=params, expr=self.visit_BlockStmt(node.body, inner_expr=None) - ) # `expr` is a lifted stencil - - def visit_FieldOperator( - self, node: foast.FieldOperator, **kwargs: Any - ) -> itir.FunctionDefinition: - func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) - - new_body = func_definition.expr - - return itir.FunctionDefinition( - id=func_definition.id, params=func_definition.params, expr=new_body - ) - - def visit_ScanOperator( - self, node: foast.ScanOperator, **kwargs: Any - ) -> itir.FunctionDefinition: - # note: we don't need the axis here as this is handled by the program - # decorator - assert isinstance(node.type, ts_ffront.ScanOperatorType) - - # We are lowering node.forward and node.init to iterators, but here we expect values -> `deref`. - # In iterator IR we didn't properly specify if this is legal, - # however after lift-inlining the expressions are transformed back to literals. - forward = im.deref(self.visit(node.forward, **kwargs)) - init = lowering_utils.process_elements( - im.deref, self.visit(node.init, **kwargs), node.init.type - ) - - # lower definition function - func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) - new_body = im.let( - func_definition.params[0].id, - # promote carry to iterator of tuples - # (this is the only place in the lowering were a variable is captured in a lifted lambda) - lowering_utils.to_tuples_of_iterator( - im.promote_to_const_iterator(func_definition.params[0].id), - [*node.type.definition.pos_or_kw_args.values()][0], # noqa: RUF015 [unnecessary-iterable-allocation-for-first-element] - ), - )( - # the function itself returns a tuple of iterators, deref element-wise - lowering_utils.process_elements( - im.deref, func_definition.expr, node.type.definition.returns - ) - ) - - stencil_args: list[itir.Expr] = [] - assert not node.type.definition.pos_only_args and not node.type.definition.kw_only_args - for param, arg_type in zip( - func_definition.params[1:], - [*node.type.definition.pos_or_kw_args.values()][1:], - strict=True, - ): - if isinstance(arg_type, ts.TupleType): - # convert into iterator of tuples - stencil_args.append(lowering_utils.to_iterator_of_tuples(param.id, arg_type)) - - new_body = im.let( - param.id, lowering_utils.to_tuples_of_iterator(param.id, arg_type) - )(new_body) - else: - stencil_args.append(im.ref(param.id)) - - definition = itir.Lambda(params=func_definition.params, expr=new_body) - - body = im.lift(im.call("scan")(definition, forward, init))(*stencil_args) - - return itir.FunctionDefinition(id=node.id, params=definition.params[1:], expr=body) - - def visit_Stmt(self, node: foast.Stmt, **kwargs: Any) -> Never: - raise AssertionError("Statements must always be visited in the context of a function.") - - def visit_Return( - self, node: foast.Return, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - return self.visit(node.value, **kwargs) - - def visit_BlockStmt( - self, node: foast.BlockStmt, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - for stmt in reversed(node.stmts): - inner_expr = self.visit(stmt, inner_expr=inner_expr, **kwargs) - assert inner_expr - return inner_expr - - def visit_IfStmt( - self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - # the lowered if call doesn't need to be lifted as the condition can only originate - # from a scalar value (and not a field) - assert ( - isinstance(node.condition.type, ts.ScalarType) - and node.condition.type.kind == ts.ScalarKind.BOOL - ) - - cond = self.visit(node.condition, **kwargs) - - return_kind: StmtReturnKind = deduce_stmt_return_kind(node) - - common_symbols: dict[str, foast.Symbol] = node.annex.propagated_symbols - - if return_kind is StmtReturnKind.NO_RETURN: - # pack the common symbols into a tuple - common_symrefs = im.make_tuple(*(im.ref(sym) for sym in common_symbols.keys())) - - # apply both branches and extract the common symbols through the prepared tuple - true_branch = self.visit(node.true_branch, inner_expr=common_symrefs, **kwargs) - false_branch = self.visit(node.false_branch, inner_expr=common_symrefs, **kwargs) - - # unpack the common symbols' tuple for `inner_expr` - for i, sym in enumerate(common_symbols.keys()): - inner_expr = im.let(sym, im.tuple_get(i, im.ref("__if_stmt_result")))(inner_expr) - - # here we assume neither branch returns - return im.let("__if_stmt_result", im.if_(im.deref(cond), true_branch, false_branch))( - inner_expr - ) - elif return_kind is StmtReturnKind.CONDITIONAL_RETURN: - common_syms = tuple(im.sym(sym) for sym in common_symbols.keys()) - common_symrefs = tuple(im.ref(sym) for sym in common_symbols.keys()) - - # wrap the inner expression in a lambda function. note that this increases the - # operation count if both branches are evaluated. - inner_expr_name = self.uid_generator.sequential_id(prefix="__inner_expr") - inner_expr_evaluator = im.lambda_(*common_syms)(inner_expr) - inner_expr = im.call(inner_expr_name)(*common_symrefs) - - true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) - false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - - return im.let(inner_expr_name, inner_expr_evaluator)( - im.if_(im.deref(cond), true_branch, false_branch) - ) - - assert return_kind is StmtReturnKind.UNCONDITIONAL_RETURN - - # note that we do not duplicate `inner_expr` here since if both branches - # return, `inner_expr` is ignored. - true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) - false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - - return im.if_(im.deref(cond), true_branch, false_branch) - - def visit_Assign( - self, node: foast.Assign, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - return im.let(self.visit(node.target, **kwargs), self.visit(node.value, **kwargs))( - inner_expr - ) - - def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym: - return im.sym(node.id) - - def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef: - return im.ref(node.id) - - def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: - return im.tuple_get(node.index, self.visit(node.value, **kwargs)) - - def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: - return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) - - def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: - # TODO(tehrengruber): extend iterator ir to support unary operators - dtype = type_info.extract_dtype(node.type) - if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: - if dtype.kind != ts.ScalarKind.BOOL: - raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") - return self._lower_and_map("not_", node.operand) - - return self._lower_and_map( - node.op.value, - foast.Constant(value="0", type=dtype, location=node.location), - node.operand, - ) - - def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: - return self._lower_and_map(node.op.value, node.left, node.right) - - def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall: - op = "if_" - args = (node.condition, node.true_expr, node.false_expr) - lowered_args: list[itir.Expr] = [ - lowering_utils.to_iterator_of_tuples(self.visit(arg, **kwargs), arg.type) - for arg in args - ] - if any(type_info.contains_local_field(arg.type) for arg in args): - lowered_args = [ - promote_to_list(arg.type)(larg) for arg, larg in zip(args, lowered_args) - ] - op = im.call("map_")(op) - - return lowering_utils.to_tuples_of_iterator( - im.promote_to_lifted_stencil(im.call(op))(*lowered_args), node.type - ) - - def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: - return self._lower_and_map(node.op.value, node.left, node.right) - - def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - current_expr = self.visit(node.func, **kwargs) - - for arg in node.args: - match arg: - # `field(Off[idx])` - case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)): - current_expr = im.lift( - im.lambda_("it")(im.deref(im.shift(offset_name, offset_index)("it"))) - )(current_expr) - # `field(Dim + idx)` - case foast.BinOp( - op=dialect_ast_enums.BinaryOperator.ADD - | dialect_ast_enums.BinaryOperator.SUB, - left=foast.Name(id=dimension), - right=foast.Constant(value=offset_index), - ): - if arg.op == dialect_ast_enums.BinaryOperator.SUB: - offset_index *= -1 - current_expr = im.lift( - # TODO(SF-N): we rely on the naming-convention that the cartesian dimensions - # are passed suffixed with `off`, e.g. the `K` is passed as `Koff` in the - # offset provider. This is a rather unclean solution and should be - # improved. - im.lambda_("it")( - im.deref( - im.shift( - common.dimension_to_implicit_offset(dimension), offset_index - )("it") - ) - ) - )(current_expr) - # `field(Off)` - case foast.Name(id=offset_name): - # only a single unstructured shift is supported so returning here is fine even though we - # are in a loop. - assert len(node.args) == 1 and len(arg.type.target) > 1 # type: ignore[attr-defined] # ensured by pattern - return im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs)) - # `field(as_offset(Off, offset_field))` - case foast.Call(func=foast.Name(id="as_offset")): - func_args = arg - # TODO(tehrengruber): Use type system to deduce the offset dimension instead of - # (e.g. to allow aliasing) - offset_dim = func_args.args[0] - assert isinstance(offset_dim, foast.Name) - offset_it = self.visit(func_args.args[1], **kwargs) - current_expr = im.lift( - im.lambda_("it", "offset")( - im.deref(im.shift(offset_dim.id, im.deref("offset"))("it")) - ) - )(current_expr, offset_it) - case _: - raise FieldOperatorLoweringError("Unexpected shift arguments!") - - return current_expr - - def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - if type_info.type_class(node.func.type) is ts.FieldType: - return self._visit_shift(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES: - return self._visit_math_built_in(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in ( - FUN_BUILTIN_NAMES + EXPERIMENTAL_FUN_BUILTIN_NAMES - ): - visitor = getattr(self, f"_visit_{node.func.id}") - return visitor(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in TYPE_BUILTIN_NAMES: - return self._visit_type_constr(node, **kwargs) - elif isinstance( - node.func.type, - (ts.FunctionType, ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType), - ): - # ITIR has no support for keyword arguments. Instead, we concatenate both positional - # and keyword arguments and use the unique order as given in the function signature. - lowered_args, lowered_kwargs = type_info.canonicalize_arguments( - node.func.type, - self.visit(node.args, **kwargs), - self.visit(node.kwargs, **kwargs), - use_signature_ordering=True, - ) - result = im.call(self.visit(node.func, **kwargs))( - *lowered_args, *lowered_kwargs.values() - ) - - # scan operators return an iterator of tuples, transform into tuples of iterator again - if isinstance(node.func.type, ts_ffront.ScanOperatorType): - result = lowering_utils.to_tuples_of_iterator( - result, node.func.type.definition.returns - ) - - return result - - raise AssertionError( - f"Call to object of type '{type(node.func.type).__name__}' not understood." - ) - - def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) - obj, new_type = node.args[0], node.args[1].id - return lowering_utils.process_elements( - lambda x: im.promote_to_lifted_stencil( - im.lambda_("it")(im.call("cast_")("it", str(new_type))) - )(x), - self.visit(obj, **kwargs), - obj.type, - ) - - def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - condition, true_value, false_value = node.args - - lowered_condition = self.visit(condition, **kwargs) - return lowering_utils.process_elements( - lambda tv, fv, types: _map( - "if_", (lowered_condition, tv, fv), (condition.type, *types) - ), - [self.visit(true_value, **kwargs), self.visit(false_value, **kwargs)], - node.type, - (node.args[1].type, node.args[2].type), - ) - - _visit_concat_where = _visit_where - - def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self.visit(node.args[0], **kwargs) - - def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self._lower_and_map(self.visit(node.func, **kwargs), *node.args) - - def _make_reduction_expr( - self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any - ) -> itir.Expr: - # TODO(havogt): deal with nested reductions of the form neighbor_sum(neighbor_sum(field(off1)(off2))) - it = self.visit(node.args[0], **kwargs) - assert isinstance(node.kwargs["axis"].type, ts.DimensionType) - val = im.call(im.call("reduce")(op, im.deref(init_expr))) - return im.promote_to_lifted_stencil(val)(it) - - def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - dtype = type_info.extract_dtype(node.type) - return self._make_reduction_expr(node, "plus", self._make_literal("0", dtype), **kwargs) - - def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - dtype = type_info.extract_dtype(node.type) - min_value, _ = type_info.arithmetic_bounds(dtype) - init_expr = self._make_literal(str(min_value), dtype) - return self._make_reduction_expr(node, "maximum", init_expr, **kwargs) - - def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - dtype = type_info.extract_dtype(node.type) - _, max_value = type_info.arithmetic_bounds(dtype) - init_expr = self._make_literal(str(max_value), dtype) - return self._make_reduction_expr(node, "minimum", init_expr, **kwargs) - - def _visit_type_constr(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - el = node.args[0] - node_kind = self.visit(node.type).kind.name.lower() - source_type = {**fbuiltins.BUILTINS, "string": str}[el.type.__str__().lower()] - target_type = fbuiltins.BUILTINS[node_kind] - - if isinstance(el, foast.Constant): - val = source_type(el.value) - elif isinstance(el, foast.UnaryOp) and isinstance(el.operand, foast.Constant): - operand = source_type(el.operand.value) - val = eval(f"lambda arg: {el.op}arg")(operand) - else: - raise FieldOperatorLoweringError( - f"Type cast only supports literal arguments, {node.type} not supported." - ) - val = target_type(val) - - return im.promote_to_const_iterator(im.literal(str(val), node_kind)) - - def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: - # TODO(havogt): lifted nullary lambdas are not supported in iterator.embedded due to an implementation detail; - # the following constructs work if they are removed by inlining. - if isinstance(type_, ts.TupleType): - return im.make_tuple( - *(self._make_literal(val, type_) for val, type_ in zip(val, type_.types)) - ) - elif isinstance(type_, ts.ScalarType): - typename = type_.kind.name.lower() - return im.promote_to_const_iterator(im.literal(str(val), typename)) - raise ValueError(f"Unsupported literal type '{type_}'.") - - def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: - return self._make_literal(node.value, node.type) - - def _lower_and_map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: - return _map( - op, tuple(self.visit(arg, **kwargs) for arg in args), tuple(arg.type for arg in args) - ) - - -def _map( - op: itir.Expr | str, - lowered_args: tuple, - original_arg_types: tuple[ts.TypeSpec, ...], -) -> itir.FunCall: - """ - Mapping includes making the operation an lifted stencil (first kind of mapping), but also `itir.map_`ing lists. - """ - if any(type_info.contains_local_field(arg_type) for arg_type in original_arg_types): - lowered_args = tuple( - promote_to_list(arg_type)(larg) - for arg_type, larg in zip(original_arg_types, lowered_args) - ) - op = im.call("map_")(op) - - return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) - - -class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index c0348bb5c6..4ec12bb76b 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -9,7 +9,6 @@ from __future__ import annotations import dataclasses -import functools from typing import Any, Optional, cast import devtools @@ -19,7 +18,6 @@ from gt4py.next.ffront import ( fbuiltins, gtcallable, - lowering_utils, program_ast as past, stages as ffront_stages, transform_utils, @@ -32,10 +30,9 @@ from gt4py.next.type_system import type_info, type_specifications as ts -# FIXME[#1582](havogt): remove `to_gtir` arg after refactoring to GTIR # FIXME[#1582](tehrengruber): This should only depend on the program not the arguments. Remove # dependency as soon as column axis can be deduced from ITIR in consumers of the CompilableProgram. -def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgram: +def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: """ Lower a PAST program definition to Iterator IR. @@ -59,7 +56,7 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra ... column_axis=None, ... ) - >>> itir_copy = past_to_itir( + >>> itir_copy = past_to_gtir( ... toolchain.CompilableProgram(copy_program.past_stage, compile_time_args) ... ) @@ -67,7 +64,7 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra copy_program >>> print(type(itir_copy.data)) - + """ all_closure_vars = transform_utils._get_closure_vars_recursively(inp.data.closure_vars) offsets_and_dimensions = transform_utils._filter_closure_vars_by_type( @@ -88,13 +85,10 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra # making this step aware of the toolchain it is called by (it can be part of multiple). lowered_funcs = [] for gt_callable in gt_callables: - if to_gtir: - lowered_funcs.append(gt_callable.__gt_gtir__()) - else: - lowered_funcs.append(gt_callable.__gt_itir__()) + lowered_funcs.append(gt_callable.__gt_gtir__()) itir_program = ProgramLowering.apply( - inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type, to_gtir=to_gtir + inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type ) if config.DEBUG or inp.data.debug: @@ -106,11 +100,10 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra ) -# FIXME[#1582](havogt): remove `to_gtir` arg after refactoring to GTIR -def past_to_itir_factory( - cached: bool = True, to_gtir: bool = True +def past_to_gtir_factory( + cached: bool = True, ) -> workflow.Workflow[AOT_PRG, stages.CompilableProgram]: - wf = workflow.make_step(functools.partial(past_to_itir, to_gtir=to_gtir)) + wf = workflow.make_step(past_to_gtir) if cached: wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) return wf @@ -190,7 +183,7 @@ class ProgramLowering( ... parsed, [fieldop_def], grid_type=common.GridType.CARTESIAN ... ) # doctest: +SKIP >>> type(lowered) # doctest: +SKIP - + >>> lowered.id # doctest: +SKIP SymbolName('program') >>> lowered.params # doctest: +SKIP @@ -198,7 +191,6 @@ class ProgramLowering( """ grid_type: common.GridType - to_gtir: bool = False # FIXME[#1582](havogt): remove after refactoring to GTIR # TODO(tehrengruber): enable doctests again. For unknown / obscure reasons # the above doctest fails when executed using `pytest --doctest-modules`. @@ -209,11 +201,8 @@ def apply( node: past.Program, function_definitions: list[itir.FunctionDefinition], grid_type: common.GridType, - to_gtir: bool = False, # FIXME[#1582](havogt): remove after refactoring to GTIR - ) -> itir.FencilDefinition: - return cls(grid_type=grid_type, to_gtir=to_gtir).visit( - node, function_definitions=function_definitions - ) + ) -> itir.Program: + return cls(grid_type=grid_type).visit(node, function_definitions=function_definitions) def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: """Generate symbols for each field param and dimension.""" @@ -246,7 +235,7 @@ def visit_Program( *, function_definitions: list[itir.FunctionDefinition], **kwargs: Any, - ) -> itir.FencilDefinition | itir.Program: + ) -> itir.Program: # The ITIR does not support dynamically getting the size of a field. As # a workaround we add additional arguments to the fencil definition # containing the size of all fields. The caller of a program is (e.g. @@ -259,27 +248,17 @@ def visit_Program( params = params + self._gen_size_params_from_program(node) implicit_domain = True - if self.to_gtir: - set_ats = [self._visit_stencil_call_as_set_at(stmt, **kwargs) for stmt in node.body] - return itir.Program( - id=node.id, - function_definitions=function_definitions, - params=params, - declarations=[], - body=set_ats, - implicit_domain=implicit_domain, - ) - else: - closures = [self._visit_stencil_call_as_closure(stmt, **kwargs) for stmt in node.body] - return itir.FencilDefinition( - id=node.id, - function_definitions=function_definitions, - params=params, - closures=closures, - implicit_domain=implicit_domain, - ) + set_ats = [self._visit_field_operator_call(stmt, **kwargs) for stmt in node.body] + return itir.Program( + id=node.id, + function_definitions=function_definitions, + params=params, + declarations=[], + body=set_ats, + implicit_domain=implicit_domain, + ) - def _visit_stencil_call_as_set_at(self, node: past.Call, **kwargs: Any) -> itir.SetAt: + def _visit_field_operator_call(self, node: past.Call, **kwargs: Any) -> itir.SetAt: assert isinstance(node.kwargs["out"].type, ts.TypeSpec) assert type_info.is_type_or_tuple_of_type(node.kwargs["out"].type, ts.FieldType) @@ -303,56 +282,6 @@ def _visit_stencil_call_as_set_at(self, node: past.Call, **kwargs: Any) -> itir. target=output, ) - # FIXME[#1582](havogt): remove after refactoring to GTIR - def _visit_stencil_call_as_closure(self, node: past.Call, **kwargs: Any) -> itir.StencilClosure: - assert isinstance(node.kwargs["out"].type, ts.TypeSpec) - assert type_info.is_type_or_tuple_of_type(node.kwargs["out"].type, ts.FieldType) - - node_kwargs = {**node.kwargs} - domain = node_kwargs.pop("domain", None) - output, lowered_domain = self._visit_stencil_call_out_arg( - node_kwargs.pop("out"), domain, **kwargs - ) - - assert isinstance(node.func.type, (ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType)) - - args, node_kwargs = type_info.canonicalize_arguments( - node.func.type, node.args, node_kwargs, use_signature_ordering=True - ) - - lowered_args, lowered_kwargs = self.visit(args, **kwargs), self.visit(node_kwargs, **kwargs) - - stencil_params = [] - stencil_args: list[itir.Expr] = [] - for i, arg in enumerate([*args, *node_kwargs]): - stencil_params.append(f"__stencil_arg{i}") - if isinstance(arg.type, ts.TupleType): - # convert into tuple of iterators - stencil_args.append( - lowering_utils.to_tuples_of_iterator(f"__stencil_arg{i}", arg.type) - ) - else: - stencil_args.append(im.ref(f"__stencil_arg{i}")) - - if isinstance(node.func.type, ts_ffront.ScanOperatorType): - # scan operators return an iterator of tuples, just deref directly - stencil_body = im.deref(im.call(node.func.id)(*stencil_args)) - else: - # field operators return a tuple of iterators, deref element-wise - stencil_body = lowering_utils.process_elements( - im.deref, - im.call(node.func.id)(*stencil_args), - node.func.type.definition.returns, - ) - - return itir.StencilClosure( - domain=lowered_domain, - stencil=im.lambda_(*stencil_params)(stencil_body), - inputs=[*lowered_args, *lowered_kwargs.values()], - output=output, - location=node.location, - ) - def _visit_slice_bound( self, slice_bound: Optional[past.Constant], diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 6efee29362..e875709631 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -9,7 +9,7 @@ from typing import ClassVar, List, Optional, Union import gt4py.eve as eve -from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels +from gt4py.eve import Coerced, SymbolName, SymbolRef from gt4py.eve.concepts import SourceLocation from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.eve.utils import noninstantiable @@ -19,10 +19,6 @@ DimensionKind = common.DimensionKind -# TODO(havogt): -# After completion of refactoring to GTIR, FencilDefinition and StencilClosure should be removed everywhere. -# During transition, we lower to FencilDefinitions and apply a transformation to GTIR-style afterwards. - @noninstantiable class Node(eve.Node): @@ -97,23 +93,6 @@ class FunctionDefinition(Node, SymbolTableTrait): expr: Expr -class StencilClosure(Node): - domain: FunCall - stencil: Expr - output: Union[SymRef, FunCall] - inputs: List[Union[SymRef, FunCall]] - - @datamodels.validator("output") - def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): - if isinstance(value, FunCall) and value.fun != SymRef(id="make_tuple"): - raise ValueError("Only FunCall to 'make_tuple' allowed.") - - @datamodels.validator("inputs") - def _input_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): - if any(isinstance(v, FunCall) and v.fun != SymRef(id="index") for v in value): - raise ValueError("Only FunCall to 'index' allowed.") - - UNARY_MATH_NUMBER_BUILTINS = {"abs"} UNARY_LOGICAL_BUILTINS = {"not_"} UNARY_MATH_FP_BUILTINS = { @@ -195,18 +174,6 @@ def _input_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribu } -class FencilDefinition(Node, ValidatedSymbolTableTrait): - id: Coerced[SymbolName] - function_definitions: List[FunctionDefinition] - params: List[Sym] - closures: List[StencilClosure] - implicit_domain: bool = False - - _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ - Sym(id=name) for name in sorted(BUILTINS) - ] # sorted for serialization stability - - class Stmt(Node): ... @@ -252,8 +219,6 @@ class Program(Node, ValidatedSymbolTableTrait): Lambda.__hash__ = Node.__hash__ # type: ignore[method-assign] FunCall.__hash__ = Node.__hash__ # type: ignore[method-assign] FunctionDefinition.__hash__ = Node.__hash__ # type: ignore[method-assign] -StencilClosure.__hash__ = Node.__hash__ # type: ignore[method-assign] -FencilDefinition.__hash__ = Node.__hash__ # type: ignore[method-assign] Program.__hash__ = Node.__hash__ # type: ignore[method-assign] SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign] IfStmt.__hash__ = Node.__hash__ # type: ignore[method-assign] diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index b4a673772f..29b30beae1 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -216,10 +216,6 @@ def function_definition(self, *args: ir.Node) -> ir.FunctionDefinition: fid, *params, expr = args return ir.FunctionDefinition(id=fid, params=params, expr=expr) - def stencil_closure(self, *args: ir.Expr) -> ir.StencilClosure: - output, stencil, *inputs, domain = args - return ir.StencilClosure(domain=domain, stencil=stencil, output=output, inputs=inputs) - def if_stmt(self, cond: ir.Expr, *args): found_else_seperator = False true_branch = [] @@ -249,23 +245,6 @@ def set_at(self, *args: ir.Expr) -> ir.SetAt: target, domain, expr = args return ir.SetAt(expr=expr, domain=domain, target=target) - # TODO(havogt): remove after refactoring. - def fencil_definition(self, fid: str, *args: ir.Node) -> ir.FencilDefinition: - params = [] - function_definitions = [] - closures = [] - for arg in args: - if isinstance(arg, ir.Sym): - params.append(arg) - elif isinstance(arg, ir.FunctionDefinition): - function_definitions.append(arg) - else: - assert isinstance(arg, ir.StencilClosure) - closures.append(arg) - return ir.FencilDefinition( - id=fid, function_definitions=function_definitions, params=params, closures=closures - ) - def program(self, fid: str, *args: ir.Node) -> ir.Program: params = [] function_definitions = [] diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 99287f8a11..a25f99356c 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -248,28 +248,6 @@ def visit_FunctionDefinition(self, node: ir.FunctionDefinition, prec: int) -> li vbody = self._vmerge(params, self._indent(expr)) return self._optimum(hbody, vbody) - def visit_StencilClosure(self, node: ir.StencilClosure, *, prec: int) -> list[str]: - assert prec == 0 - domain = self.visit(node.domain, prec=0) - stencil = self.visit(node.stencil, prec=0) - output = self.visit(node.output, prec=0) - inputs = self.visit(node.inputs, prec=0) - - hinputs = self._hmerge(["("], *self._hinterleave(inputs, ", "), [")"]) - vinputs = self._vmerge(["("], *self._hinterleave(inputs, ",", indent=True), [")"]) - inputs = self._optimum(hinputs, vinputs) - - head = self._hmerge(output, [" ← "]) - foot = self._hmerge(inputs, [" @ "], domain, [";"]) - - h = self._hmerge(head, ["("], stencil, [")"], foot) - v = self._vmerge( - self._hmerge(head, ["("]), - self._indent(self._indent(stencil)), - self._indent(self._hmerge([")"], foot)), - ) - return self._optimum(h, v) - def visit_Temporary(self, node: ir.Temporary, *, prec: int) -> list[str]: start, end = [node.id + " = temporary("], [");"] args = [] @@ -312,25 +290,6 @@ def visit_IfStmt(self, node: ir.IfStmt, *, prec: int) -> list[str]: head, self._indent(true_branch), ["} else {"], self._indent(false_branch), ["}"] ) - def visit_FencilDefinition(self, node: ir.FencilDefinition, *, prec: int) -> list[str]: - assert prec == 0 - function_definitions = self.visit(node.function_definitions, prec=0) - closures = self.visit(node.closures, prec=0) - params = self.visit(node.params, prec=0) - - hparams = self._hmerge([node.id + "("], *self._hinterleave(params, ", "), [") {"]) - vparams = self._vmerge( - [node.id + "("], *self._hinterleave(params, ",", indent=True), [") {"] - ) - params = self._optimum(hparams, vparams) - - function_definitions = self._vmerge(*function_definitions) - closures = self._vmerge(*closures) - - return self._vmerge( - params, self._indent(function_definitions), self._indent(closures), ["}"] - ) - def visit_Program(self, node: ir.Program, *, prec: int) -> list[str]: assert prec == 0 function_definitions = self.visit(node.function_definitions, prec=0) diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 81e9551e5c..12c86680b5 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -258,7 +258,7 @@ def _contains_tuple_dtype_field(arg): return isinstance(arg, common.Field) and any(dim is None for dim in arg.domain.dims) -def _make_fencil_params(fun, args) -> list[Sym]: +def _make_program_params(fun, args) -> list[Sym]: params: list[Sym] = [] param_infos = list(inspect.signature(fun).parameters.values()) @@ -293,18 +293,16 @@ def _make_fencil_params(fun, args) -> list[Sym]: return params -def trace_fencil_definition( - fun: typing.Callable, args: typing.Iterable -) -> itir.FencilDefinition | itir.Program: +def trace_fencil_definition(fun: typing.Callable, args: typing.Iterable) -> itir.Program: """ - Transform fencil given as a callable into `itir.FencilDefinition` using tracing. + Transform fencil given as a callable into `itir.Program` using tracing. Arguments: - fun: The fencil / callable to trace. + fun: The program / callable to trace. args: A list of arguments, e.g. fields, scalars, composites thereof, or directly a type. """ with TracerContext() as _: - params = _make_fencil_params(fun, args) + params = _make_program_params(fun, args) trace_function_call(fun, args=(_s(param.id) for param in params)) return itir.Program( diff --git a/src/gt4py/next/iterator/transforms/__init__.py b/src/gt4py/next/iterator/transforms/__init__.py index aeccb5f26d..d0afc610e7 100644 --- a/src/gt4py/next/iterator/transforms/__init__.py +++ b/src/gt4py/next/iterator/transforms/__init__.py @@ -7,10 +7,10 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.next.iterator.transforms.pass_manager import ( - ITIRTransform, + GTIRTransform, apply_common_transforms, apply_fieldview_transforms, ) -__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "ITIRTransform"] +__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "GTIRTransform"] diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index e71a24127f..b64886f729 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -128,7 +128,7 @@ def apply( flags = flags or cls.flags offset_provider_type = offset_provider_type or {} - if isinstance(node, (ir.Program, ir.FencilDefinition)): + if isinstance(node, ir.Program): within_stencil = False assert within_stencil in [ True, diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 824adfdd8d..4f3fcbfdd5 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -376,7 +376,7 @@ def extract_subexpression( return _NodeReplacer(expr_map).visit(node), extracted, ignored_children -ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.FencilDefinition | itir.Expr) +ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.Expr) @dataclasses.dataclass(frozen=True) @@ -413,7 +413,7 @@ def apply( within_stencil: bool | None = None, offset_provider_type: common.OffsetProviderType | None = None, ) -> ProgramOrExpr: - is_program = isinstance(node, (itir.Program, itir.FencilDefinition)) + is_program = isinstance(node, itir.Program) if is_program: assert within_stencil is None within_stencil = False diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py deleted file mode 100644 index 4ad91645d4..0000000000 --- a/src/gt4py/next/iterator/transforms/fencil_to_program.py +++ /dev/null @@ -1,31 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py import eve -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im - - -class FencilToProgram(eve.NodeTranslator): - @classmethod - def apply(cls, node: itir.FencilDefinition | itir.Program) -> itir.Program: - return cls().visit(node) - - def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.SetAt: - as_fieldop = im.call(im.call("as_fieldop")(node.stencil, node.domain))(*node.inputs) - return itir.SetAt(expr=as_fieldop, domain=node.domain, target=node.output) - - def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program: - return itir.Program( - id=node.id, - function_definitions=node.function_definitions, - params=node.params, - declarations=[], - body=self.visit(node.closures), - implicit_domain=node.implicit_domain, - ) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index ec6f89685a..ec4207d726 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -6,13 +6,12 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Callable, Optional, Protocol +from typing import Optional, Protocol from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import ( - fencil_to_program, fuse_as_fieldop, global_tmps, infer_domain, @@ -32,16 +31,16 @@ from gt4py.next.iterator.type_system.inference import infer -class ITIRTransform(Protocol): +class GTIRTransform(Protocol): def __call__( - self, _: itir.Program | itir.FencilDefinition, *, offset_provider: common.OffsetProvider + self, _: itir.Program, *, offset_provider: common.OffsetProvider ) -> itir.Program: ... # TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward # `extract_temporaries` and `temporary_extraction_heuristics` which is inconvenient. def apply_common_transforms( - ir: itir.Program | itir.FencilDefinition, + ir: itir.Program, *, offset_provider=None, # TODO(havogt): should be replaced by offset_provider_type, but global_tmps currently relies on runtime info extract_temporaries=False, @@ -49,10 +48,6 @@ def apply_common_transforms( common_subexpression_elimination=True, force_inline_lambda_args=False, unconditionally_collapse_tuples=False, - # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. symbolic_domain_sizes: Optional[dict[str, str]] = None, @@ -62,9 +57,6 @@ def apply_common_transforms( if offset_provider_type is None: offset_provider_type = common.offset_provider_to_type(offset_provider) - # FIXME[#1582](tehrengruber): Rewrite iterator tests with itir.Program and remove this - if isinstance(ir, itir.FencilDefinition): - ir = fencil_to_program.FencilToProgram.apply(ir) assert isinstance(ir, itir.Program) tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") @@ -73,7 +65,7 @@ def apply_common_transforms( ir = MergeLet().visit(ir) ir = inline_fundefs.InlineFundefs().visit(ir) - ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program + ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = NormalizeShifts().visit(ir) # note: this increases the size of the tree @@ -82,7 +74,7 @@ def apply_common_transforms( # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( - ir, # type: ignore[arg-type] # always an itir.Program + ir, offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, ) @@ -119,7 +111,7 @@ def apply_common_transforms( if extract_temporaries: ir = infer(ir, inplace=True, offset_provider_type=offset_provider_type) - ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # type: ignore[arg-type] # always an itir.Program + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can diff --git a/src/gt4py/next/iterator/transforms/program_to_fencil.py b/src/gt4py/next/iterator/transforms/program_to_fencil.py deleted file mode 100644 index 4411dda74f..0000000000 --- a/src/gt4py/next/iterator/transforms/program_to_fencil.py +++ /dev/null @@ -1,31 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm - - -def program_to_fencil(node: itir.Program) -> itir.FencilDefinition: - assert not node.declarations - closures = [] - for stmt in node.body: - assert isinstance(stmt, itir.SetAt) - assert isinstance(stmt.expr, itir.FunCall) and cpm.is_call_to(stmt.expr.fun, "as_fieldop") - stencil, domain = stmt.expr.fun.args - inputs = stmt.expr.args - assert all(isinstance(inp, itir.SymRef) for inp in inputs) - closures.append( - itir.StencilClosure(domain=domain, stencil=stencil, output=stmt.target, inputs=inputs) - ) - - return itir.FencilDefinition( - id=node.id, - function_definitions=node.function_definitions, - params=node.params, - closures=closures, - ) diff --git a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py deleted file mode 100644 index 5058a91216..0000000000 --- a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py +++ /dev/null @@ -1,44 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import ir - - -class PruneClosureInputs(PreserveLocationVisitor, NodeTranslator): - """Removes all unused input arguments from a stencil closure.""" - - def visit_StencilClosure(self, node: ir.StencilClosure) -> ir.StencilClosure: - if not isinstance(node.stencil, ir.Lambda): - return node - - unused: set[str] = {p.id for p in node.stencil.params} - expr = self.visit(node.stencil.expr, unused=unused, shadowed=set[str]()) - params = [] - inputs = [] - for param, inp in zip(node.stencil.params, node.inputs): - if param.id not in unused: - params.append(param) - inputs.append(inp) - - return ir.StencilClosure( - domain=node.domain, - stencil=ir.Lambda(params=params, expr=expr), - output=node.output, - inputs=inputs, - ) - - def visit_SymRef(self, node: ir.SymRef, *, unused: set[str], shadowed: set[str]) -> ir.SymRef: - if node.id not in shadowed: - unused.discard(node.id) - return node - - def visit_Lambda(self, node: ir.Lambda, *, unused: set[str], shadowed: set[str]) -> ir.Lambda: - return self.generic_visit( - node, unused=unused, shadowed=shadowed | {p.id for p in node.params} - ) diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 1765259a81..2903201083 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -69,7 +69,7 @@ def apply( Counter({SymRef(id=SymbolRef('x')): 2, SymRef(id=SymbolRef('y')): 2, SymRef(id=SymbolRef('z')): 1}) """ if ignore_builtins: - inactive_refs = {str(n.id) for n in itir.FencilDefinition._NODE_SYMBOLS_} + inactive_refs = {str(n.id) for n in itir.Program._NODE_SYMBOLS_} else: inactive_refs = set() diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index ffca6cc7a7..1b980783fa 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -352,7 +352,7 @@ def apply( Preconditions: - All parameters in :class:`itir.Program` and :class:`itir.FencilDefinition` must have a type + All parameters in :class:`itir.Program` must have a type defined, as they are the starting point for type propagation. Design decisions: @@ -401,9 +401,9 @@ def apply( # parts of a program. node = SanitizeTypes().visit(node) - if isinstance(node, (itir.FencilDefinition, itir.Program)): + if isinstance(node, itir.Program): assert all(isinstance(param.type, ts.DataType) for param in node.params), ( - "All parameters in 'itir.Program' and 'itir.FencilDefinition' must have a type " + "All parameters in 'itir.Program' must have a type " "defined, as they are the starting point for type propagation.", ) @@ -460,20 +460,6 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: ) return result - # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere - def visit_FencilDefinition(self, node: itir.FencilDefinition, *, ctx) -> it_ts.FencilType: - params: dict[str, ts.DataType] = {} - for param in node.params: - assert isinstance(param.type, ts.DataType) - params[param.id] = param.type - - function_definitions: dict[str, type_synthesizer.TypeSynthesizer] = {} - for fun_def in node.function_definitions: - function_definitions[fun_def.id] = self.visit(fun_def, ctx=ctx | function_definitions) - - closures = self.visit(node.closures, ctx=ctx | params | function_definitions) - return it_ts.FencilType(params=params, closures=closures) - def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: params: dict[str, ts.DataType] = {} for param in node.params: @@ -532,37 +518,6 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: and target_type.dtype == expr_type.dtype ) - # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere - def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.StencilClosureType: - domain: it_ts.DomainType = self.visit(node.domain, ctx=ctx) - inputs: list[ts.FieldType] = self.visit(node.inputs, ctx=ctx) - output: ts.FieldType = self.visit(node.output, ctx=ctx) - - assert isinstance(domain, it_ts.DomainType) - for output_el in type_info.primitive_constituents(output): - assert isinstance(output_el, ts.FieldType) - - stencil_type_synthesizer = self.visit(node.stencil, ctx=ctx) - stencil_args = [ - type_synthesizer._convert_as_fieldop_input_to_iterator(domain, input_) - for input_ in inputs - ] - stencil_returns = stencil_type_synthesizer( - *stencil_args, offset_provider_type=self.offset_provider_type - ) - - return it_ts.StencilClosureType( - domain=domain, - stencil=ts.FunctionType( - pos_only_args=stencil_args, - pos_or_kw_args={}, - kw_only_args={}, - returns=stencil_returns, - ), - output=output, - inputs=inputs, - ) - def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionType: assert ( node.value in self.dimensions diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index edb56f5659..eef8c75d0f 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -43,30 +43,6 @@ class IteratorType(ts.DataType, ts.CallableType): element_type: ts.DataType -@dataclasses.dataclass(frozen=True) -class StencilClosureType(ts.TypeSpec): - domain: DomainType - stencil: ts.FunctionType - output: ts.FieldType | ts.TupleType - inputs: list[ts.FieldType] - - def __post_init__(self): - # local import to avoid importing type_info from a type_specification module - from gt4py.next.type_system import type_info - - for i, el_type in enumerate(type_info.primitive_constituents(self.output)): - assert isinstance( - el_type, ts.FieldType - ), f"All constituent types must be field types, but the {i}-th element is of type '{el_type}'." - - -# TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere -@dataclasses.dataclass(frozen=True) -class FencilType(ts.TypeSpec): - params: dict[str, ts.DataType] - closures: list[StencilClosureType] - - @dataclasses.dataclass(frozen=True) class ProgramType(ts.TypeSpec): params: dict[str, ts.DataType] diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 85838d9c76..22326c7e87 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -26,9 +26,7 @@ SettingT_co = TypeVar("SettingT_co", bound=languages.LanguageSettings, covariant=True) -CompilableProgram: TypeAlias = toolchain.CompilableProgram[ - itir.FencilDefinition | itir.Program, arguments.CompileTimeArgs -] +CompilableProgram: TypeAlias = toolchain.CompilableProgram[itir.Program, arguments.CompileTimeArgs] @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index f1649112a7..020b1f55ea 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -10,7 +10,7 @@ import dataclasses import functools -from typing import Any, Callable, Final, Optional +from typing import Any, Final, Optional import factory import numpy as np @@ -53,9 +53,6 @@ class GTFNTranslationStep( use_imperative_backend: bool = False device_type: core_defs.DeviceType = core_defs.DeviceType.CPU symbolic_domain_sizes: Optional[dict[str, str]] = None - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: match self.device_type: @@ -80,7 +77,7 @@ def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSetting def _process_regular_arguments( self, - program: itir.FencilDefinition | itir.Program, + program: itir.Program, arg_types: tuple[ts.TypeSpec, ...], offset_provider_type: common.OffsetProviderType, ) -> tuple[list[interface.Parameter], list[str]]: @@ -157,7 +154,7 @@ def _process_connectivity_args( def _preprocess_program( self, - program: itir.FencilDefinition | itir.Program, + program: itir.Program, offset_provider: common.OffsetProvider, ) -> itir.Program: apply_common_transforms = functools.partial( @@ -167,7 +164,6 @@ def _preprocess_program( # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements unconditionally_collapse_tuples=True, symbolic_domain_sizes=self.symbolic_domain_sizes, - temporary_extraction_heuristics=self.temporary_extraction_heuristics, ) new_program = apply_common_transforms( @@ -186,7 +182,7 @@ def _preprocess_program( def generate_stencil_source( self, - program: itir.FencilDefinition | itir.Program, + program: itir.Program, offset_provider: common.OffsetProvider, column_axis: Optional[common.Dimension], ) -> str: @@ -214,7 +210,7 @@ def __call__( self, inp: stages.CompilableProgram ) -> stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings]: """Generate GTFN C++ code from the ITIR definition.""" - program: itir.FencilDefinition | itir.Program = inp.data + program: itir.Program = inp.data # handle regular parameters and arguments of the program (i.e. what the user defined in # the program) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index dc0012b041..d5b34fd5b9 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -108,7 +108,7 @@ def _get_gridtype(body: list[itir.Stmt]) -> common.GridType: grid_types = {_extract_grid_type(d) for d in domains} if len(grid_types) != 1: raise ValueError( - f"Found 'StencilClosures' with more than one 'GridType': '{grid_types}'. This is currently not supported." + f"Found 'set_at' with more than one 'GridType': '{grid_types}'. This is currently not supported." ) return grid_types.pop() diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index db1242e2a4..5f32eaa2bb 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -15,7 +15,7 @@ @program_formatter.program_formatter -def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: +def format_cpp(program: itir.Program, *args: Any, **kwargs: Any) -> str: # TODO(tehrengruber): This is a little ugly. Revisit. gtfn_translation = gtfn.GTFNBackendFactory().executor.translation assert isinstance(gtfn_translation, GTFNTranslationStep) diff --git a/src/gt4py/next/program_processors/formatters/lisp.py b/src/gt4py/next/program_processors/formatters/lisp.py deleted file mode 100644 index 0a8253595e..0000000000 --- a/src/gt4py/next/program_processors/formatters/lisp.py +++ /dev/null @@ -1,67 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from typing import Any - -from gt4py.eve.codegen import FormatTemplate as as_fmt, TemplatedGenerator -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import apply_common_transforms -from gt4py.next.program_processors import program_formatter - - -class ToLispLike(TemplatedGenerator): - Sym = as_fmt("{id}") - FunCall = as_fmt("({fun} {' '.join(args)})") - Literal = as_fmt("{value}") - OffsetLiteral = as_fmt("{value}") - SymRef = as_fmt("{id}") - StencilClosure = as_fmt( - """( - :domain {domain} - :stencil {stencil} - :output {output} - :inputs {' '.join(inputs)} - ) - """ - ) - FencilDefinition = as_fmt( - """ - ({' '.join(function_definitions)}) - (defen {id}({' '.join(params)}) - {''.join(closures)}) - """ - ) - FunctionDefinition = as_fmt( - """(defun {id}({' '.join(params)}) - {expr} - ) - -""" - ) - Lambda = as_fmt( - """(lambda ({' '.join(params)}) - {expr} - )""" - ) - - @classmethod - def apply(cls, root: itir.FencilDefinition, **kwargs: Any) -> str: # type: ignore[override] - transformed = apply_common_transforms(root, offset_provider=kwargs["offset_provider"]) - generated_code = super().apply(transformed, **kwargs) - try: - from yasi import indent_code - - indented = indent_code(generated_code, "--dialect lisp") - return "".join(indented["indented_code"]) - except ImportError: - return generated_code - - -@program_formatter.program_formatter -def format_lisp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: - return ToLispLike.apply(program, **kwargs) diff --git a/src/gt4py/next/program_processors/formatters/pretty_print.py b/src/gt4py/next/program_processors/formatters/pretty_print.py index f14ac5653f..cbf9fd1978 100644 --- a/src/gt4py/next/program_processors/formatters/pretty_print.py +++ b/src/gt4py/next/program_processors/formatters/pretty_print.py @@ -15,7 +15,7 @@ @program_formatter.program_formatter -def format_itir_and_check(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: +def format_itir_and_check(program: itir.Program, *args: Any, **kwargs: Any) -> str: pretty = pretty_printer.pformat(program) parsed = pretty_parser.pparse(pretty) assert parsed == program diff --git a/src/gt4py/next/program_processors/program_formatter.py b/src/gt4py/next/program_processors/program_formatter.py index f77e7f32ee..321c09668c 100644 --- a/src/gt4py/next/program_processors/program_formatter.py +++ b/src/gt4py/next/program_processors/program_formatter.py @@ -10,7 +10,7 @@ Interface for program processors. Program processors are functions which operate on a program paired with the input -arguments for the program. Programs are represented by an ``iterator.ir.itir.FencilDefinition`` +arguments for the program. Programs are represented by an ``iterator.ir.Program`` node. Program processors that execute the program with the given arguments (possibly by generating code along the way) are program executors. Those that generate any kind of string based on the program and (optionally) input values are program formatters. @@ -30,14 +30,14 @@ class ProgramFormatter: @abc.abstractmethod - def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: ... + def __call__(self, program: itir.Program, *args: Any, **kwargs: Any) -> str: ... @dataclasses.dataclass(frozen=True) class WrappedProgramFormatter(ProgramFormatter): formatter: Callable[..., str] - def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: + def __call__(self, program: itir.Program, *args: Any, **kwargs: Any) -> str: return self.formatter(program, *args, **kwargs) @@ -47,7 +47,7 @@ def program_formatter(func: Callable[..., str]) -> ProgramFormatter: Examples: >>> @program_formatter - ... def format_foo(fencil: itir.FencilDefinition, *args, **kwargs) -> str: + ... def format_foo(fencil: itir.Program, *args, **kwargs) -> str: ... '''A very useless fencil formatter.''' ... return "foo" diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 40d44f5ab0..a38a50d886 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -72,7 +72,7 @@ def __call__( self, inp: stages.CompilableProgram ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: """Generate DaCe SDFG file from the GTIR definition.""" - program: itir.FencilDefinition | itir.Program = inp.data + program: itir.Program = inp.data assert isinstance(program, itir.Program) sdfg = self.generate_sdfg( diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 55f479c665..c0a9be9168 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -125,7 +125,7 @@ def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str: Generates a unique hash string for a stencil source program representing the program, sorted offset_provider, and column_axis. """ - program: itir.FencilDefinition | itir.Program = inp.data + program: itir.Program = inp.data offset_provider: common.OffsetProvider = inp.args.offset_provider column_axis: Optional[common.Dimension] = inp.args.column_axis diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 25eda5a2ed..32c3f7a360 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -90,11 +90,11 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: def fencil_generator( - ir: itir.Program | itir.FencilDefinition, + ir: itir.Program, debug: bool, use_embedded: bool, offset_provider: common.OffsetProvider, - transforms: itir_transforms.ITIRTransform, + transforms: itir_transforms.GTIRTransform, ) -> stages.CompiledProgram: """ Generate a directly executable fencil from an ITIR node. @@ -197,7 +197,7 @@ class Roundtrip(workflow.Workflow[stages.CompilableProgram, stages.CompiledProgr debug: Optional[bool] = None use_embedded: bool = True dispatch_backend: Optional[next_backend.Backend] = None - transforms: itir_transforms.ITIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` + transforms: itir_transforms.GTIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` def __call__(self, inp: stages.CompilableProgram) -> stages.CompiledProgram: debug = config.DEBUG if self.debug is None else self.debug @@ -265,10 +265,10 @@ def decorated_fencil( gtir = next_backend.Backend( name="roundtrip_gtir", - executor=Roundtrip(transforms=itir_transforms.apply_fieldview_transforms), # type: ignore[arg-type] # on purpose doesn't support `FencilDefintion` will resolve itself later... + executor=Roundtrip(transforms=itir_transforms.apply_fieldview_transforms), # type: ignore[arg-type] # don't understand why mypy complains allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.Transforms( - past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), + past_to_itir=past_to_itir.past_to_gtir_factory(), foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(cached=True), field_view_op_to_prog=foast_to_past.operator_to_program_factory( foast_to_itir_step=foast_to_gtir.adapted_foast_to_gtir_factory() diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index 45bf7428a6..9e80dba53b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -21,7 +21,7 @@ ) -def test_program_itir_regression(cartesian_case): +def test_program_gtir_regression(cartesian_case): @gtx.field_operator(backend=None) def testee_op(a: cases.IField) -> cases.IField: return a @@ -30,8 +30,8 @@ def testee_op(a: cases.IField) -> cases.IField: def testee(a: cases.IField, out: cases.IField): testee_op(a, out=out) - assert isinstance(testee.itir, itir.Program) - assert isinstance(testee.with_backend(cartesian_case.backend).itir, itir.Program) + assert isinstance(testee.gtir, itir.Program) + assert isinstance(testee.with_backend(cartesian_case.backend).gtir, itir.Program) def test_frozen(cartesian_case): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 66c56c4827..7d2eec772c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -107,12 +107,12 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh def test_temporary_symbols(testee, mesh_descriptor): - itir_with_tmp = apply_common_transforms( - testee.itir, + gtir_with_tmp = apply_common_transforms( + testee.gtir, extract_temporaries=True, offset_provider=mesh_descriptor.offset_provider, ) params = ["num_vertices", "num_edges", "num_cells"] for param in params: - assert any([param == str(p) for p in itir_with_tmp.params]) + assert any([param == str(p) for p in gtir_with_tmp.params]) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 8f6d5787d3..03662f8dcc 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -58,7 +58,6 @@ def _program_processor(request) -> tuple[ProgramProcessor, bool]: (next_tests.definitions.ProgramBackendId.GTFN_CPU, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), # pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation - (next_tests.definitions.ProgramFormatterId.LISP_FORMATTER, False), (next_tests.definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), (next_tests.definitions.ProgramFormatterId.GTFN_CPP_FORMATTER, False), ], diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py deleted file mode 100644 index c102df9d57..0000000000 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ /dev/null @@ -1,598 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -# TODO(tehrengruber): The style of the tests in this file is not optimal as a single change in the -# lowering can (and often does) make all of them fail. Once we have embedded field view we want to -# switch to executing the different cases here; once with a regular backend (i.e. including -# parsing) and then with embedded field view (i.e. no parsing). If the results match the lowering -# should be correct. - -from __future__ import annotations - -from types import SimpleNamespace - -import pytest - -import gt4py.next as gtx -from gt4py.next import float32, float64, int32, int64, neighbor_sum -from gt4py.next.ffront import type_specifications as ts_ffront -from gt4py.next.ffront.ast_passes import single_static_assign as ssa -from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering -from gt4py.next.ffront.func_to_foast import FieldOperatorParser -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.type_system import type_specifications as ts, type_translation -from gt4py.next.iterator.type_system import type_specifications as it_ts - - -IDim = gtx.Dimension("IDim") -Edge = gtx.Dimension("Edge") -Vertex = gtx.Dimension("Vertex") -Cell = gtx.Dimension("Cell") -V2EDim = gtx.Dimension("V2E", gtx.DimensionKind.LOCAL) -V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) -TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. - - -def debug_itir(tree): - """Compare tree snippets while debugging.""" - from devtools import debug - - from gt4py.eve.codegen import format_python_source - from gt4py.next.program_processors import EmbeddedDSL - - debug(format_python_source(EmbeddedDSL.apply(tree))) - - -def test_copy(): - def copy_field(inp: gtx.Field[[TDim], float64]): - return inp - - parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - - assert lowered.id == "copy_field" - assert lowered.expr == im.ref("inp") - - -def test_scalar_arg(): - def scalar_arg(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], int64]: - return alpha * bar - - parsed = FieldOperatorParser.apply_to_function(scalar_arg) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("multiplies")( - "alpha", "bar" - ) # no difference to non-scalar arg - - assert lowered.expr == reference - - -def test_multicopy(): - def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): - return inp1, inp2 - - parsed = FieldOperatorParser.apply_to_function(multicopy) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple("inp1", "inp2") - - assert lowered.expr == reference - - -def test_arithmetic(): - def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): - return inp1 + inp2 - - parsed = FieldOperatorParser.apply_to_function(arithmetic) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("plus")("inp1", "inp2") - - assert lowered.expr == reference - - -def test_shift(): - Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) - - def shift_by_one(inp: gtx.Field[[IDim], float64]): - return inp(Ioff[1]) - - parsed = FieldOperatorParser.apply_to_function(shift_by_one) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", 1)("it"))))("inp") - - assert lowered.expr == reference - - -def test_negative_shift(): - Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) - - def shift_by_one(inp: gtx.Field[[IDim], float64]): - return inp(Ioff[-1]) - - parsed = FieldOperatorParser.apply_to_function(shift_by_one) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", -1)("it"))))("inp") - - assert lowered.expr == reference - - -def test_temp_assignment(): - def copy_field(inp: gtx.Field[[TDim], float64]): - tmp = inp - inp = tmp - tmp2 = inp - return tmp2 - - parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let(ssa.unique_name("tmp", 0), "inp")( - im.let( - ssa.unique_name("inp", 0), - ssa.unique_name("tmp", 0), - )( - im.let( - ssa.unique_name("tmp2", 0), - ssa.unique_name("inp", 0), - )(ssa.unique_name("tmp2", 0)) - ) - ) - - assert lowered.expr == reference - - -def test_unary_ops(): - def unary(inp: gtx.Field[[TDim], float64]): - tmp = +inp - tmp = -tmp - return tmp - - parsed = FieldOperatorParser.apply_to_function(unary) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let( - ssa.unique_name("tmp", 0), - im.promote_to_lifted_stencil("plus")( - im.promote_to_const_iterator(im.literal("0", "float64")), "inp" - ), - )( - im.let( - ssa.unique_name("tmp", 1), - im.promote_to_lifted_stencil("minus")( - im.promote_to_const_iterator(im.literal("0", "float64")), ssa.unique_name("tmp", 0) - ), - )(ssa.unique_name("tmp", 1)) - ) - - assert lowered.expr == reference - - -@pytest.mark.parametrize("var, var_type", [("-1.0", "float64"), ("True", "bool")]) -def test_unary_op_type_conversion(var, var_type): - def unary_float(): - return float(-1) - - def unary_bool(): - return bool(-1) - - fun = unary_bool if var_type == "bool" else unary_float - parsed = FieldOperatorParser.apply_to_function(fun) - lowered = FieldOperatorLowering.apply(parsed) - reference = im.promote_to_const_iterator(im.literal(var, var_type)) - - assert lowered.expr == reference - - -def test_unpacking(): - """Unpacking assigns should get separated.""" - - def unpacking( - inp1: gtx.Field[[TDim], float64], inp2: gtx.Field[[TDim], float64] - ) -> gtx.Field[[TDim], float64]: - tmp1, tmp2 = inp1, inp2 # noqa - return tmp1 - - parsed = FieldOperatorParser.apply_to_function(unpacking) - lowered = FieldOperatorLowering.apply(parsed) - - tuple_expr = im.make_tuple("inp1", "inp2") - tuple_access_0 = im.tuple_get(0, "__tuple_tmp_0") - tuple_access_1 = im.tuple_get(1, "__tuple_tmp_0") - - reference = im.let("__tuple_tmp_0", tuple_expr)( - im.let( - ssa.unique_name("tmp1", 0), - tuple_access_0, - )( - im.let( - ssa.unique_name("tmp2", 0), - tuple_access_1, - )(ssa.unique_name("tmp1", 0)) - ) - ) - - assert lowered.expr == reference - - -def test_annotated_assignment(): - pytest.xfail("Annotated assignments are not properly supported at the moment.") - - def copy_field(inp: gtx.Field[[TDim], float64]): - tmp: gtx.Field[[TDim], float64] = inp - return tmp - - parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let(ssa.unique_name("tmp", 0), "inp")(ssa.unique_name("tmp", 0)) - - assert lowered.expr == reference - - -def test_call(): - # create something that appears to the lowering like a field operator. - # we could also create an actual field operator, but we want to avoid - # using such heavy constructs for testing the lowering. - field_type = type_translation.from_type_hint(gtx.Field[[TDim], float64]) - identity = SimpleNamespace( - __gt_type__=lambda: ts_ffront.FieldOperatorType( - definition=ts.FunctionType( - pos_only_args=[field_type], pos_or_kw_args={}, kw_only_args={}, returns=field_type - ) - ) - ) - - def call(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: - return identity(inp) - - parsed = FieldOperatorParser.apply_to_function(call) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.call("identity")("inp") - - assert lowered.expr == reference - - -def test_temp_tuple(): - """Returning a temp tuple should work.""" - - def temp_tuple(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], int64]): - tmp = a, b - return tmp - - parsed = FieldOperatorParser.apply_to_function(temp_tuple) - lowered = FieldOperatorLowering.apply(parsed) - - tuple_expr = im.make_tuple("a", "b") - reference = im.let(ssa.unique_name("tmp", 0), tuple_expr)(ssa.unique_name("tmp", 0)) - - assert lowered.expr == reference - - -def test_unary_not(): - def unary_not(cond: gtx.Field[[TDim], "bool"]): - return not cond - - parsed = FieldOperatorParser.apply_to_function(unary_not) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("not_")("cond") - - assert lowered.expr == reference - - -def test_binary_plus(): - def plus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a + b - - parsed = FieldOperatorParser.apply_to_function(plus) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("plus")("a", "b") - - assert lowered.expr == reference - - -def test_add_scalar_literal_to_field(): - def scalar_plus_field(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: - return 2.0 + a - - parsed = FieldOperatorParser.apply_to_function(scalar_plus_field) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("plus")( - im.promote_to_const_iterator(im.literal("2.0", "float64")), "a" - ) - - assert lowered.expr == reference - - -def test_add_scalar_literals(): - def scalar_plus_scalar(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int32"]: - tmp = int32(1) + int32("1") - return a + tmp - - parsed = FieldOperatorParser.apply_to_function(scalar_plus_scalar) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let( - ssa.unique_name("tmp", 0), - im.promote_to_lifted_stencil("plus")( - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int32")), - ), - )(im.promote_to_lifted_stencil("plus")("a", ssa.unique_name("tmp", 0))) - - assert lowered.expr == reference - - -def test_binary_mult(): - def mult(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a * b - - parsed = FieldOperatorParser.apply_to_function(mult) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("multiplies")("a", "b") - - assert lowered.expr == reference - - -def test_binary_minus(): - def minus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a - b - - parsed = FieldOperatorParser.apply_to_function(minus) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("minus")("a", "b") - - assert lowered.expr == reference - - -def test_binary_div(): - def division(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a / b - - parsed = FieldOperatorParser.apply_to_function(division) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("divides")("a", "b") - - assert lowered.expr == reference - - -def test_binary_and(): - def bit_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): - return a & b - - parsed = FieldOperatorParser.apply_to_function(bit_and) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("and_")("a", "b") - - assert lowered.expr == reference - - -def test_scalar_and(): - def scalar_and(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]: - return a & False - - parsed = FieldOperatorParser.apply_to_function(scalar_and) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("and_")( - "a", im.promote_to_const_iterator(im.literal("False", "bool")) - ) - - assert lowered.expr == reference - - -def test_binary_or(): - def bit_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): - return a | b - - parsed = FieldOperatorParser.apply_to_function(bit_or) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("or_")("a", "b") - - assert lowered.expr == reference - - -def test_compare_scalars(): - def comp_scalars() -> bool: - return 3 > 4 - - parsed = FieldOperatorParser.apply_to_function(comp_scalars) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("greater")( - im.promote_to_const_iterator(im.literal("3", "int32")), - im.promote_to_const_iterator(im.literal("4", "int32")), - ) - - assert lowered.expr == reference - - -def test_compare_gt(): - def comp_gt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a > b - - parsed = FieldOperatorParser.apply_to_function(comp_gt) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("greater")("a", "b") - - assert lowered.expr == reference - - -def test_compare_lt(): - def comp_lt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a < b - - parsed = FieldOperatorParser.apply_to_function(comp_lt) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("less")("a", "b") - - assert lowered.expr == reference - - -def test_compare_eq(): - def comp_eq(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]): - return a == b - - parsed = FieldOperatorParser.apply_to_function(comp_eq) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("eq")("a", "b") - - assert lowered.expr == reference - - -def test_compare_chain(): - def compare_chain( - a: gtx.Field[[IDim], float64], b: gtx.Field[[IDim], float64], c: gtx.Field[[IDim], float64] - ) -> gtx.Field[[IDim], bool]: - return a > b > c - - parsed = FieldOperatorParser.apply_to_function(compare_chain) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("and_")( - im.promote_to_lifted_stencil("greater")("a", "b"), - im.promote_to_lifted_stencil("greater")("b", "c"), - ) - - assert lowered.expr == reference - - -def test_reduction_lowering_simple(): - def reduction(edge_f: gtx.Field[[Edge], float64]): - return neighbor_sum(edge_f(V2E), axis=V2EDim) - - parsed = FieldOperatorParser.apply_to_function(reduction) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil( - im.call( - im.call("reduce")( - "plus", - im.deref(im.promote_to_const_iterator(im.literal(value="0", typename="float64"))), - ) - ) - )(im.lifted_neighbors("V2E", "edge_f")) - - assert lowered.expr == reference - - -def test_reduction_lowering_expr(): - def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64]): - e1_nbh = e1(V2E) - return neighbor_sum(1.1 * (e1_nbh + e2), axis=V2EDim) - - parsed = FieldOperatorParser.apply_to_function(reduction) - lowered = FieldOperatorLowering.apply(parsed) - - mapped = im.promote_to_lifted_stencil(im.map_("multiplies"))( - im.promote_to_lifted_stencil("make_const_list")( - im.promote_to_const_iterator(im.literal("1.1", "float64")) - ), - im.promote_to_lifted_stencil(im.map_("plus"))(ssa.unique_name("e1_nbh", 0), "e2"), - ) - - reference = im.let( - ssa.unique_name("e1_nbh", 0), - im.lifted_neighbors("V2E", "e1"), - )( - im.promote_to_lifted_stencil( - im.call( - im.call("reduce")( - "plus", - im.deref( - im.promote_to_const_iterator(im.literal(value="0", typename="float64")) - ), - ) - ) - )(mapped) - ) - - assert lowered.expr == reference - - -def test_builtin_int_constructors(): - def int_constrs() -> tuple[int32, int32, int64, int32, int64]: - return 1, int32(1), int64(1), int32("1"), int64("1") - - parsed = FieldOperatorParser.apply_to_function(int_constrs) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple( - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int64")), - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int64")), - ) - - assert lowered.expr == reference - - -def test_builtin_float_constructors(): - def float_constrs() -> tuple[float, float, float32, float64, float, float32, float64]: - return ( - 0.1, - float(0.1), - float32(0.1), - float64(0.1), - float(".1"), - float32(".1"), - float64(".1"), - ) - - parsed = FieldOperatorParser.apply_to_function(float_constrs) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple( - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float32")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float32")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - ) - - assert lowered.expr == reference - - -def test_builtin_bool_constructors(): - def bool_constrs() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: - return True, False, bool(True), bool(False), bool(0), bool(5), bool("True"), bool("False") - - parsed = FieldOperatorParser.apply_to_function(bool_constrs) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple( - im.promote_to_const_iterator(im.literal(str(True), "bool")), - im.promote_to_const_iterator(im.literal(str(False), "bool")), - im.promote_to_const_iterator(im.literal(str(True), "bool")), - im.promote_to_const_iterator(im.literal(str(False), "bool")), - im.promote_to_const_iterator(im.literal(str(bool(0)), "bool")), - im.promote_to_const_iterator(im.literal(str(bool(5)), "bool")), - im.promote_to_const_iterator(im.literal(str(bool("True")), "bool")), - im.promote_to_const_iterator(im.literal(str(bool("False")), "bool")), - ) - - assert lowered.expr == reference diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py index a6231c22a7..c813285bd0 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py @@ -46,7 +46,6 @@ def test_copy_lowering(copy_program_def, gtir_identity_fundef): past_node, function_definitions=[gtir_identity_fundef], grid_type=gtx.GridType.CARTESIAN, - to_gtir=True, ) set_at_pattern = P( itir.SetAt, @@ -93,7 +92,6 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) past_node, function_definitions=[gtir_identity_fundef], grid_type=gtx.GridType.CARTESIAN, - to_gtir=True, ) set_at_pattern = P( itir.SetAt, @@ -149,9 +147,7 @@ def tuple_program( make_tuple_op(inp, out=(out1[1:], out2[1:])) parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply( - parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN, to_gtir=True - ) + ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) @pytest.mark.xfail( @@ -166,9 +162,7 @@ def tuple_program( make_tuple_op(inp, out=(out1[1:], out2)) parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply( - parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN, to_gtir=True - ) + ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) @pytest.mark.xfail @@ -194,7 +188,6 @@ def test_invalid_call_sig_program(invalid_call_sig_program_def): ProgramParser.apply_to_function(invalid_call_sig_program_def), function_definitions=[], grid_type=gtx.GridType.CARTESIAN, - to_gtir=True, ) assert exc_info.match("Invalid call to 'identity'") diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py deleted file mode 100644 index fefd3c653b..0000000000 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ /dev/null @@ -1,214 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import re - -import pytest - -import gt4py.eve as eve -import gt4py.next as gtx -from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next import errors -from gt4py.next.ffront.func_to_past import ProgramParser -from gt4py.next.ffront.past_to_itir import ProgramLowering -from gt4py.next.iterator import ir as itir -from gt4py.next.type_system import type_specifications as ts - -from next_tests.past_common_fixtures import ( - IDim, - copy_program_def, - copy_restrict_program_def, - float64, - identity_def, - invalid_call_sig_program_def, -) - - -@pytest.fixture -def itir_identity_fundef(): - return itir.FunctionDefinition( - id="identity", - params=[itir.Sym(id="x")], - expr=itir.FunCall(fun=itir.SymRef(id="deref"), args=[itir.SymRef(id="x")]), - ) - - -def test_copy_lowering(copy_program_def, itir_identity_fundef): - past_node = ProgramParser.apply_to_function(copy_program_def) - itir_node = ProgramLowering.apply( - past_node, function_definitions=[itir_identity_fundef], grid_type=gtx.GridType.CARTESIAN - ) - closure_pattern = P( - itir.StencilClosure, - domain=P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("cartesian_domain")), - args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), - args=[ - P(itir.AxisLiteral, value="IDim"), - P(itir.Literal, value="0", type=ts.ScalarType(kind=ts.ScalarKind.INT32)), - P(itir.SymRef, id=eve.SymbolRef("__out_size_0")), - ], - ) - ], - ), - stencil=P( - itir.Lambda, - params=[P(itir.Sym, id=eve.SymbolName("__stencil_arg0"))], - expr=P( - itir.FunCall, - fun=P( - itir.Lambda, - params=[P(itir.Sym)], - expr=P(itir.FunCall, fun=P(itir.SymRef, id=eve.SymbolRef("deref"))), - ), - args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("identity")), - args=[P(itir.SymRef, id=eve.SymbolRef("__stencil_arg0"))], - ) - ], - ), - ), - inputs=[P(itir.SymRef, id=eve.SymbolRef("in_field"))], - output=P(itir.SymRef, id=eve.SymbolRef("out")), - ) - fencil_pattern = P( - itir.FencilDefinition, - id=eve.SymbolName("copy_program"), - params=[ - P(itir.Sym, id=eve.SymbolName("in_field")), - P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_size_0")), - P(itir.Sym, id=eve.SymbolName("__out_size_0")), - ], - closures=[closure_pattern], - ) - - fencil_pattern.match(itir_node, raise_exception=True) - - -def test_copy_restrict_lowering(copy_restrict_program_def, itir_identity_fundef): - past_node = ProgramParser.apply_to_function(copy_restrict_program_def) - itir_node = ProgramLowering.apply( - past_node, function_definitions=[itir_identity_fundef], grid_type=gtx.GridType.CARTESIAN - ) - closure_pattern = P( - itir.StencilClosure, - domain=P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("cartesian_domain")), - args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), - args=[ - P(itir.AxisLiteral, value="IDim"), - P( - itir.Literal, - value="1", - type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) - ), - ), - P( - itir.Literal, - value="2", - type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) - ), - ), - ], - ) - ], - ), - ) - fencil_pattern = P( - itir.FencilDefinition, - id=eve.SymbolName("copy_restrict_program"), - params=[ - P(itir.Sym, id=eve.SymbolName("in_field")), - P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_size_0")), - P(itir.Sym, id=eve.SymbolName("__out_size_0")), - ], - closures=[closure_pattern], - ) - - fencil_pattern.match(itir_node, raise_exception=True) - - -def test_tuple_constructed_in_out_with_slicing(make_tuple_op): - def tuple_program( - inp: gtx.Field[[IDim], float64], - out1: gtx.Field[[IDim], float64], - out2: gtx.Field[[IDim], float64], - ): - make_tuple_op(inp, out=(out1[1:], out2[1:])) - - parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) - - -@pytest.mark.xfail( - reason="slicing is only allowed if all fields are sliced in the same way." -) # see ADR 10 -def test_tuple_constructed_in_out_with_slicing(make_tuple_op): - def tuple_program( - inp: gtx.Field[[IDim], float64], - out1: gtx.Field[[IDim], float64], - out2: gtx.Field[[IDim], float64], - ): - make_tuple_op(inp, out=(out1[1:], out2)) - - parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) - - -@pytest.mark.xfail -def test_inout_prohibited(identity_def): - identity = gtx.field_operator(identity_def) - - def inout_field_program(inout_field: gtx.Field[[IDim], "float64"]): - identity(inout_field, out=inout_field) - - with pytest.raises( - ValueError, match=(r"Call to function with field as input and output not allowed.") - ): - ProgramLowering.apply( - ProgramParser.apply_to_function(inout_field_program), - function_definitions=[], - grid_type=gtx.GridType.CARTESIAN, - ) - - -def test_invalid_call_sig_program(invalid_call_sig_program_def): - with pytest.raises(errors.DSLError) as exc_info: - ProgramLowering.apply( - ProgramParser.apply_to_function(invalid_call_sig_program_def), - function_definitions=[], - grid_type=gtx.GridType.CARTESIAN, - ) - - assert exc_info.match("Invalid call to 'identity'") - # TODO(tehrengruber): re-enable again when call signature check doesn't return - # immediately after missing `out` argument - # assert ( - # re.search( - # "Function takes 1 arguments, but 2 were given.", exc_info.value.__cause__.args[0] - # ) - # is not None - # ) - assert ( - re.search(r"Missing required keyword argument 'out'", exc_info.value.__cause__.args[0]) - is not None - ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 817c06e8f0..2492fc446d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -8,21 +8,24 @@ # TODO(SF-N): test scan operator -import pytest +from typing import Iterable, Literal, Optional, Union + import numpy as np -from typing import Iterable, Optional, Literal, Union +import pytest from gt4py import eve -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next import constructors +from gt4py.next import common, constructors, utils +from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.transforms import infer_domain -from gt4py.next.iterator.ir_utils import domain_utils -from gt4py.next.common import Dimension -from gt4py.next import common -from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding -from gt4py.next import utils +from gt4py.next.type_system import type_specifications as ts + float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py deleted file mode 100644 index 407ccad924..0000000000 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py +++ /dev/null @@ -1,68 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.prune_closure_inputs import PruneClosureInputs - - -def test_simple(): - testee = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="x"), ir.Sym(id="y"), ir.Sym(id="z")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="y")]), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="foo"), ir.SymRef(id="bar"), ir.SymRef(id="baz")], - ) - expected = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="y")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="y")]), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="bar")], - ) - actual = PruneClosureInputs().visit(testee) - assert actual == expected - - -def test_shadowing(): - testee = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="x"), ir.Sym(id="y"), ir.Sym(id="z")], - expr=ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="z")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="z")]), - ), - args=[ir.SymRef(id="y")], - ), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="foo"), ir.SymRef(id="bar"), ir.SymRef(id="baz")], - ) - expected = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="y")], - expr=ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="z")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="z")]), - ), - args=[ir.SymRef(id="y")], - ), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="bar")], - ) - actual = PruneClosureInputs().visit(testee) - assert actual == expected diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index b1ba4ccf22..03b8e3bc15 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -36,6 +36,7 @@ from . import pytestmark + dace_backend = pytest.importorskip("gt4py.next.program_processors.runners.dace_fieldview") From ae6296546d91f41e40451403c3560b1744d467cc Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 6 Dec 2024 21:15:55 +0100 Subject: [PATCH 076/178] feat[next]: Inline dynamic shifts (#1738) Dynamic shifts are not supported in the domain inference. In order to make them work nonetheless this PR aggressively inlines all arguments to `as_fieldop` until they contain only references to `itir.Program` params. Additionally the domain inference is extended to tolerate such `as_fieldop` by introducing a special domain marker that signifies a domain is unknown. --------- Co-authored-by: Hannes Vogt Co-authored-by: Edoardo Paone --- .../iterator/transforms/fuse_as_fieldop.py | 209 ++++++++------ .../next/iterator/transforms/global_tmps.py | 4 +- .../next/iterator/transforms/infer_domain.py | 272 +++++++++++------- .../transforms/inline_dynamic_shifts.py | 73 +++++ .../next/iterator/transforms/pass_manager.py | 7 + tests/next_tests/definitions.py | 1 - .../test_inline_dynamic_shifts.py | 48 ++++ .../transforms_tests/test_domain_inference.py | 115 +++++--- 8 files changed, 492 insertions(+), 237 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 9076bf2d3f..e8a221b814 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -53,7 +53,7 @@ def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: if cpm.is_ref_to(stencil, "deref"): stencil = im.lambda_("arg")(im.deref("arg")) new_expr = im.as_fieldop(stencil, domain)(*expr.args) - type_inference.copy_type(from_=expr, to=new_expr) + type_inference.copy_type(from_=expr, to=new_expr, allow_untyped=True) return new_expr @@ -68,6 +68,107 @@ def _is_tuple_expr_of_literals(expr: itir.Expr): return isinstance(expr, itir.Literal) +def _inline_as_fieldop_arg( + arg: itir.Expr, *, uids: eve_utils.UIDGenerator +) -> tuple[itir.Expr, dict[str, itir.Expr]]: + assert cpm.is_applied_as_fieldop(arg) + arg = _canonicalize_as_fieldop(arg) + + stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` + inner_args: list[itir.Expr] = arg.args + extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg + + stencil_params: list[itir.Sym] = [] + stencil_body: itir.Expr = stencil.expr + + for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): + if isinstance(inner_arg, itir.SymRef): + stencil_params.append(inner_param) + extracted_args[inner_arg.id] = inner_arg + elif isinstance(inner_arg, itir.Literal): + # note: only literals, not all scalar expressions are required as it doesn't make sense + # for them to be computed per grid point. + stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( + stencil_body + ) + else: + # a scalar expression, a previously not inlined `as_fieldop` call or an opaque + # expression e.g. containing a tuple + stencil_params.append(inner_param) + new_outer_stencil_param = uids.sequential_id(prefix="__iasfop") + extracted_args[new_outer_stencil_param] = inner_arg + + return im.lift(im.lambda_(*stencil_params)(stencil_body))( + *extracted_args.keys() + ), extracted_args + + +def fuse_as_fieldop( + expr: itir.Expr, eligible_args: list[bool], *, uids: eve_utils.UIDGenerator +) -> itir.Expr: + assert cpm.is_applied_as_fieldop(expr) and isinstance(expr.fun.args[0], itir.Lambda) # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + + stencil: itir.Lambda = expr.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + + args: list[itir.Expr] = expr.args + + new_args: dict[str, itir.Expr] = {} + new_stencil_body: itir.Expr = stencil.expr + + for eligible, stencil_param, arg in zip(eligible_args, stencil.params, args, strict=True): + if eligible: + if cpm.is_applied_as_fieldop(arg): + pass + elif cpm.is_call_to(arg, "if_"): + # TODO(tehrengruber): revisit if we want to inline if_ + type_ = arg.type + arg = im.op_as_fieldop("if_")(*arg.args) + arg.type = type_ + elif _is_tuple_expr_of_literals(arg): + arg = im.op_as_fieldop(im.lambda_()(arg))() + else: + raise NotImplementedError() + + inline_expr, extracted_args = _inline_as_fieldop_arg(arg, uids=uids) + + new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body) + + new_args = _merge_arguments(new_args, extracted_args) + else: + # just a safety check if typing information is available + if arg.type and not isinstance(arg.type, ts.DeferredType): + assert isinstance(arg.type, ts.TypeSpec) + dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) + assert not isinstance(dtype, it_ts.ListType) + new_param: str + if isinstance( + arg, itir.SymRef + ): # use name from outer scope (optional, just to get a nice IR) + new_param = arg.id + new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body) + else: + new_param = stencil_param.id + new_args = _merge_arguments(new_args, {new_param: arg}) + + new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( + *new_args.values() + ) + + # simplify stencil directly to keep the tree small + new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( + new_node + ) # to keep the tree small + new_node = inline_lambdas.InlineLambdas.apply( + new_node, opcount_preserving=True, force_inline_lift_args=True + ) + new_node = inline_lifts.InlineLifts().visit(new_node) + + type_inference.copy_type(from_=expr, to=new_node, allow_untyped=True) + + return new_node + + @dataclasses.dataclass class FuseAsFieldOp(eve.NodeTranslator): """ @@ -98,38 +199,6 @@ class FuseAsFieldOp(eve.NodeTranslator): uids: eve_utils.UIDGenerator - def _inline_as_fieldop_arg(self, arg: itir.Expr) -> tuple[itir.Expr, dict[str, itir.Expr]]: - assert cpm.is_applied_as_fieldop(arg) - arg = _canonicalize_as_fieldop(arg) - - stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` - inner_args: list[itir.Expr] = arg.args - extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg - - stencil_params: list[itir.Sym] = [] - stencil_body: itir.Expr = stencil.expr - - for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): - if isinstance(inner_arg, itir.SymRef): - stencil_params.append(inner_param) - extracted_args[inner_arg.id] = inner_arg - elif isinstance(inner_arg, itir.Literal): - # note: only literals, not all scalar expressions are required as it doesn't make sense - # for them to be computed per grid point. - stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( - stencil_body - ) - else: - # a scalar expression, a previously not inlined `as_fieldop` call or an opaque - # expression e.g. containing a tuple - stencil_params.append(inner_param) - new_outer_stencil_param = self.uids.sequential_id(prefix="__iasfop") - extracted_args[new_outer_stencil_param] = inner_arg - - return im.lift(im.lambda_(*stencil_params)(stencil_body))( - *extracted_args.keys() - ), extracted_args - @classmethod def apply( cls, @@ -158,72 +227,26 @@ def visit_FunCall(self, node: itir.FunCall): if cpm.is_call_to(node.fun, "as_fieldop") and isinstance(node.fun.args[0], itir.Lambda): stencil: itir.Lambda = node.fun.args[0] - domain = node.fun.args[1] if len(node.fun.args) > 1 else None - - shifts = trace_shifts.trace_stencil(stencil) - args: list[itir.Expr] = node.args + shifts = trace_shifts.trace_stencil(stencil) - new_args: dict[str, itir.Expr] = {} - new_stencil_body: itir.Expr = stencil.expr - - for stencil_param, arg, arg_shifts in zip(stencil.params, args, shifts, strict=True): + eligible_args = [] + for arg, arg_shifts in zip(args, shifts, strict=True): assert isinstance(arg.type, ts.TypeSpec) dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) # TODO(tehrengruber): make this configurable - should_inline = _is_tuple_expr_of_literals(arg) or ( - isinstance(arg, itir.FunCall) - and ( - cpm.is_call_to(arg.fun, "as_fieldop") - and isinstance(arg.fun.args[0], itir.Lambda) - or cpm.is_call_to(arg, "if_") + eligible_args.append( + _is_tuple_expr_of_literals(arg) + or ( + isinstance(arg, itir.FunCall) + and ( + cpm.is_call_to(arg.fun, "as_fieldop") + and isinstance(arg.fun.args[0], itir.Lambda) + or cpm.is_call_to(arg, "if_") + ) + and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) ) - and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) ) - if should_inline: - if cpm.is_applied_as_fieldop(arg): - pass - elif cpm.is_call_to(arg, "if_"): - # TODO(tehrengruber): revisit if we want to inline if_ - type_ = arg.type - arg = im.op_as_fieldop("if_")(*arg.args) - arg.type = type_ - elif _is_tuple_expr_of_literals(arg): - arg = im.op_as_fieldop(im.lambda_()(arg))() - else: - raise NotImplementedError() - - inline_expr, extracted_args = self._inline_as_fieldop_arg(arg) - - new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body) - - new_args = _merge_arguments(new_args, extracted_args) - else: - assert not isinstance(dtype, it_ts.ListType) - new_param: str - if isinstance( - arg, itir.SymRef - ): # use name from outer scope (optional, just to get a nice IR) - new_param = arg.id - new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body) - else: - new_param = stencil_param.id - new_args = _merge_arguments(new_args, {new_param: arg}) - - new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( - *new_args.values() - ) - - # simplify stencil directly to keep the tree small - new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( - new_node - ) # to keep the tree small - new_node = inline_lambdas.InlineLambdas.apply( - new_node, opcount_preserving=True, force_inline_lift_args=True - ) - new_node = inline_lifts.InlineLifts().visit(new_node) - - type_inference.copy_type(from_=node, to=new_node) - return new_node + return fuse_as_fieldop(node, eligible_args, uids=self.uids) return node diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index a6d39883e3..334fb330d7 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -74,7 +74,7 @@ def _transform_by_pattern( # or a tuple thereof) # - one `SetAt` statement that materializes the expression into the temporary for tmp_sym, tmp_expr in extracted_fields.items(): - domain = tmp_expr.annex.domain + domain: infer_domain.DomainAccess = tmp_expr.annex.domain # TODO(tehrengruber): Implement. This happens when the expression is a combination # of an `if_` call with a tuple, e.g., `if_(cond, {a, b}, {c, d})`. As long as we are @@ -186,7 +186,7 @@ def create_global_tmps( This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its arguments into temporaries. """ - program = infer_domain.infer_program(program, offset_provider) + program = infer_domain.infer_program(program, offset_provider=offset_provider) program = type_inference.infer( program, offset_provider_type=common.offset_provider_to_type(offset_provider) ) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 6852b47a7a..f26d3f9ec2 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -10,10 +10,10 @@ import itertools import typing -from typing import Callable, Optional, TypeAlias from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.eve.extended_typing import Callable, Optional, TypeAlias, Unpack from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ( @@ -25,8 +25,35 @@ from gt4py.next.utils import flatten_nested_tuple, tree_map -DOMAIN: TypeAlias = domain_utils.SymbolicDomain | None | tuple["DOMAIN", ...] -ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN] +class DomainAccessDescriptor(eve.StrEnum): + """ + Descriptor for domains that could not be inferred. + """ + + # TODO(tehrengruber): Revisit this concept. It is strange that we don't have a descriptor + # `KNOWN`, but since we don't need it, it wasn't added. + + #: The access is unknown because of a dynamic shift.whose extent is not known. + #: E.g.: `(⇑(λ(arg0, arg1) → ·⟪Ioffₒ, ·arg1⟫(arg0)))(in_field1, in_field2)` + UNKNOWN = "unknown" + #: The domain is never accessed. + #: E.g.: `{in_field1, in_field2}[0]` + NEVER = "never" + + +NonTupleDomainAccess: TypeAlias = domain_utils.SymbolicDomain | DomainAccessDescriptor +#: The domain can also be a tuple of domains, usually this only occurs for scan operators returning +#: a tuple since other occurrences for tuples are removed before domain inference. This is +#: however not a requirement of the pass and `make_tuple(vertex_field, edge_field)` infers just +#: fine to a tuple of a vertex and an edge domain. +DomainAccess: TypeAlias = NonTupleDomainAccess | tuple["DomainAccess", ...] +AccessedDomains: TypeAlias = dict[str, DomainAccess] + + +class InferenceOptions(typing.TypedDict): + offset_provider: common.OffsetProvider + symbolic_domain_sizes: Optional[dict[str, str]] + allow_uninferred: bool class DomainAnnexDebugger(eve.NodeVisitor): @@ -57,43 +84,58 @@ def _split_dict_by_key(pred: Callable, d: dict): # TODO(tehrengruber): Revisit whether we want to move this behaviour to `domain_utils.domain_union`. -def _domain_union_with_none( - *domains: domain_utils.SymbolicDomain | None, -) -> domain_utils.SymbolicDomain | None: - filtered_domains: list[domain_utils.SymbolicDomain] = [d for d in domains if d is not None] +def _domain_union( + *domains: domain_utils.SymbolicDomain | DomainAccessDescriptor, +) -> domain_utils.SymbolicDomain | DomainAccessDescriptor: + if any(d == DomainAccessDescriptor.UNKNOWN for d in domains): + return DomainAccessDescriptor.UNKNOWN + + filtered_domains: list[domain_utils.SymbolicDomain] = [ + d # type: ignore[misc] # domain can never be unknown as these cases are filtered above + for d in domains + if d != DomainAccessDescriptor.NEVER + ] if len(filtered_domains) == 0: - return None + return DomainAccessDescriptor.NEVER return domain_utils.domain_union(*filtered_domains) -def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMAIN]: +def _canonicalize_domain_structure( + d1: DomainAccess, d2: DomainAccess +) -> tuple[DomainAccess, DomainAccess]: """ Given two domains or composites thereof, canonicalize their structure. If one of the arguments is a tuple the other one will be promoted to a tuple of same structure - unless it already is a tuple. Missing values are replaced by None, meaning no domain is - specified. + unless it already is a tuple. Missing values are filled by :ref:`DomainAccessDescriptor.NEVER`. >>> domain = im.domain(common.GridType.CARTESIAN, {}) >>> _canonicalize_domain_structure((domain,), (domain, domain)) == ( - ... (domain, None), + ... (domain, DomainAccessDescriptor.NEVER), ... (domain, domain), ... ) True - >>> _canonicalize_domain_structure((domain, None), None) == ((domain, None), (None, None)) + >>> _canonicalize_domain_structure( + ... (domain, DomainAccessDescriptor.NEVER), DomainAccessDescriptor.NEVER + ... ) == ( + ... (domain, DomainAccessDescriptor.NEVER), + ... (DomainAccessDescriptor.NEVER, DomainAccessDescriptor.NEVER), + ... ) True """ - if d1 is None and isinstance(d2, tuple): - return _canonicalize_domain_structure((None,) * len(d2), d2) - if d2 is None and isinstance(d1, tuple): - return _canonicalize_domain_structure(d1, (None,) * len(d1)) + if d1 is DomainAccessDescriptor.NEVER and isinstance(d2, tuple): + return _canonicalize_domain_structure((DomainAccessDescriptor.NEVER,) * len(d2), d2) + if d2 is DomainAccessDescriptor.NEVER and isinstance(d1, tuple): + return _canonicalize_domain_structure(d1, (DomainAccessDescriptor.NEVER,) * len(d1)) if isinstance(d1, tuple) and isinstance(d2, tuple): return tuple( zip( *( _canonicalize_domain_structure(el1, el2) - for el1, el2 in itertools.zip_longest(d1, d2, fillvalue=None) + for el1, el2 in itertools.zip_longest( + d1, d2, fillvalue=DomainAccessDescriptor.NEVER + ) ) ) ) # type: ignore[return-value] # mypy not smart enough @@ -101,16 +143,16 @@ def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMA def _merge_domains( - original_domains: ACCESSED_DOMAINS, - additional_domains: ACCESSED_DOMAINS, -) -> ACCESSED_DOMAINS: + original_domains: AccessedDomains, + additional_domains: AccessedDomains, +) -> AccessedDomains: new_domains = {**original_domains} for key, domain in additional_domains.items(): original_domain, domain = _canonicalize_domain_structure( - original_domains.get(key, None), domain + original_domains.get(key, DomainAccessDescriptor.NEVER), domain ) - new_domains[key] = tree_map(_domain_union_with_none)(original_domain, domain) + new_domains[key] = tree_map(_domain_union)(original_domain, domain) return new_domains @@ -118,44 +160,52 @@ def _merge_domains( def _extract_accessed_domains( stencil: itir.Expr, input_ids: list[str], - target_domain: domain_utils.SymbolicDomain, + target_domain: NonTupleDomainAccess, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], -) -> ACCESSED_DOMAINS: - accessed_domains: dict[str, domain_utils.SymbolicDomain | None] = {} +) -> dict[str, NonTupleDomainAccess]: + accessed_domains: dict[str, NonTupleDomainAccess] = {} shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True): + # TODO(tehrengruber): Dynamic shifts are not supported by `SymbolicDomain.translate`. Use + # special `UNKNOWN` marker for them until we have implemented a proper solution. + if any(s == trace_shifts.Sentinel.VALUE for shift in shifts_list for s in shift): + accessed_domains[in_field_id] = DomainAccessDescriptor.UNKNOWN + continue + new_domains = [ domain_utils.SymbolicDomain.translate( target_domain, shift, offset_provider, symbolic_domain_sizes ) + if not isinstance(target_domain, DomainAccessDescriptor) + else target_domain for shift in shifts_list ] - # `None` means field is never accessed - accessed_domains[in_field_id] = _domain_union_with_none( - accessed_domains.get(in_field_id, None), *new_domains + accessed_domains[in_field_id] = _domain_union( + accessed_domains.get(in_field_id, DomainAccessDescriptor.NEVER), *new_domains ) - return typing.cast(ACCESSED_DOMAINS, accessed_domains) + return accessed_domains def _infer_as_fieldop( applied_fieldop: itir.FunCall, - target_domain: DOMAIN, + target_domain: DomainAccess, + *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: + allow_uninferred: bool, +) -> tuple[itir.FunCall, AccessedDomains]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") - if target_domain is None: - raise ValueError("'target_domain' cannot be 'None'.") + if not allow_uninferred and target_domain is DomainAccessDescriptor.NEVER: + raise ValueError("'target_domain' cannot be 'NEVER' unless `allow_uninferred=True`.") # FIXME[#1582](tehrengruber): Temporary solution for `tuple_get` on scan result. See `test_solve_triag`. if isinstance(target_domain, tuple): - target_domain = _domain_union_with_none(*flatten_nested_tuple(target_domain)) - if not isinstance(target_domain, domain_utils.SymbolicDomain): - raise ValueError("'target_domain' needs to be a 'domain_utils.SymbolicDomain'.") + target_domain = _domain_union(*flatten_nested_tuple(target_domain)) # type: ignore[arg-type] # mypy not smart enough + assert isinstance(target_domain, (domain_utils.SymbolicDomain, DomainAccessDescriptor)) # `as_fieldop(stencil)(inputs...)` stencil, inputs = applied_fieldop.fun.args[0], applied_fieldop.args @@ -177,22 +227,29 @@ def _infer_as_fieldop( raise ValueError(f"Unsupported expression of type '{type(in_field)}'.") input_ids.append(id_) - inputs_accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( + inputs_accessed_domains: dict[str, NonTupleDomainAccess] = _extract_accessed_domains( stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes ) # Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s - accessed_domains: ACCESSED_DOMAINS = {} + accessed_domains: AccessedDomains = {} transformed_inputs: list[itir.Expr] = [] for in_field_id, in_field in zip(input_ids, inputs): transformed_input, accessed_domains_tmp = infer_expr( - in_field, inputs_accessed_domains[in_field_id], offset_provider, symbolic_domain_sizes + in_field, + inputs_accessed_domains[in_field_id], + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, ) transformed_inputs.append(transformed_input) accessed_domains = _merge_domains(accessed_domains, accessed_domains_tmp) - target_domain_expr = domain_utils.SymbolicDomain.as_expr(target_domain) + if not isinstance(target_domain, DomainAccessDescriptor): + target_domain_expr = domain_utils.SymbolicDomain.as_expr(target_domain) + else: + target_domain_expr = None transformed_call = im.as_fieldop(stencil, target_domain_expr)(*transformed_inputs) accessed_domains_without_tmp = { @@ -206,17 +263,15 @@ def _infer_as_fieldop( def _infer_let( let_expr: itir.FunCall, - input_domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: + input_domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.FunCall, AccessedDomains]: assert cpm.is_let(let_expr) assert isinstance(let_expr.fun, itir.Lambda) # just to make mypy happy - transformed_calls_expr, accessed_domains = infer_expr( - let_expr.fun.expr, input_domain, offset_provider, symbolic_domain_sizes - ) - let_params = {param_sym.id for param_sym in let_expr.fun.params} + + transformed_calls_expr, accessed_domains = infer_expr(let_expr.fun.expr, input_domain, **kwargs) + accessed_domains_let_args, accessed_domains_outer = _split_dict_by_key( lambda k: k in let_params, accessed_domains ) @@ -227,10 +282,9 @@ def _infer_let( arg, accessed_domains_let_args.get( param.id, - None, + DomainAccessDescriptor.NEVER, ), - offset_provider, - symbolic_domain_sizes, + **kwargs, ) accessed_domains_outer = _merge_domains(accessed_domains_outer, accessed_domains_arg) transformed_calls_args.append(transformed_calls_arg) @@ -247,13 +301,12 @@ def _infer_let( def _infer_make_tuple( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "make_tuple") infered_args_expr = [] - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} if not isinstance(domain, tuple): # promote domain to a tuple of domains such that it has the same structure as # the expression @@ -261,13 +314,12 @@ def _infer_make_tuple( # out @ c⟨ IDimₕ: [0, __out_size_0) ⟩ ← {__sym_1, __sym_2}; domain = (domain,) * len(expr.args) assert len(expr.args) >= len(domain) - # There may be less domains than tuple args, pad the domain with `None` in that case. - # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` - domain = (*domain, *(None for _ in range(len(expr.args) - len(domain)))) + # There may be fewer domains than tuple args, pad the domain with `NEVER` + # in that case. + # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` + domain = (*domain, *(DomainAccessDescriptor.NEVER for _ in range(len(expr.args) - len(domain)))) for i, arg in enumerate(expr.args): - infered_arg_expr, actual_domains_arg = infer_expr( - arg, domain[i], offset_provider, symbolic_domain_sizes - ) + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain[i], **kwargs) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(*infered_args_expr) @@ -276,19 +328,18 @@ def _infer_make_tuple( def _infer_tuple_get( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "tuple_get") - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} idx_expr, tuple_arg = expr.args assert isinstance(idx_expr, itir.Literal) idx = int(idx_expr.value) - tuple_domain = tuple(None if i != idx else domain for i in range(idx + 1)) - infered_arg_expr, actual_domains_arg = infer_expr( - tuple_arg, tuple_domain, offset_provider, symbolic_domain_sizes + tuple_domain = tuple( + DomainAccessDescriptor.NEVER if i != idx else domain for i in range(idx + 1) ) + infered_arg_expr, actual_domains_arg = infer_expr(tuple_arg, tuple_domain, **kwargs) infered_args_expr = im.tuple_get(idx, infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) @@ -297,18 +348,15 @@ def _infer_tuple_get( def _infer_if( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "if_") infered_args_expr = [] - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} cond, true_val, false_val = expr.args for arg in [true_val, false_val]: - infered_arg_expr, actual_domains_arg = infer_expr( - arg, domain, offset_provider, symbolic_domain_sizes - ) + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, **kwargs) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(cond, *infered_args_expr) @@ -317,24 +365,23 @@ def _infer_if( def _infer_expr( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: if isinstance(expr, itir.SymRef): return expr, {str(expr.id): domain} elif isinstance(expr, itir.Literal): return expr, {} elif cpm.is_applied_as_fieldop(expr): - return _infer_as_fieldop(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_as_fieldop(expr, domain, **kwargs) elif cpm.is_let(expr): - return _infer_let(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_let(expr, domain, **kwargs) elif cpm.is_call_to(expr, "make_tuple"): - return _infer_make_tuple(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_make_tuple(expr, domain, **kwargs) elif cpm.is_call_to(expr, "tuple_get"): - return _infer_tuple_get(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_tuple_get(expr, domain, **kwargs) elif cpm.is_call_to(expr, "if_"): - return _infer_if(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_if(expr, domain, **kwargs) elif ( cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) or cpm.is_call_to(expr, itir.TYPEBUILTINS) @@ -347,10 +394,12 @@ def _infer_expr( def infer_expr( expr: itir.Expr, - domain: DOMAIN, + domain: DomainAccess, + *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + allow_uninferred: bool = False, +) -> tuple[itir.Expr, AccessedDomains]: """ Infer the domain of all field subexpressions of `expr`. @@ -362,30 +411,35 @@ def infer_expr( - domain: The domain `expr` is read at. - symbolic_domain_sizes: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol name that evaluates to the length of that axis. + - allow_uninferred: Allow `as_fieldop` expressions whose domain is either unknown (e.g. + because of a dynamic shift) or never accessed. Returns: A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed) having a domain argument now, and a dictionary mapping symbol names referenced in `expr` to domain they are accessed at. """ - # this is just a small wrapper that populates the `domain` annex - expr, accessed_domains = _infer_expr(expr, domain, offset_provider, symbolic_domain_sizes) + expr, accessed_domains = _infer_expr( + expr, + domain, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, + ) expr.annex.domain = domain + return expr, accessed_domains def _infer_stmt( stmt: itir.Stmt, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], + **kwargs: Unpack[InferenceOptions], ): if isinstance(stmt, itir.SetAt): - transformed_call, _unused_domain = infer_expr( - stmt.expr, - domain_utils.SymbolicDomain.from_expr(stmt.domain), - offset_provider, - symbolic_domain_sizes, + transformed_call, _ = infer_expr( + stmt.expr, domain_utils.SymbolicDomain.from_expr(stmt.domain), **kwargs ) + return itir.SetAt( expr=transformed_call, domain=stmt.domain, @@ -394,20 +448,18 @@ def _infer_stmt( elif isinstance(stmt, itir.IfStmt): return itir.IfStmt( cond=stmt.cond, - true_branch=[ - _infer_stmt(c, offset_provider, symbolic_domain_sizes) for c in stmt.true_branch - ], - false_branch=[ - _infer_stmt(c, offset_provider, symbolic_domain_sizes) for c in stmt.false_branch - ], + true_branch=[_infer_stmt(c, **kwargs) for c in stmt.true_branch], + false_branch=[_infer_stmt(c, **kwargs) for c in stmt.false_branch], ) raise ValueError(f"Unsupported stmt: {stmt}") def infer_program( program: itir.Program, + *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, + allow_uninferred: bool = False, ) -> itir.Program: """ Infer the domain of all field subexpressions inside a program. @@ -423,5 +475,13 @@ def infer_program( function_definitions=program.function_definitions, params=program.params, declarations=program.declarations, - body=[_infer_stmt(stmt, offset_provider, symbolic_domain_sizes) for stmt in program.body], + body=[ + _infer_stmt( + stmt, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, + ) + for stmt in program.body + ], ) diff --git a/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py new file mode 100644 index 0000000000..0af9d9dab9 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py @@ -0,0 +1,73 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dataclasses +from typing import Optional + +import gt4py.next.iterator.ir_utils.common_pattern_matcher as cpm +from gt4py import eve +from gt4py.eve import utils as eve_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.transforms import fuse_as_fieldop, inline_lambdas, trace_shifts +from gt4py.next.iterator.transforms.symbol_ref_utils import collect_symbol_refs + + +def _dynamic_shift_args(node: itir.Expr) -> None | list[bool]: + if not cpm.is_applied_as_fieldop(node): + return None + params_shifts = trace_shifts.trace_stencil( + node.fun.args[0], # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + num_args=len(node.args), + save_to_annex=True, + ) + dynamic_shifts = [ + any(trace_shifts.Sentinel.VALUE in shifts for shifts in param_shifts) + for param_shifts in params_shifts + ] + return dynamic_shifts + + +@dataclasses.dataclass +class InlineDynamicShifts(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): + uids: eve_utils.UIDGenerator + + @classmethod + def apply(cls, node: itir.Program, uids: Optional[eve_utils.UIDGenerator] = None): + if not uids: + uids = eve_utils.UIDGenerator() + + return cls(uids=uids).visit(node) + + def visit_FunCall(self, node: itir.FunCall, **kwargs): + node = self.generic_visit(node, **kwargs) + + if cpm.is_let(node) and ( + dynamic_shift_args := _dynamic_shift_args(let_body := node.fun.expr) # type: ignore[attr-defined] # ensured by is_let + ): + inline_let_params = {p.id: False for p in node.fun.params} # type: ignore[attr-defined] # ensured by is_let + + for inp, is_dynamic_shift_arg in zip(let_body.args, dynamic_shift_args, strict=True): + for ref in collect_symbol_refs(inp): + if ref in inline_let_params and is_dynamic_shift_arg: + inline_let_params[ref] = True + + if any(inline_let_params): + node = inline_lambdas.inline_lambda( + node, eligible_params=list(inline_let_params.values()) + ) + + if dynamic_shift_args := _dynamic_shift_args(node): + assert len(node.fun.args) in [1, 2] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop in _dynamic_shift_args + fuse_args = [ + not isinstance(inp, itir.SymRef) and dynamic_shift_arg + for inp, dynamic_shift_arg in zip(node.args, dynamic_shift_args, strict=True) + ] + if any(fuse_args): + return fuse_as_fieldop.fuse_as_fieldop(node, fuse_args, uids=self.uids) + + return node diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index ec4207d726..d967c8fbb8 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -15,6 +15,7 @@ fuse_as_fieldop, global_tmps, infer_domain, + inline_dynamic_shifts, inline_fundefs, inline_lifts, ) @@ -73,6 +74,9 @@ def apply_common_transforms( ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + ir = inline_dynamic_shifts.InlineDynamicShifts.apply( + ir + ) # domain inference does not support dynamic offsets yet ir = infer_domain.infer_program( ir, offset_provider=offset_provider, @@ -158,5 +162,8 @@ def apply_fieldview_transforms( ir = CollapseTuple.apply( ir, offset_provider_type=common.offset_provider_to_type(offset_provider) ) # type: ignore[assignment] # type is still `itir.Program` + ir = inline_dynamic_shifts.InlineDynamicShifts.apply( + ir + ) # domain inference does not support dynamic offsets yet ir = infer_domain.infer_program(ir, offset_provider=offset_provider) return ir diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index d7413f32d7..bed6e89a52 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -130,7 +130,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ] # Markers to skip because of missing features in the domain inference DOMAIN_INFERENCE_SKIP_LIST = [ - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ] DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ diff --git a/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py new file mode 100644 index 0000000000..ff7a761c5a --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py @@ -0,0 +1,48 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Callable, Optional + +from gt4py import next as gtx +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import inline_dynamic_shifts +from gt4py.next.type_system import type_specifications as ts + +IDim = gtx.Dimension("IDim") +field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + + +def test_inline_dynamic_shift_as_fieldop_arg(): + testee = im.as_fieldop(im.lambda_("a", "b")(im.deref(im.shift("IOff", im.deref("b"))("a"))))( + im.as_fieldop("deref")("inp"), "offset_field" + ) + expected = im.as_fieldop( + im.lambda_("inp", "offset_field")( + im.deref(im.shift("IOff", im.deref("offset_field"))("inp")) + ) + )("inp", "offset_field") + + actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee) + assert actual == expected + + +def test_inline_dynamic_shift_let_var(): + testee = im.let("tmp", im.as_fieldop("deref")("inp"))( + im.as_fieldop(im.lambda_("a", "b")(im.deref(im.shift("IOff", im.deref("b"))("a"))))( + "tmp", "offset_field" + ) + ) + + expected = im.as_fieldop( + im.lambda_("inp", "offset_field")( + im.deref(im.shift("IOff", im.deref("offset_field"))("inp")) + ) + )("inp", "offset_field") + + actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee) + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 2492fc446d..779ab738cb 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -76,7 +76,7 @@ def setup_test_as_fieldop( def run_test_program( testee: itir.Program, expected: itir.Program, offset_provider: common.OffsetProvider ) -> None: - actual_program = infer_domain.infer_program(testee, offset_provider) + actual_program = infer_domain.infer_program(testee, offset_provider=offset_provider) folded_program = constant_fold_domain_exprs(actual_program) assert folded_program == expected @@ -89,12 +89,14 @@ def run_test_expr( expected_domains: dict[str, itir.Expr | dict[str | Dimension, tuple[itir.Expr, itir.Expr]]], offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, + allow_uninferred: bool = False, ): actual_call, actual_domains = infer_domain.infer_expr( testee, domain_utils.SymbolicDomain.from_expr(domain), - offset_provider, - symbolic_domain_sizes, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, ) folded_call = constant_fold_domain_exprs(actual_call) folded_domains = constant_fold_accessed_domains(actual_domains) if actual_domains else None @@ -104,10 +106,8 @@ def run_test_expr( def canonicalize_domain(d): if isinstance(d, dict): return im.domain(grid_type, d) - elif isinstance(d, itir.FunCall): + elif isinstance(d, (itir.FunCall, infer_domain.DomainAccessDescriptor)): return d - elif d is None: - return None raise AssertionError() expected_domains = {ref: canonicalize_domain(d) for ref, d in expected_domains.items()} @@ -128,10 +128,12 @@ def constant_fold_domain_exprs(arg: itir.Node) -> itir.Node: def constant_fold_accessed_domains( - domains: infer_domain.ACCESSED_DOMAINS, -) -> infer_domain.ACCESSED_DOMAINS: - def fold_domain(domain: domain_utils.SymbolicDomain | None): - if domain is None: + domains: infer_domain.AccessedDomains, +) -> infer_domain.AccessedDomains: + def fold_domain( + domain: domain_utils.SymbolicDomain | Literal[infer_domain.DomainAccessDescriptor.NEVER], + ): + if isinstance(domain, infer_domain.DomainAccessDescriptor): return domain return constant_fold_domain_exprs(domain.as_expr()) @@ -154,7 +156,7 @@ def translate_domain( shift_list = [item for sublist in shift_tuples for item in sublist] translated_domain_expr = domain_utils.SymbolicDomain.from_expr(domain).translate( - shift_list, offset_provider + shift_list, offset_provider=offset_provider ) return constant_fold_domain_exprs(translated_domain_expr.as_expr()) @@ -340,7 +342,7 @@ def test_nested_stencils(offset_provider): "in_field2": translate_domain(domain, {"Ioff": 0, "Joff": -2}, offset_provider), } actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) folded_call = constant_fold_domain_exprs(actual_call) @@ -384,7 +386,7 @@ def test_nested_stencils_n_times(offset_provider, iterations): } actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -397,7 +399,10 @@ def test_unused_input(offset_provider): stencil = im.lambda_("arg0", "arg1")(im.deref("arg0")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - expected_domains = {"in_field1": {IDim: (0, 11)}, "in_field2": None} + expected_domains = { + "in_field1": {IDim: (0, 11)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } testee, expected = setup_test_as_fieldop( stencil, domain, @@ -409,7 +414,7 @@ def test_let_unused_field(offset_provider): testee = im.let("a", "c")("b") domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.let("a", "c")("b") - expected_domains = {"b": {IDim: (0, 11)}, "c": None} + expected_domains = {"b": {IDim: (0, 11)}, "c": infer_domain.DomainAccessDescriptor.NEVER} run_test_expr(testee, expected, domain, expected_domains, offset_provider) @@ -522,7 +527,7 @@ def test_cond(offset_provider): expected = im.if_(cond, expected_field_1, expected_field_2) actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -579,7 +584,7 @@ def test_let(offset_provider): expected_domains_sym = {"in_field": translate_domain(domain, {"Ioff": 2}, offset_provider)} actual_call2, actual_domains2 = infer_domain.infer_expr( - testee2, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee2, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains2 = constant_fold_accessed_domains(actual_domains2) folded_call2 = constant_fold_domain_exprs(actual_call2) @@ -803,7 +808,7 @@ def test_make_tuple(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -815,13 +820,13 @@ def test_tuple_get_1_make_tuple(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.make_tuple(im.ref("a"), im.ref("b"), im.ref("c"))) expected_domains = { - "a": None, + "a": infer_domain.DomainAccessDescriptor.NEVER, "b": im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}), - "c": None, + "c": infer_domain.DomainAccessDescriptor.NEVER, } actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -833,7 +838,7 @@ def test_tuple_get_1_nested_make_tuple(offset_provider): domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) expected = im.tuple_get(1, im.make_tuple(im.ref("a"), im.make_tuple(im.ref("b"), im.ref("c")))) - expected_domains = {"a": None, "b": domain1, "c": domain2} + expected_domains = {"a": infer_domain.DomainAccessDescriptor.NEVER, "b": domain1, "c": domain2} actual, actual_domains = infer_domain.infer_expr( testee, @@ -841,7 +846,7 @@ def test_tuple_get_1_nested_make_tuple(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -852,14 +857,18 @@ def test_tuple_get_let_arg_make_tuple(offset_provider): testee = im.tuple_get(1, im.let("a", im.make_tuple(im.ref("b"), im.ref("c")))("d")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.let("a", im.make_tuple(im.ref("b"), im.ref("c")))("d")) - expected_domains = {"b": None, "c": None, "d": (None, domain)} + expected_domains = { + "b": infer_domain.DomainAccessDescriptor.NEVER, + "c": infer_domain.DomainAccessDescriptor.NEVER, + "d": (infer_domain.DomainAccessDescriptor.NEVER, domain), + } actual, actual_domains = infer_domain.infer_expr( testee, domain_utils.SymbolicDomain.from_expr( im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -870,12 +879,16 @@ def test_tuple_get_let_make_tuple(offset_provider): testee = im.tuple_get(1, im.let("a", "b")(im.make_tuple(im.ref("c"), im.ref("d")))) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.let("a", "b")(im.make_tuple(im.ref("c"), im.ref("d")))) - expected_domains = {"c": None, "d": domain, "b": None} + expected_domains = { + "c": infer_domain.DomainAccessDescriptor.NEVER, + "d": domain, + "b": infer_domain.DomainAccessDescriptor.NEVER, + } actual, actual_domains = infer_domain.infer_expr( testee, domain_utils.SymbolicDomain.from_expr(domain), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -903,7 +916,7 @@ def test_nested_make_tuple(offset_provider): ), domain_utils.SymbolicDomain.from_expr(domain3), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -914,10 +927,10 @@ def test_tuple_get_1(offset_provider): testee = im.tuple_get(1, im.ref("a")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.ref("a")) - expected_domains = {"a": (None, domain)} + expected_domains = {"a": (infer_domain.DomainAccessDescriptor.NEVER, domain)} actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -937,7 +950,7 @@ def test_domain_tuple(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -953,7 +966,7 @@ def test_as_fieldop_tuple_get(offset_provider): expected_domains = {"a": (domain, domain)} actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -973,7 +986,7 @@ def test_make_tuple_2tuple_get(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -990,7 +1003,7 @@ def test_make_tuple_non_tuple_domain(offset_provider): expected_domains = {"in_field1": domain, "in_field2": domain} actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -1004,7 +1017,7 @@ def test_arithmetic_builtin(offset_provider): expected_domains = {} actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_call = constant_fold_domain_exprs(actual_call) @@ -1048,3 +1061,35 @@ def test_symbolic_domain_sizes(unstructured_offset_provider): unstructured_offset_provider, symbolic_domain_sizes, ) + + +def test_unknown_domain(offset_provider): + stencil = im.lambda_("arg0", "arg1")(im.deref(im.shift("Ioff", im.deref("arg1"))("arg0"))) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": infer_domain.DomainAccessDescriptor.UNKNOWN, + "in_field2": {IDim: (0, 10)}, + } + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_never_accessed_domain(offset_provider): + stencil = im.lambda_("arg0", "arg1")(im.deref("arg0")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": {IDim: (0, 10)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_never_accessed_domain_tuple(offset_provider): + testee = im.tuple_get(0, im.make_tuple("in_field1", "in_field2")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": {IDim: (0, 10)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } + run_test_expr(testee, testee, domain, expected_domains, offset_provider) From 29b6af23c15955910f413ed12e5d1a463e7b5b4b Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 9 Dec 2024 16:44:28 +0100 Subject: [PATCH 077/178] build: fix min version of filelock (#1777) ... and fix linting after ruff update. --- .pre-commit-config.yaml | 10 ++-- constraints.txt | 48 +++++++++---------- min-extra-requirements-test.txt | 2 +- min-requirements-test.txt | 2 +- pyproject.toml | 2 +- requirements-dev.txt | 48 +++++++++---------- src/gt4py/__init__.py | 2 +- src/gt4py/cartesian/__init__.py | 4 +- src/gt4py/cartesian/backend/__init__.py | 2 +- src/gt4py/cartesian/cli.py | 2 +- src/gt4py/cartesian/frontend/__init__.py | 2 +- src/gt4py/cartesian/gtscript.py | 6 +-- src/gt4py/cartesian/testing/__init__.py | 2 +- src/gt4py/cartesian/utils/__init__.py | 2 +- src/gt4py/cartesian/utils/base.py | 6 +-- src/gt4py/eve/__init__.py | 2 +- src/gt4py/eve/datamodels/validators.py | 2 +- src/gt4py/next/errors/__init__.py | 2 +- src/gt4py/next/ffront/fbuiltins.py | 2 +- src/gt4py/next/iterator/runtime.py | 2 +- .../next/iterator/transforms/__init__.py | 2 +- .../iterator/transforms/fuse_as_fieldop.py | 6 ++- .../transformations/__init__.py | 14 +++--- src/gt4py/storage/__init__.py | 6 +-- 24 files changed, 88 insertions(+), 90 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e1870c67f..e383112310 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,7 +50,7 @@ repos: ## version = re.search('ruff==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: v{version}") ##]]] - rev: v0.7.4 + rev: v0.8.2 ##[[[end]]] hooks: # Run the linter. @@ -96,7 +96,7 @@ repos: - boltons==24.1.0 - cached-property==2.0.1 - click==8.1.7 - - cmake==3.31.0.1 + - cmake==3.31.1 - cytoolz==1.0.0 - deepdiff==8.0.1 - devtools==0.12.2 @@ -108,9 +108,9 @@ repos: - importlib-resources==6.4.5 - jinja2==3.1.4 - lark==1.2.2 - - mako==1.3.6 - - nanobind==2.2.0 - - ninja==1.11.1.1 + - mako==1.3.8 + - nanobind==2.4.0 + - ninja==1.11.1.2 - numpy==1.24.4 - packaging==24.2 - pybind11==2.13.6 diff --git a/constraints.txt b/constraints.txt index f039fa2125..fbdfb6e267 100644 --- a/constraints.txt +++ b/constraints.txt @@ -23,9 +23,9 @@ certifi==2024.8.30 # via requests cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox charset-normalizer==3.4.0 # via requests -clang-format==19.1.3 # via -r requirements-dev.in, gt4py (pyproject.toml) +clang-format==19.1.4 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.31.0.1 # via gt4py (pyproject.toml) +cmake==3.31.1 # via gt4py (pyproject.toml) cogapp==3.4.1 # via -r requirements-dev.in colorama==0.4.6 # via tox comm==0.2.2 # via ipykernel @@ -35,7 +35,7 @@ cycler==0.12.1 # via matplotlib cytoolz==1.0.0 # via gt4py (pyproject.toml) dace==1.0.0 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -debugpy==1.8.8 # via ipykernel +debugpy==1.8.9 # via ipykernel decorator==5.1.1 # via ipython deepdiff==8.0.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) @@ -47,11 +47,11 @@ exceptiongroup==1.2.2 # via hypothesis, pytest execnet==2.1.1 # via pytest-cache, pytest-xdist executing==2.1.0 # via devtools, stack-data factory-boy==3.3.1 # via gt4py (pyproject.toml), pytest-factoryboy -faker==33.0.0 # via factory-boy -fastjsonschema==2.20.0 # via nbformat +faker==33.1.0 # via factory-boy +fastjsonschema==2.21.1 # via nbformat filelock==3.16.1 # via gt4py (pyproject.toml), tox, virtualenv -fonttools==4.55.0 # via matplotlib -fparser==0.1.4 # via dace +fonttools==4.55.2 # via matplotlib +fparser==0.2.0 # via dace frozendict==2.4.6 # via gt4py (pyproject.toml) gitdb==4.0.11 # via gitpython gitpython==3.1.43 # via tach @@ -75,7 +75,7 @@ jupyter-core==5.7.2 # via ipykernel, jupyter-client, nbformat jupytext==1.16.4 # via -r requirements-dev.in kiwisolver==1.4.7 # via matplotlib lark==1.2.2 # via gt4py (pyproject.toml) -mako==1.3.6 # via gt4py (pyproject.toml) +mako==1.3.8 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins, rich markupsafe==2.1.5 # via jinja2, mako matplotlib==3.7.5 # via -r requirements-dev.in @@ -85,13 +85,13 @@ mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy mypy==1.13.0 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==2.2.0 # via gt4py (pyproject.toml) +nanobind==2.4.0 # via gt4py (pyproject.toml) nbclient==0.6.8 # via nbmake nbformat==5.10.4 # via jupytext, nbclient, nbmake nbmake==1.5.4 # via -r requirements-dev.in nest-asyncio==1.6.0 # via ipykernel, nbclient networkx==3.1 # via dace, tach -ninja==1.11.1.1 # via gt4py (pyproject.toml) +ninja==1.11.1.2 # via gt4py (pyproject.toml) nodeenv==1.9.1 # via pre-commit numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), matplotlib, scipy orderly-set==5.2.2 # via deepdiff @@ -102,7 +102,7 @@ pexpect==4.9.0 # via ipython pickleshare==0.7.5 # via ipython pillow==10.4.0 # via matplotlib pip-tools==7.4.1 # via -r requirements-dev.in -pipdeptree==2.23.4 # via -r requirements-dev.in +pipdeptree==2.24.0 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema platformdirs==4.3.6 # via black, jupyter-core, tox, virtualenv pluggy==1.5.0 # via pytest, tox @@ -113,15 +113,15 @@ psutil==6.1.0 # via -r requirements-dev.in, ipykernel, pytest-xdist ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data pybind11==2.13.6 # via gt4py (pyproject.toml) -pydantic==2.10.0 # via bump-my-version, pydantic-settings -pydantic-core==2.27.0 # via pydantic +pydantic==2.10.3 # via bump-my-version, pydantic-settings +pydantic-core==2.27.1 # via pydantic pydantic-settings==2.6.1 # via bump-my-version -pydot==3.0.2 # via tach +pydot==3.0.3 # via tach pygments==2.18.0 # via -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx pyparsing==3.1.4 # via matplotlib, pydot pyproject-api==1.8.0 # via tox pyproject-hooks==1.2.0 # via build, pip-tools -pytest==8.3.3 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist +pytest==8.3.4 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==5.0.0 # via -r requirements-dev.in pytest-custom-exit-code==0.3.0 # via -r requirements-dev.in @@ -137,12 +137,12 @@ questionary==2.0.1 # via bump-my-version referencing==0.35.1 # via jsonschema, jsonschema-specifications requests==2.32.3 # via sphinx rich==13.9.4 # via bump-my-version, rich-click, tach -rich-click==1.8.4 # via bump-my-version +rich-click==1.8.5 # via bump-my-version rpds-py==0.20.1 # via jsonschema, referencing -ruff==0.7.4 # via -r requirements-dev.in +ruff==0.8.2 # via -r requirements-dev.in scipy==1.10.1 # via gt4py (pyproject.toml) setuptools-scm==8.1.0 # via fparser -six==1.16.0 # via asttokens, astunparse, python-dateutil +six==1.17.0 # via asttokens, astunparse, python-dateutil smmap==5.0.1 # via gitdb snowballstemmer==2.2.0 # via sphinx sortedcontainers==2.4.0 # via hypothesis @@ -159,21 +159,21 @@ stack-data==0.6.3 # via ipython stdlib-list==0.10.0 # via tach sympy==1.13.3 # via dace tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.14.4 # via -r requirements-dev.in -tomli==2.1.0 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox +tach==0.16.5 # via -r requirements-dev.in +tomli==2.2.1 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via tach tomlkit==0.13.2 # via bump-my-version toolz==1.0.0 # via cytoolz -tornado==6.4.1 # via ipykernel, jupyter-client +tornado==6.4.2 # via ipykernel, jupyter-client tox==4.23.2 # via -r requirements-dev.in traitlets==5.14.3 # via comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat -types-tabulate==0.9.0.20240106 # via -r requirements-dev.in +types-tabulate==0.9.0.20241207 # via -r requirements-dev.in typing-extensions==4.12.2 # via annotated-types, black, faker, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, setuptools-scm, tox urllib3==2.2.3 # via requests -virtualenv==20.27.1 # via pre-commit, tox +virtualenv==20.28.0 # via pre-commit, tox wcmatch==10.0 # via bump-my-version wcwidth==0.2.13 # via prompt-toolkit -wheel==0.45.0 # via astunparse, pip-tools +wheel==0.45.1 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.20.2 # via importlib-metadata, importlib-resources diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index d7679a1f0f..6d75415181 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -67,7 +67,7 @@ deepdiff==5.6.0 devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 -filelock==3.0.0 +filelock==3.16.1 frozendict==2.3 gridtools-cpp==2.3.8 hypothesis==6.0.0 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index cf505e88d6..991b7a6941 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -63,7 +63,7 @@ deepdiff==5.6.0 devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 -filelock==3.0.0 +filelock==3.16.1 frozendict==2.3 gridtools-cpp==2.3.8 hypothesis==6.0.0 diff --git a/pyproject.toml b/pyproject.toml index e859c9b4f7..d086363ec4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ 'devtools>=0.6', 'diskcache>=5.6.3', 'factory-boy>=3.3.0', - 'filelock>=3.0.0', + 'filelock>=3.16.1', 'frozendict>=2.3', 'gridtools-cpp>=2.3.8,==2.*', "importlib-resources>=5.0;python_version<'3.9'", diff --git a/requirements-dev.txt b/requirements-dev.txt index 6542be36f1..40554cef13 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -23,9 +23,9 @@ certifi==2024.8.30 # via -c constraints.txt, requests cfgv==3.4.0 # via -c constraints.txt, pre-commit chardet==5.2.0 # via -c constraints.txt, tox charset-normalizer==3.4.0 # via -c constraints.txt, requests -clang-format==19.1.3 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) +clang-format==19.1.4 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via -c constraints.txt, black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.31.0.1 # via -c constraints.txt, gt4py (pyproject.toml) +cmake==3.31.1 # via -c constraints.txt, gt4py (pyproject.toml) cogapp==3.4.1 # via -c constraints.txt, -r requirements-dev.in colorama==0.4.6 # via -c constraints.txt, tox comm==0.2.2 # via -c constraints.txt, ipykernel @@ -35,7 +35,7 @@ cycler==0.12.1 # via -c constraints.txt, matplotlib cytoolz==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) dace==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) darglint==1.8.1 # via -c constraints.txt, -r requirements-dev.in -debugpy==1.8.8 # via -c constraints.txt, ipykernel +debugpy==1.8.9 # via -c constraints.txt, ipykernel decorator==5.1.1 # via -c constraints.txt, ipython deepdiff==8.0.1 # via -c constraints.txt, gt4py (pyproject.toml) devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml) @@ -47,11 +47,11 @@ exceptiongroup==1.2.2 # via -c constraints.txt, hypothesis, pytest execnet==2.1.1 # via -c constraints.txt, pytest-cache, pytest-xdist executing==2.1.0 # via -c constraints.txt, devtools, stack-data factory-boy==3.3.1 # via -c constraints.txt, gt4py (pyproject.toml), pytest-factoryboy -faker==33.0.0 # via -c constraints.txt, factory-boy -fastjsonschema==2.20.0 # via -c constraints.txt, nbformat +faker==33.1.0 # via -c constraints.txt, factory-boy +fastjsonschema==2.21.1 # via -c constraints.txt, nbformat filelock==3.16.1 # via -c constraints.txt, gt4py (pyproject.toml), tox, virtualenv -fonttools==4.55.0 # via -c constraints.txt, matplotlib -fparser==0.1.4 # via -c constraints.txt, dace +fonttools==4.55.2 # via -c constraints.txt, matplotlib +fparser==0.2.0 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython gitpython==3.1.43 # via -c constraints.txt, tach @@ -75,7 +75,7 @@ jupyter-core==5.7.2 # via -c constraints.txt, ipykernel, jupyter-client, n jupytext==1.16.4 # via -c constraints.txt, -r requirements-dev.in kiwisolver==1.4.7 # via -c constraints.txt, matplotlib lark==1.2.2 # via -c constraints.txt, gt4py (pyproject.toml) -mako==1.3.6 # via -c constraints.txt, gt4py (pyproject.toml) +mako==1.3.8 # via -c constraints.txt, gt4py (pyproject.toml) markdown-it-py==3.0.0 # via -c constraints.txt, jupytext, mdit-py-plugins, rich markupsafe==2.1.5 # via -c constraints.txt, jinja2, mako matplotlib==3.7.5 # via -c constraints.txt, -r requirements-dev.in @@ -85,13 +85,13 @@ mdurl==0.1.2 # via -c constraints.txt, markdown-it-py mpmath==1.3.0 # via -c constraints.txt, sympy mypy==1.13.0 # via -c constraints.txt, -r requirements-dev.in mypy-extensions==1.0.0 # via -c constraints.txt, black, mypy -nanobind==2.2.0 # via -c constraints.txt, gt4py (pyproject.toml) +nanobind==2.4.0 # via -c constraints.txt, gt4py (pyproject.toml) nbclient==0.6.8 # via -c constraints.txt, nbmake nbformat==5.10.4 # via -c constraints.txt, jupytext, nbclient, nbmake nbmake==1.5.4 # via -c constraints.txt, -r requirements-dev.in nest-asyncio==1.6.0 # via -c constraints.txt, ipykernel, nbclient networkx==3.1 # via -c constraints.txt, dace, tach -ninja==1.11.1.1 # via -c constraints.txt, gt4py (pyproject.toml) +ninja==1.11.1.2 # via -c constraints.txt, gt4py (pyproject.toml) nodeenv==1.9.1 # via -c constraints.txt, pre-commit numpy==1.24.4 # via -c constraints.txt, contourpy, dace, gt4py (pyproject.toml), matplotlib orderly-set==5.2.2 # via -c constraints.txt, deepdiff @@ -102,7 +102,7 @@ pexpect==4.9.0 # via -c constraints.txt, ipython pickleshare==0.7.5 # via -c constraints.txt, ipython pillow==10.4.0 # via -c constraints.txt, matplotlib pip-tools==7.4.1 # via -c constraints.txt, -r requirements-dev.in -pipdeptree==2.23.4 # via -c constraints.txt, -r requirements-dev.in +pipdeptree==2.24.0 # via -c constraints.txt, -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via -c constraints.txt, jsonschema platformdirs==4.3.6 # via -c constraints.txt, black, jupyter-core, tox, virtualenv pluggy==1.5.0 # via -c constraints.txt, pytest, tox @@ -113,15 +113,15 @@ psutil==6.1.0 # via -c constraints.txt, -r requirements-dev.in, ipyk ptyprocess==0.7.0 # via -c constraints.txt, pexpect pure-eval==0.2.3 # via -c constraints.txt, stack-data pybind11==2.13.6 # via -c constraints.txt, gt4py (pyproject.toml) -pydantic==2.10.0 # via -c constraints.txt, bump-my-version, pydantic-settings -pydantic-core==2.27.0 # via -c constraints.txt, pydantic +pydantic==2.10.3 # via -c constraints.txt, bump-my-version, pydantic-settings +pydantic-core==2.27.1 # via -c constraints.txt, pydantic pydantic-settings==2.6.1 # via -c constraints.txt, bump-my-version -pydot==3.0.2 # via -c constraints.txt, tach +pydot==3.0.3 # via -c constraints.txt, tach pygments==2.18.0 # via -c constraints.txt, -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx pyparsing==3.1.4 # via -c constraints.txt, matplotlib, pydot pyproject-api==1.8.0 # via -c constraints.txt, tox pyproject-hooks==1.2.0 # via -c constraints.txt, build, pip-tools -pytest==8.3.3 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist +pytest==8.3.4 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist pytest-cache==1.0 # via -c constraints.txt, -r requirements-dev.in pytest-cov==5.0.0 # via -c constraints.txt, -r requirements-dev.in pytest-custom-exit-code==0.3.0 # via -c constraints.txt, -r requirements-dev.in @@ -137,11 +137,11 @@ questionary==2.0.1 # via -c constraints.txt, bump-my-version referencing==0.35.1 # via -c constraints.txt, jsonschema, jsonschema-specifications requests==2.32.3 # via -c constraints.txt, sphinx rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach -rich-click==1.8.4 # via -c constraints.txt, bump-my-version +rich-click==1.8.5 # via -c constraints.txt, bump-my-version rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing -ruff==0.7.4 # via -c constraints.txt, -r requirements-dev.in +ruff==0.8.2 # via -c constraints.txt, -r requirements-dev.in setuptools-scm==8.1.0 # via -c constraints.txt, fparser -six==1.16.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil +six==1.17.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil smmap==5.0.1 # via -c constraints.txt, gitdb snowballstemmer==2.2.0 # via -c constraints.txt, sphinx sortedcontainers==2.4.0 # via -c constraints.txt, hypothesis @@ -158,21 +158,21 @@ stack-data==0.6.3 # via -c constraints.txt, ipython stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.13.3 # via -c constraints.txt, dace tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.14.4 # via -c constraints.txt, -r requirements-dev.in -tomli==2.1.0 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox +tach==0.16.5 # via -c constraints.txt, -r requirements-dev.in +tomli==2.2.1 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version toolz==1.0.0 # via -c constraints.txt, cytoolz -tornado==6.4.1 # via -c constraints.txt, ipykernel, jupyter-client +tornado==6.4.2 # via -c constraints.txt, ipykernel, jupyter-client tox==4.23.2 # via -c constraints.txt, -r requirements-dev.in traitlets==5.14.3 # via -c constraints.txt, comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat -types-tabulate==0.9.0.20240106 # via -c constraints.txt, -r requirements-dev.in +types-tabulate==0.9.0.20241207 # via -c constraints.txt, -r requirements-dev.in typing-extensions==4.12.2 # via -c constraints.txt, annotated-types, black, faker, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, setuptools-scm, tox urllib3==2.2.3 # via -c constraints.txt, requests -virtualenv==20.27.1 # via -c constraints.txt, pre-commit, tox +virtualenv==20.28.0 # via -c constraints.txt, pre-commit, tox wcmatch==10.0 # via -c constraints.txt, bump-my-version wcwidth==0.2.13 # via -c constraints.txt, prompt-toolkit -wheel==0.45.0 # via -c constraints.txt, astunparse, pip-tools +wheel==0.45.1 # via -c constraints.txt, astunparse, pip-tools xxhash==3.0.0 # via -c constraints.txt, gt4py (pyproject.toml) zipp==3.20.2 # via -c constraints.txt, importlib-metadata, importlib-resources diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py index 1b88285475..c0bf9580b3 100644 --- a/src/gt4py/__init__.py +++ b/src/gt4py/__init__.py @@ -27,6 +27,6 @@ if _sys.version_info >= (3, 10): - from . import next + from . import next # noqa: A004 shadowing a Python builtin __all__ += ["next"] diff --git a/src/gt4py/cartesian/__init__.py b/src/gt4py/cartesian/__init__.py index c03ef15105..90df315d5c 100644 --- a/src/gt4py/cartesian/__init__.py +++ b/src/gt4py/cartesian/__init__.py @@ -27,7 +27,7 @@ __all__ = [ - "typing", + "StencilObject", "caching", "cli", "config", @@ -39,5 +39,5 @@ "stencil_builder", "stencil_object", "type_hints", - "StencilObject", + "typing", ] diff --git a/src/gt4py/cartesian/backend/__init__.py b/src/gt4py/cartesian/backend/__init__.py index e58c7a01a7..4296e3b389 100644 --- a/src/gt4py/cartesian/backend/__init__.py +++ b/src/gt4py/cartesian/backend/__init__.py @@ -32,9 +32,9 @@ "BasePyExtBackend", "CLIBackendMixin", "CudaBackend", - "GTGpuBackend", "GTCpuIfirstBackend", "GTCpuKfirstBackend", + "GTGpuBackend", "NumpyBackend", "PurePythonBackendCLIMixin", "from_name", diff --git a/src/gt4py/cartesian/cli.py b/src/gt4py/cartesian/cli.py index 91daed9e98..4ea5e44074 100644 --- a/src/gt4py/cartesian/cli.py +++ b/src/gt4py/cartesian/cli.py @@ -90,7 +90,7 @@ def backend_table(cls) -> str: ", ".join(backend.languages["bindings"]) if backend and backend.languages else "?" for backend in backends ] - enabled = [backend is not None and "Yes" or "No" for backend in backends] + enabled = [(backend is not None and "Yes") or "No" for backend in backends] data = zip(names, comp_langs, binding_langs, enabled) return tabulate.tabulate(data, headers=headers) diff --git a/src/gt4py/cartesian/frontend/__init__.py b/src/gt4py/cartesian/frontend/__init__.py index 6988fb6aab..f1e0f9a775 100644 --- a/src/gt4py/cartesian/frontend/__init__.py +++ b/src/gt4py/cartesian/frontend/__init__.py @@ -10,4 +10,4 @@ from .base import REGISTRY, Frontend, from_name, register -__all__ = ["gtscript_frontend", "REGISTRY", "Frontend", "from_name", "register"] +__all__ = ["REGISTRY", "Frontend", "from_name", "gtscript_frontend", "register"] diff --git a/src/gt4py/cartesian/gtscript.py b/src/gt4py/cartesian/gtscript.py index 643ecba010..59f3ef37c2 100644 --- a/src/gt4py/cartesian/gtscript.py +++ b/src/gt4py/cartesian/gtscript.py @@ -657,10 +657,8 @@ def __str__(self) -> str: class _FieldDescriptorMaker: @staticmethod def _is_axes_spec(spec) -> bool: - return ( - isinstance(spec, Axis) - or isinstance(spec, collections.abc.Collection) - and all(isinstance(i, Axis) for i in spec) + return isinstance(spec, Axis) or ( + isinstance(spec, collections.abc.Collection) and all(isinstance(i, Axis) for i in spec) ) def __getitem__(self, field_spec): diff --git a/src/gt4py/cartesian/testing/__init__.py b/src/gt4py/cartesian/testing/__init__.py index 288d7b1d2d..0753b4175e 100644 --- a/src/gt4py/cartesian/testing/__init__.py +++ b/src/gt4py/cartesian/testing/__init__.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -__all__ = ["field", "global_name", "none", "parameter", "StencilTestSuite"] +__all__ = ["StencilTestSuite", "field", "global_name", "none", "parameter"] try: from .input_strategies import field, global_name, none, parameter from .suites import StencilTestSuite diff --git a/src/gt4py/cartesian/utils/__init__.py b/src/gt4py/cartesian/utils/__init__.py index 3c0bdb3fc3..626d29b167 100644 --- a/src/gt4py/cartesian/utils/__init__.py +++ b/src/gt4py/cartesian/utils/__init__.py @@ -37,7 +37,7 @@ ) -__all__ = [ +__all__ = [ # noqa: RUF022 `__all__` is not sorted # Modules "attrib", "meta", diff --git a/src/gt4py/cartesian/utils/base.py b/src/gt4py/cartesian/utils/base.py index d5d43a4103..35184a3f7b 100644 --- a/src/gt4py/cartesian/utils/base.py +++ b/src/gt4py/cartesian/utils/base.py @@ -63,10 +63,8 @@ def flatten_iter(nested_iterables, filter_none=False, *, skip_types=(str, bytes) def get_member(instance, item_name): try: - if ( - isinstance(instance, collections.abc.Mapping) - or isinstance(instance, collections.abc.Sequence) - and isinstance(item_name, int) + if isinstance(instance, collections.abc.Mapping) or ( + isinstance(instance, collections.abc.Sequence) and isinstance(item_name, int) ): return instance[item_name] else: diff --git a/src/gt4py/eve/__init__.py b/src/gt4py/eve/__init__.py index 5adac47da3..e6044f15ef 100644 --- a/src/gt4py/eve/__init__.py +++ b/src/gt4py/eve/__init__.py @@ -71,7 +71,7 @@ from .visitors import NodeTranslator, NodeVisitor -__all__ = [ +__all__ = [ # noqa: RUF022 `__all__` is not sorted # version "__version__", "__version_info__", diff --git a/src/gt4py/eve/datamodels/validators.py b/src/gt4py/eve/datamodels/validators.py index 119410460c..4ce6f94c5e 100644 --- a/src/gt4py/eve/datamodels/validators.py +++ b/src/gt4py/eve/datamodels/validators.py @@ -42,7 +42,7 @@ from .core import DataModelTP, FieldValidator -__all__ = [ +__all__ = [ # noqa: RUF022 `__all__` is not sorted # reexported from attrs "and_", "deep_iterable", diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index 89f78a45e4..9febe098a4 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -23,9 +23,9 @@ __all__ = [ "DSLError", "InvalidParameterAnnotationError", + "MissingArgumentError", "MissingAttributeError", "MissingParameterAnnotationError", - "MissingArgumentError", "UndefinedSymbolError", "UnsupportedPythonFeatureError", ] diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index b60fa63f95..1210e96efc 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -10,7 +10,7 @@ import functools import inspect import math -from builtins import bool, float, int, tuple +from builtins import bool, float, int, tuple # noqa: A004 shadowing a Python built-in from typing import Any, Callable, Final, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index e47a6886ad..c9a5b15de7 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -26,7 +26,7 @@ # TODO(tehrengruber): remove cirular dependency and import unconditionally from gt4py.next import backend as next_backend -__all__ = ["offset", "fundef", "fendef", "set_at", "if_stmt"] +__all__ = ["fendef", "fundef", "if_stmt", "offset", "set_at"] @dataclass(frozen=True) diff --git a/src/gt4py/next/iterator/transforms/__init__.py b/src/gt4py/next/iterator/transforms/__init__.py index d0afc610e7..1d91254ee8 100644 --- a/src/gt4py/next/iterator/transforms/__init__.py +++ b/src/gt4py/next/iterator/transforms/__init__.py @@ -13,4 +13,4 @@ ) -__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "GTIRTransform"] +__all__ = ["GTIRTransform", "apply_common_transforms", "apply_fieldview_transforms"] diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index e8a221b814..661b456608 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -240,8 +240,10 @@ def visit_FunCall(self, node: itir.FunCall): or ( isinstance(arg, itir.FunCall) and ( - cpm.is_call_to(arg.fun, "as_fieldop") - and isinstance(arg.fun.args[0], itir.Lambda) + ( + cpm.is_call_to(arg.fun, "as_fieldop") + and isinstance(arg.fun.args[0], itir.Lambda) + ) or cpm.is_call_to(arg, "if_") ) and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 2232bcef01..4f3efb19b0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -43,25 +43,25 @@ "GT_SIMPLIFY_DEFAULT_SKIP_SET", "GPUSetBlockSize", "GT4PyGlobalSelfCopyElimination", - "GT4PyMoveTaskletIntoMap", "GT4PyMapBufferElimination", + "GT4PyMoveTaskletIntoMap", "LoopBlocking", - "MapIterationOrder", "MapFusionParallel", "MapFusionSerial", + "MapIterationOrder", "SerialMapPromoter", "SerialMapPromoterGPU", "gt_auto_optimize", "gt_change_transient_strides", "gt_create_local_double_buffering", + "gt_find_constant_arguments", + "gt_gpu_transform_non_standard_memlet", "gt_gpu_transformation", "gt_inline_nested_sdfg", - "gt_set_iteration_order", - "gt_set_gpu_blocksize", - "gt_simplify", "gt_make_transients_persistent", "gt_reduce_distributed_buffering", - "gt_find_constant_arguments", + "gt_set_gpu_blocksize", + "gt_set_iteration_order", + "gt_simplify", "gt_substitute_compiletime_symbols", - "gt_gpu_transform_non_standard_memlet", ] diff --git a/src/gt4py/storage/__init__.py b/src/gt4py/storage/__init__.py index 4866cd480c..5986baa65e 100644 --- a/src/gt4py/storage/__init__.py +++ b/src/gt4py/storage/__init__.py @@ -16,12 +16,12 @@ __all__ = [ "cartesian", - "layout", "empty", "from_array", + "from_name", "full", + "layout", "ones", - "zeros", - "from_name", "register", + "zeros", ] From 98889056c914886912d9131793deb67b5f947602 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 10 Dec 2024 10:02:22 +0100 Subject: [PATCH 078/178] feat[next]: Change interval syntax in ITIR pretty printer (#1766) We currently use `)` in the pretty printer to express an open interval. This is quite cumbersome when debugging the IR because it breaks matching parenthesis in the editor of functions and calls, e.g. when does a function start and end. This PR simply uses `[` instead. --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 6 +++--- src/gt4py/next/iterator/pretty_parser.py | 2 +- src/gt4py/next/iterator/pretty_printer.py | 4 +++- src/gt4py/next/iterator/transforms/fuse_as_fieldop.py | 6 +++--- src/gt4py/next/iterator/transforms/inline_fundefs.py | 2 +- .../unit_tests/iterator_tests/test_pretty_parser.py | 4 ++-- .../unit_tests/iterator_tests/test_pretty_printer.py | 2 +- 7 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index a4e111e785..0839e95b5b 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -423,11 +423,11 @@ def domain( ... }, ... ) ... ) - 'c⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩' + 'c⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩' >>> str(domain(common.GridType.CARTESIAN, {"IDim": (0, 10), "JDim": (0, 20)})) - 'c⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩' + 'c⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩' >>> str(domain(common.GridType.UNSTRUCTURED, {"IDim": (0, 10), "JDim": (0, 20)})) - 'u⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩' + 'u⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩' """ if isinstance(grid_type, common.GridType): grid_type = f"{grid_type!s}_domain" diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index 29b30beae1..a077b39911 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -84,7 +84,7 @@ else_branch_seperator: "else" if_stmt: "if" "(" prec0 ")" "{" ( stmt )* "}" else_branch_seperator "{" ( stmt )* "}" - named_range: AXIS_LITERAL ":" "[" prec0 "," prec0 ")" + named_range: AXIS_LITERAL ":" "[" prec0 "," prec0 "[" function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";" declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" TYPE_LITERAL ")" ";" stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";" diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index a25f99356c..7acbf5d23d 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -190,7 +190,9 @@ def visit_FunCall(self, node: ir.FunCall, *, prec: int) -> list[str]: if fun_name == "named_range" and len(node.args) == 3: # named_range(dim, start, stop) → dim: [star, stop) dim, start, end = self.visit(node.args, prec=0) - res = self._hmerge(dim, [": ["], start, [", "], end, [")"]) + res = self._hmerge( + dim, [": ["], start, [", "], end, ["["] + ) # to get matching parenthesis of functions return self._prec_parens(res, prec, PRECEDENCE["__call__"]) if fun_name == "cartesian_domain" and len(node.args) >= 1: # cartesian_domain(x, y, ...) → c{ x × y × ... } # noqa: RUF003 [ambiguous-unicode-character-comment] diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 661b456608..b7087472e0 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -186,15 +186,15 @@ class FuseAsFieldOp(eve.NodeTranslator): ... im.ref("inp3", field_type), ... ) >>> print(nested_as_fieldop) - as_fieldop(λ(__arg0, __arg1) → ·__arg0 + ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)( - as_fieldop(λ(__arg0, __arg1) → ·__arg0 × ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2), inp3 + as_fieldop(λ(__arg0, __arg1) → ·__arg0 + ·__arg1, c⟨ IDimₕ: [0, 1[ ⟩)( + as_fieldop(λ(__arg0, __arg1) → ·__arg0 × ·__arg1, c⟨ IDimₕ: [0, 1[ ⟩)(inp1, inp2), inp3 ) >>> print( ... FuseAsFieldOp.apply( ... nested_as_fieldop, offset_provider_type={}, allow_undeclared_symbols=True ... ) ... ) - as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3) + as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1[ ⟩)(inp1, inp2, inp3) """ # noqa: RUF002 # ignore ambiguous multiplication character uids: eve_utils.UIDGenerator diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index a2188030a1..e4cae978da 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -59,7 +59,7 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program: >>> print(prune_unreferenced_fundefs(program)) testee(inp, out) { fun1 = λ(a) → ·a; - out @ c⟨ IDimₕ: [0, 10) ⟩ ← fun1(inp); + out @ c⟨ IDimₕ: [0, 10[ ⟩ ← fun1(inp); } """ fun_names = [fun.id for fun in program.function_definitions] diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index bf47f997d6..af9084f407 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -127,7 +127,7 @@ def test_make_tuple(): def test_named_range_horizontal(): - testee = "IDimₕ: [x, y)" + testee = "IDimₕ: [x, y[" expected = ir.FunCall( fun=ir.SymRef(id="named_range"), args=[ir.AxisLiteral(value="IDim"), ir.SymRef(id="x"), ir.SymRef(id="y")], @@ -137,7 +137,7 @@ def test_named_range_horizontal(): def test_named_range_vertical(): - testee = "IDimᵥ: [x, y)" + testee = "IDimᵥ: [x, y[" expected = ir.FunCall( fun=ir.SymRef(id="named_range"), args=[ diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 11f50dbf6d..6b45f470b7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -233,7 +233,7 @@ def test_named_range_horizontal(): fun=ir.SymRef(id="named_range"), args=[ir.AxisLiteral(value="IDim"), ir.SymRef(id="x"), ir.SymRef(id="y")], ) - expected = "IDimₕ: [x, y)" + expected = "IDimₕ: [x, y[" actual = pformat(testee) assert actual == expected From 06b398af7c5a4235d2c595bbbac93ec70f31a5a6 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 16 Dec 2024 15:32:16 +0100 Subject: [PATCH 079/178] refact[next][dace]: split handling of let-statement lambdas from stencil body (#1781) This is a refactoring of the code to lower lambda nodes: it splits the lowering of let-statements from the lowering of stencil expressions. --- .../gtir_builtin_translators.py | 43 ++--- .../runners/dace_fieldview/gtir_dataflow.py | 165 +++++++++++++----- .../runners/dace_fieldview/gtir_sdfg.py | 5 +- 3 files changed, 143 insertions(+), 70 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index ff011c4193..cffbd74c90 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace -import dace.subsets as sbs +from dace import subsets as dace_subsets from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins @@ -30,7 +30,7 @@ gtir_python_codegen, utility as dace_gtir_utils, ) -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_info as ti, type_specifications as ts if TYPE_CHECKING: @@ -39,7 +39,7 @@ def _get_domain_indices( dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None -) -> sbs.Indices: +) -> dace_subsets.Indices: """ Helper function to construct the list of indices for a field domain, applying an optional offset in each dimension as start index. @@ -55,9 +55,9 @@ def _get_domain_indices( """ index_variables = [dace.symbolic.SymExpr(dace_gtir_utils.get_map_variable(dim)) for dim in dims] if offsets is None: - return sbs.Indices(index_variables) + return dace_subsets.Indices(index_variables) else: - return sbs.Indices( + return dace_subsets.Indices( [ index - offset if offset != 0 else index for index, offset in zip(index_variables, offsets, strict=True) @@ -96,7 +96,7 @@ def get_local_view( """Helper method to access a field in local view, given the compute domain of a field operator.""" if isinstance(self.gt_type, ts.ScalarType): return gtir_dataflow.MemletExpr( - dc_node=self.dc_node, gt_dtype=self.gt_type, subset=sbs.Indices([0]) + dc_node=self.dc_node, gt_dtype=self.gt_type, subset=dace_subsets.Indices([0]) ) if isinstance(self.gt_type, ts.FieldType): @@ -263,7 +263,7 @@ def _create_field_operator( dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) - field_subset = sbs.Range.from_indices(field_indices) + field_subset = dace_subsets.Range.from_indices(field_indices) if isinstance(output_edge.result.gt_dtype, ts.ScalarType): assert output_edge.result.gt_dtype == node_type.dtype assert isinstance(dataflow_output_desc, dace.data.Scalar) @@ -280,7 +280,7 @@ def _create_field_operator( field_dims.append(output_edge.result.gt_dtype.offset_type) field_shape.extend(dataflow_output_desc.shape) field_offset.extend(dataflow_output_desc.offset) - field_subset = field_subset + sbs.Range.from_array(dataflow_output_desc) + field_subset = field_subset + dace_subsets.Range.from_array(dataflow_output_desc) # allocate local temporary storage field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) @@ -366,36 +366,37 @@ def translate_as_fieldop( """ assert isinstance(node, gtir.FunCall) assert cpm.is_call_to(node.fun, "as_fieldop") - assert isinstance(node.type, ts.FieldType) fun_node = node.fun assert len(fun_node.args) == 2 - stencil_expr, domain_expr = fun_node.args + fieldop_expr, domain_expr = fun_node.args - if isinstance(stencil_expr, gtir.Lambda): - # Default case, handled below: the argument expression is a lambda function - # representing the stencil operation to be computed over the field domain. - pass - elif cpm.is_ref_to(stencil_expr, "deref"): + assert isinstance(node.type, ts.FieldType) + if cpm.is_ref_to(fieldop_expr, "deref"): # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. stencil_expr = im.lambda_("a")(im.deref("a")) - stencil_expr.expr.type = node.type.dtype # type: ignore[attr-defined] + stencil_expr.expr.type = node.type.dtype + elif isinstance(fieldop_expr, gtir.Lambda): + # Default case, handled below: the argument expression is a lambda function + # representing the stencil operation to be computed over the field domain. + stencil_expr = fieldop_expr else: raise NotImplementedError( - f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node." + f"Expression type '{type(fieldop_expr)}' not supported as argument to 'as_fieldop' node." ) # parse the domain of the field operator domain = extract_domain(domain_expr) # visit the list of arguments to be passed to the lambda expression - stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] + fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) - input_edges, output_edge = taskgen.visit(stencil_expr, args=stencil_args) + input_edges, output_edge = gtir_dataflow.visit_lambda( + sdfg, state, sdfg_builder, stencil_expr, fieldop_args + ) return _create_field_operator( sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge @@ -654,7 +655,7 @@ def translate_tuple_get( if not isinstance(node.args[0], gtir.Literal): raise ValueError("Tuple can only be subscripted with compile-time constants.") - assert node.args[0].type == dace_utils.as_itir_type(INDEX_DTYPE) + assert ti.is_integral(node.args[0].type) index = int(node.args[0].value) data_nodes = sdfg_builder.visit( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index cfba4d61e5..a3653fb519 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -10,10 +10,22 @@ import abc import dataclasses -from typing import Any, Dict, Final, List, Optional, Protocol, Set, Tuple, TypeAlias, Union +from typing import ( + Any, + Dict, + Final, + List, + Optional, + Protocol, + Sequence, + Set, + Tuple, + TypeAlias, + Union, +) import dace -import dace.subsets as sbs +from dace import subsets as dace_subsets from gt4py import eve from gt4py.next import common as gtx_common @@ -68,7 +80,7 @@ class MemletExpr: dc_node: dace.nodes.AccessNode gt_dtype: itir_ts.ListType | ts.ScalarType - subset: sbs.Indices | sbs.Range + subset: dace_subsets.Range @dataclasses.dataclass(frozen=True) @@ -104,7 +116,7 @@ class IteratorExpr: field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]] indices: dict[gtx_common.Dimension, DataExpr] - def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: + def get_memlet_subset(self, sdfg: dace.SDFG) -> dace_subsets.Range: if not all(isinstance(self.indices[dim], SymbolExpr) for dim, _ in self.field_domain): raise ValueError(f"Cannot deref iterator {self}.") @@ -117,7 +129,7 @@ def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: assert len(field_desc.shape) == len(self.field_domain) field_domain = self.field_domain - return sbs.Range.from_string( + return dace_subsets.Range.from_string( ",".join( str(self.indices[dim].value - offset) # type: ignore[union-attr] if dim in self.indices @@ -152,7 +164,7 @@ class MemletInputEdge(DataflowInputEdge): state: dace.SDFGState source: dace.nodes.AccessNode - subset: sbs.Range + subset: dace_subsets.Range dest: dace.nodes.AccessNode | dace.nodes.Tasklet dest_conn: Optional[str] @@ -202,7 +214,7 @@ def connect( self, mx: dace.nodes.MapExit, dest: dace.nodes.AccessNode, - subset: sbs.Range, + subset: dace_subsets.Range, ) -> None: # retrieve the node which writes the result last_node = self.state.in_edges(self.result.dc_node)[0].src @@ -256,10 +268,12 @@ def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: return op_name, reduce_init, reduce_identity +@dataclasses.dataclass(frozen=True) class LambdaToDataflow(eve.NodeVisitor): """ - Translates an `ir.Lambda` expression to a dataflow graph. + Visitor class to translate a `Lambda` expression to a dataflow graph. + This visitor should be applied by calling `apply()` method on a `Lambda` IR. The dataflow graph generated here typically represents the stencil function of a field operator. It only computes single elements or pure local fields, in case of neighbor values. In case of local fields, the dataflow contains @@ -275,25 +289,15 @@ class LambdaToDataflow(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState subgraph_builder: gtir_sdfg.DataflowBuilder - input_edges: list[DataflowInputEdge] - symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] - - def __init__( - self, - sdfg: dace.SDFG, - state: dace.SDFGState, - subgraph_builder: gtir_sdfg.DataflowBuilder, - ): - self.sdfg = sdfg - self.state = state - self.subgraph_builder = subgraph_builder - self.input_edges = [] - self.symbol_map = {} + input_edges: list[DataflowInputEdge] = dataclasses.field(default_factory=lambda: []) + symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] = dataclasses.field( + default_factory=lambda: {} + ) def _add_input_data_edge( self, src: dace.nodes.AccessNode, - src_subset: sbs.Range, + src_subset: dace_subsets.Range, dst_node: dace.nodes.Node, dst_conn: Optional[str] = None, src_offset: Optional[list[dace.symbolic.SymExpr]] = None, @@ -301,7 +305,7 @@ def _add_input_data_edge( input_subset = ( src_subset if src_offset is None - else sbs.Range( + else dace_subsets.Range( (start - off, stop - off, step) for (start, stop, step), off in zip(src_subset, src_offset, strict=True) ) @@ -512,7 +516,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: # add new termination point for the field parameter self._add_input_data_edge( arg_expr.field, - sbs.Range.from_array(field_desc), + dace_subsets.Range.from_array(field_desc), deref_node, "field", src_offset=[offset for (_, offset) in arg_expr.field_domain], @@ -580,7 +584,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: MemletExpr( dc_node=it.field, gt_dtype=node.type, - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( ",".join( str(it.indices[dim].value - offset) # type: ignore[union-attr] if dim != offset_provider.codomain @@ -596,7 +600,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: MemletExpr( dc_node=self.state.add_access(connectivity), gt_dtype=node.type, - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( f"{origin_index.value}, 0:{offset_provider.max_neighbors}" ), ) @@ -758,7 +762,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: gt_dtype=itir_ts.ListType( element_type=node.type.element_type, offset_type=offset_type ), - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" ), ) @@ -908,7 +912,9 @@ def _make_reduce_with_skip_values( ) self._add_input_data_edge( connectivity_node, - sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}"), + dace_subsets.Range.from_string( + f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" + ), nsdfg_node, "neighbor_indices", ) @@ -1081,7 +1087,7 @@ def _make_dynamic_neighbor_offset( ) self._add_input_data_edge( offset_table_node, - sbs.Range.from_array(offset_table_node.desc(self.sdfg)), + dace_subsets.Range.from_array(offset_table_node.desc(self.sdfg)), tasklet_node, "table", ) @@ -1127,7 +1133,7 @@ def _make_unstructured_shift( shifted_indices[neighbor_dim] = MemletExpr( dc_node=offset_table_node, gt_dtype=it.gt_dtype, - subset=sbs.Indices([origin_index.value, offset_expr.value]), + subset=dace_subsets.Indices([origin_index.value, offset_expr.value]), ) else: # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node @@ -1264,39 +1270,39 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: elif cpm.is_applied_shift(node): return self._visit_shift(node) + elif isinstance(node.fun, gtir.Lambda): + # Lambda node should be visited with 'visit_let()' method. + raise ValueError(f"Unexpected lambda in 'FunCall' node: {node}.") + elif isinstance(node.fun, gtir.SymRef): return self._visit_generic_builtin(node) else: raise NotImplementedError(f"Invalid 'FunCall' node: {node}.") - def visit_Lambda( - self, node: gtir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] - ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: - for p, arg in zip(node.params, args, strict=True): - self.symbol_map[str(p.id)] = arg - output_expr: DataExpr = self.visit(node.expr) - if isinstance(output_expr, ValueExpr): - return self.input_edges, DataflowOutputEdge(self.state, output_expr) + def visit_Lambda(self, node: gtir.Lambda) -> DataflowOutputEdge: + result: DataExpr = self.visit(node.expr) + + if isinstance(result, ValueExpr): + return DataflowOutputEdge(self.state, result) - if isinstance(output_expr, MemletExpr): + if isinstance(result, MemletExpr): # special case where the field operator is simply copying data from source to destination node - output_dtype = output_expr.dc_node.desc(self.sdfg).dtype + output_dtype = result.dc_node.desc(self.sdfg).dtype tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") self._add_input_data_edge( - output_expr.dc_node, - output_expr.subset, + result.dc_node, + result.subset, tasklet_node, "__inp", ) else: - assert isinstance(output_expr, SymbolExpr) # even simpler case, where a constant value is written to destination node - output_dtype = output_expr.dc_dtype - tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {output_expr.value}") + output_dtype = result.dc_dtype + tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {result.value}") output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") - return self.input_edges, DataflowOutputEdge(self.state, output_expr) + return DataflowOutputEdge(self.state, output_expr) def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: dc_dtype = dace_utils.as_dace_type(node.type) @@ -1309,3 +1315,68 @@ def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolE # if not in the lambda symbol map, this must be a symref to a builtin function assert param in gtir_python_codegen.MATH_BUILTINS_MAPPING return SymbolExpr(param, dace.string) + + def visit_let( + self, + node: gtir.Lambda, + args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], + ) -> DataflowOutputEdge: + """ + Maps lambda arguments to internal parameters. + + This method is responsible to recognize the usage of the `Lambda` node, + which can be either a let-statement or the stencil expression in local view. + The usage of a `Lambda` as let-statement corresponds to computing some results + and making them available inside the lambda scope, represented as a nested SDFG. + All let-statements, if any, are supposed to be encountered before the stencil + expression. In other words, the `Lambda` node representing the stencil expression + is always the innermost node. + Therefore, the lowering of let-statements results in recursive calls to + `visit_let()` until the stencil expression is found. At that point, it falls + back to the `visit()` function. + """ + + # lambda arguments are mapped to symbols defined in lambda scope. + for p, arg in zip(node.params, args, strict=True): + self.symbol_map[str(p.id)] = arg + + if cpm.is_let(node.expr): + let_node = node.expr + let_args = [self.visit(arg) for arg in let_node.args] + assert isinstance(let_node.fun, gtir.Lambda) + return self.visit_let(let_node.fun, args=let_args) + else: + # this lambda node is not a let-statement, but a stencil expression + return self.visit(node) + + +def visit_lambda( + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, + node: gtir.Lambda, + args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], +) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: + """ + Entry point to visit a `Lambda` node and lower it to a dataflow graph, + that can be instantiated inside a map scope implementing the field operator. + + It calls `LambdaToDataflow.visit_let()` to map the lambda arguments to internal + parameters and visit the let-statements (if any), which always appear as outermost + nodes. Finally, the visitor returns the output edge of the dataflow. + + Args: + sdfg: The SDFG where the dataflow graph will be instantiated. + state: The SDFG state where the dataflow graph will be instantiated. + sdfg_builder: Helper class to build the SDFG. + node: Lambda node to visit. + args: Arguments passed to lambda node. + + Returns: + A tuple of two elements: + - List of connections for data inputs to the dataflow. + - Output data connection. + """ + taskgen = LambdaToDataflow(sdfg, state, sdfg_builder) + output_edge = taskgen.visit_let(node, args) + return taskgen.input_edges, output_edge diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 6b5e164458..9bd40f75f8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -602,7 +602,7 @@ def visit_Lambda( node: gtir.Lambda, sdfg: dace.SDFG, head_state: dace.SDFGState, - args: list[gtir_builtin_translators.FieldopResult], + args: Sequence[gtir_builtin_translators.FieldopResult], ) -> gtir_builtin_translators.FieldopResult: """ Translates a `Lambda` node to a nested SDFG in the current state. @@ -679,7 +679,7 @@ def get_field_domain_offset( self.offset_provider_type, lambda_symbols, lambda_field_offsets ) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) - nstate = nsdfg.add_state("lambda") + nsdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) # add sdfg storage for the symbols that need to be passed as input parameters lambda_params = [ @@ -690,6 +690,7 @@ def get_field_domain_offset( nsdfg, node_params=lambda_params, symbolic_arguments=lambda_domain_symbols ) + nstate = nsdfg.add_state("lambda") lambda_result = lambda_translator.visit( node.expr, sdfg=nsdfg, From 77cad7c8862c6164dff5f9e192ffef8fc9a2b1af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Fri, 20 Dec 2024 11:53:40 +0100 Subject: [PATCH 080/178] feat[dace][next]: Fixing strides in optimization (#1782) Added functionality to properly handle changes of strides. During the implementation of the scan we found that the strides were not handled properly. Most importantly a change on one level was not propagated into the next levels, i.e. they were still using the old strides. This PR Solves most of the problems, but there are still some issues that are unsolved: - Views are not adjusted yet (Fixed in [PR@1784](https://github.com/GridTools/gt4py/pull/1784)). - It is not properly checked if the symbols of the propagated strides are safe to introduce into the nested SDFG. The initial functionality of this PR was done by Edoardo Paone (@edopao). --------- Co-authored-by: edopao --- .../transformations/__init__.py | 12 +- .../transformations/gpu_utils.py | 2 +- .../transformations/simplify.py | 5 +- .../dace_fieldview/transformations/strides.py | 611 +++++++++++++++++- .../test_map_buffer_elimination.py | 93 ++- .../transformation_tests/test_strides.py | 541 ++++++++++++++++ 6 files changed, 1238 insertions(+), 26 deletions(-) create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 4f3efb19b0..0902bd665a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -35,7 +35,13 @@ gt_simplify, gt_substitute_compiletime_symbols, ) -from .strides import gt_change_transient_strides +from .strides import ( + gt_change_transient_strides, + gt_map_strides_to_dst_nested_sdfg, + gt_map_strides_to_src_nested_sdfg, + gt_propagate_strides_from_access_node, + gt_propagate_strides_of, +) from .util import gt_find_constant_arguments, gt_make_transients_persistent @@ -59,6 +65,10 @@ "gt_gpu_transformation", "gt_inline_nested_sdfg", "gt_make_transients_persistent", + "gt_map_strides_to_dst_nested_sdfg", + "gt_map_strides_to_src_nested_sdfg", + "gt_propagate_strides_from_access_node", + "gt_propagate_strides_of", "gt_reduce_distributed_buffering", "gt_set_gpu_blocksize", "gt_set_iteration_order", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 2cd3020180..7b14144ead 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -95,7 +95,7 @@ def gt_gpu_transformation( if try_removing_trivial_maps: # In DaCe a Tasklet, outside of a Map, can not write into an _array_ that is on - # GPU. `sdfg.appyl_gpu_transformations()` will wrap such Tasklets in a Map. So + # GPU. `sdfg.apply_gpu_transformations()` will wrap such Tasklets in a Map. So # we might end up with lots of these trivial Maps, each requiring a separate # kernel launch. To prevent this we will combine these trivial maps, if # possible, with their downstream maps. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index 6b7bd1b6d5..4339a761fa 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -950,7 +950,7 @@ def _perform_pointwise_test( def apply( self, - graph: dace.SDFGState | dace.SDFG, + graph: dace.SDFGState, sdfg: dace.SDFG, ) -> None: # Removal @@ -971,6 +971,9 @@ def apply( tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) assert glob_in_subset is not None + # Recursively visit the nested SDFGs for mapping of strides from inner to outer array + gtx_transformations.gt_map_strides_to_src_nested_sdfg(sdfg, graph, map_to_tmp_edge, glob_ac) + # We now remove the `tmp` node, and create a new connection between # the global node and the map exit. new_map_to_glob_edge = graph.add_edge( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 4e254f2880..980b2a8fdf 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -6,14 +6,30 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from typing import Optional, TypeAlias + import dace from dace import data as dace_data +from dace.sdfg import nodes as dace_nodes from gt4py.next.program_processors.runners.dace_fieldview import ( transformations as gtx_transformations, ) +PropagatedStrideRecord: TypeAlias = tuple[str, dace_nodes.NestedSDFG] +"""Record of a stride that has been propagated into a NestedSDFG. + +The type combines the NestedSDFG into which the strides were already propagated +and the data within that NestedSDFG to which we have propagated the strides, +which is the connector name on the NestedSDFG. +We need the NestedSDFG because we have to know what was already processed, +however, we also need the inner array name because of aliasing, i.e. a data +descriptor on the outside could be mapped to multiple data descriptors +inside the NestedSDFG. +""" + + def gt_change_transient_strides( sdfg: dace.SDFG, gpu: bool, @@ -24,6 +40,11 @@ def gt_change_transient_strides( transients in the optimal way. The function should run after all maps have been created. + After the strides have been adjusted the function will also propagate + the strides into nested SDFG. This propagation will happen with + `ignore_symbol_mapping` set to `True`, see `gt_propagate_strides_of()` + for more. + Args: sdfg: The SDFG to process. gpu: If the SDFG is supposed to run on the GPU. @@ -35,8 +56,6 @@ def gt_change_transient_strides( Todo: - Implement the estimation correctly. - - Handle the case of nested SDFGs correctly; on the outside a transient, - but on the inside a non transient. """ # TODO(phimeull): Implement this function correctly. @@ -46,54 +65,608 @@ def gt_change_transient_strides( return sdfg for nsdfg in sdfg.all_sdfgs_recursive(): - # TODO(phimuell): Handle the case when transient goes into nested SDFG - # on the inside it is a non transient, so it is ignored. _gt_change_transient_strides_non_recursive_impl(nsdfg) def _gt_change_transient_strides_non_recursive_impl( sdfg: dace.SDFG, ) -> None: - """Essentially this function just changes the stride to FORTRAN order.""" - for top_level_transient in _find_toplevel_transients(sdfg, only_arrays=True): + """Set optimal strides of all transients in the SDFG. + + The function will look for all top level transients, see `_gt_find_toplevel_data_accesses()` + and set their strides such that the access is optimal, see Note. The function + will also run `gt_propagate_strides_of()` to propagate the strides into nested SDFGs. + + This function should never be called directly but always through + `gt_change_transient_strides()`! + + Note: + Currently the function just reverses the strides of the data descriptor + it processes. Since DaCe generates `C` order by default this lead to + FORTRAN order, which is (for now) sufficient to optimize the memory + layout to GPU. + + Todo: + Make this function more intelligent to analyse the access pattern and then + figuring out the best order. + """ + # NOTE: Processing the transient here is enough. If we are inside a + # NestedSDFG then they were handled before on the level above us. + top_level_transients_and_their_accesses = _gt_find_toplevel_data_accesses( + sdfg=sdfg, + only_transients=True, + only_arrays=True, + ) + for top_level_transient, accesses in top_level_transients_and_their_accesses.items(): desc: dace_data.Array = sdfg.arrays[top_level_transient] + + # Setting the strides only make sense if we have more than one dimensions ndim = len(desc.shape) if ndim <= 1: continue + # We assume that everything is in C order initially, to get FORTRAN order # we simply have to reverse the order. + # TODO(phimuell): Improve this. new_stride_order = list(range(ndim)) desc.set_strides_from_layout(*new_stride_order) + # Now we have to propagate the changed strides. Because we already have + # collected all the AccessNodes we are using the + # `gt_propagate_strides_from_access_node()` function, but we have to + # create `processed_nsdfg` set already outside here. + # Furthermore, the same comment as above applies here, we do not have to + # propagate the non-transients, because they either come from outside, + # or they were already handled in the levels above, where they were + # defined and then propagated down. + # TODO(phimuell): Updated the functions such that only one scan is needed. + processed_nsdfgs: set[dace_nodes.NestedSDFG] = set() + for state, access_node in accesses: + gt_propagate_strides_from_access_node( + sdfg=sdfg, + state=state, + outer_node=access_node, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=True, + ) + + +def gt_propagate_strides_of( + sdfg: dace.SDFG, + data_name: str, + ignore_symbol_mapping: bool = True, +) -> None: + """Propagates the strides of `data_name` within the whole SDFG. + + This function will call `gt_propagate_strides_from_access_node()` for every + AccessNode that refers to `data_name`. It will also make sure that a descriptor + inside a NestedSDFG is only processed once. + + Args: + sdfg: The SDFG on which we operate. + data_name: Name of the data descriptor that should be handled. + ignore_symbol_mapping: If `False` (default is `True`) try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + """ + + # Defining it here ensures that we will not enter an NestedSDFG multiple times. + processed_nsdfgs: set[PropagatedStrideRecord] = set() + + for state in sdfg.states(): + for dnode in state.data_nodes(): + if dnode.data != data_name: + continue + gt_propagate_strides_from_access_node( + sdfg=sdfg, + state=state, + outer_node=dnode, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def gt_propagate_strides_from_access_node( + sdfg: dace.SDFG, + state: dace.SDFGState, + outer_node: dace_nodes.AccessNode, + ignore_symbol_mapping: bool = True, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, +) -> None: + """Propagates the stride of `outer_node` to any adjacent NestedSDFG. + + The function will propagate the strides of the data descriptor `outer_node` + refers to along all adjacent edges of `outer_node`. If one of these edges + leads to a NestedSDFG then the function will modify the strides of data + descriptor within to match the strides on the outside. The function will then + recursively process NestedSDFG. + + It is important that this function will only handle the NestedSDFGs that are + reachable from `outer_node`. To fully propagate the strides the + `gt_propagate_strides_of()` should be used. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that reads from the data node, the nested SDFG is expected as the destination. + outer_node: The data node whose strides should be propagated. + ignore_symbol_mapping: If `False` (default is `True`), try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + processed_nsdfgs: Set of NestedSDFG that were already processed and will be ignored. + Only specify when you know what your are doing. + """ + if processed_nsdfgs is None: + # For preventing the case that nested SDFGs are handled multiple time. + processed_nsdfgs = set() + + for in_edge in state.in_edges(outer_node): + gt_map_strides_to_src_nested_sdfg( + sdfg=sdfg, + state=state, + edge=in_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + for out_edge in state.out_edges(outer_node): + gt_map_strides_to_dst_nested_sdfg( + sdfg=sdfg, + state=state, + edge=out_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def gt_map_strides_to_dst_nested_sdfg( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.Edge, + outer_node: dace.nodes.AccessNode, + ignore_symbol_mapping: bool = True, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, +) -> None: + """Propagates the strides of `outer_node` along `edge` in the dataflow direction. + + In this context "along the dataflow direction" means that `edge` is an outgoing + edge of `outer_node` and the strides are propagated into all NestedSDFGs that + are downstream of `outer_node`. + + Except in certain cases this function should not be used directly. It is + instead recommended to use `gt_propagate_strides_of()`, which propagates + all edges in the SDFG. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that writes to the data node, the nested SDFG is expected as the source. + outer_node: The data node whose strides should be propagated. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + processed_nsdfgs: Set of NestedSDFGs that were already processed. Only specify when + you know what your are doing. + """ + assert edge.src is outer_node + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=True, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def gt_map_strides_to_src_nested_sdfg( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.Edge, + outer_node: dace.nodes.AccessNode, + ignore_symbol_mapping: bool = False, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, +) -> None: + """Propagates the strides of `outer_node` along `edge` in the opposite direction of the dataflow + + In this context "in the opposite direction of the dataflow" means that `edge` + is an incoming edge of `outer_node` and the strides are propagated into all + NestedSDFGs that are upstream of `outer_node`. + + Except in certain cases this function should not be used directly. It is + instead recommended to use `gt_propagate_strides_of()`, which propagates + all edges in the SDFG. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that writes to the data node, the nested SDFG is expected as the source. + outer_node: The data node whose strides should be propagated. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + processed_nsdfgs: Set of NestedSDFGs that were already processed. Only specify when + you know what your are doing. + """ + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=False, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def _gt_map_strides_to_nested_sdfg_src_dst( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + outer_node: dace.nodes.AccessNode, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]], + propagate_along_dataflow: bool, + ignore_symbol_mapping: bool = False, +) -> None: + """Propagates the stride of `outer_node` along `edge`. + + The function will follow `edge`, the direction depends on the value of + `propagate_along_dataflow` and propagate the strides of `outer_node` + into every NestedSDFG that is reachable by following `edge`. + + When the function encounters a NestedSDFG it will determine what data + the `outer_node` is mapped to on the inside of the NestedSDFG. + It will then replace the stride of the inner descriptor with the ones + of the outside. Afterwards it will recursively propagate the strides + inside the NestedSDFG. + During this propagation the function will follow any edges. + + If the function reaches a NestedSDFG that is listed inside `processed_nsdfgs` + then it will be skipped. NestedSDFGs that have been processed will be added + to the `processed_nsdfgs`. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that reads from the data node, the nested SDFG is expected as the destination. + outer_node: The data node whose strides should be propagated. + processed_nsdfgs: Set of Nested SDFG that were already processed and will be ignored. + Only specify when you know what your are doing. + propagate_along_dataflow: Determine the direction of propagation. If `True` the + function follows the dataflow. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + + Note: + A user should not use this function directly, instead `gt_propagate_strides_of()`, + `gt_map_strides_to_src_nested_sdfg()` (`propagate_along_dataflow == `False`) + or `gt_map_strides_to_dst_nested_sdfg()` (`propagate_along_dataflow == `True`) + should be used. + + Todo: + Try using `MemletTree` for the propagation. + """ + # If `processed_nsdfg` is `None` then this is the first call. We will now + # allocate the `set` and pass it as argument to all recursive calls, this + # ensures that the `set` is the same everywhere. + if processed_nsdfgs is None: + processed_nsdfgs = set() + + if propagate_along_dataflow: + # Propagate along the dataflow or forward, so we are interested at the `dst` of the edge. + ScopeNode = dace_nodes.MapEntry + + def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_nodes.Node: + return edge.dst + + def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str: + return edge.dst_conn + + def get_subset( + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> dace.subsets.Subset: + return edge.data.get_src_subset(edge, state) -def _find_toplevel_transients( + def next_edges_by_connector( + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: + if edge.dst_conn is None or not edge.dst_conn.startswith("IN_"): + return [] + return list(state.out_edges_by_connector(edge.dst, "OUT_" + edge.dst_conn[3:])) + + else: + # Propagate against the dataflow or backward, so we are interested at the `src` of the edge. + ScopeNode = dace_nodes.MapExit + + def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_nodes.Node: + return edge.src + + def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str: + return edge.src_conn + + def get_subset( + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> dace.subsets.Subset: + return edge.data.get_dst_subset(edge, state) + + def next_edges_by_connector( + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: + return list(state.in_edges_by_connector(edge.src, "IN_" + edge.src_conn[4:])) + + if isinstance(get_node(edge), ScopeNode): + for next_edge in next_edges_by_connector(state, edge): + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=next_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=propagate_along_dataflow, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + elif isinstance(get_node(edge), dace.nodes.NestedSDFG): + nsdfg_node = get_node(edge) + inner_data = get_inner_data(edge) + process_record = (inner_data, nsdfg_node) + + if process_record in processed_nsdfgs: + # We already handled this NestedSDFG and the inner data. + return + + # Mark this nested SDFG as processed. + processed_nsdfgs.add(process_record) + + # Now set the stride of the data descriptor inside the nested SDFG to + # the ones it has outside. + _gt_map_strides_into_nested_sdfg( + sdfg=sdfg, + nsdfg_node=nsdfg_node, + inner_data=inner_data, + outer_subset=get_subset(state, edge), + outer_desc=outer_node.desc(sdfg), + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + # Since the function call above is not recursive we have now to propagate + # the change into the NestedSDFGs. Using `_gt_find_toplevel_data_accesses()` + # is a bit overkill, but allows for a more uniform processing. + # TODO(phimuell): Instead of scanning every level for every data we modify + # we should scan the whole SDFG once and then reuse this information. + accesses_in_nested_sdfg = _gt_find_toplevel_data_accesses( + sdfg=nsdfg_node.sdfg, + only_transients=False, # Because on the nested levels they are globals. + only_arrays=True, + ) + for nested_state, nested_access in accesses_in_nested_sdfg.get(inner_data, list()): + # We have to use `gt_propagate_strides_from_access_node()` here because we + # have to handle its entirety. We could wait until the other branch processes + # the nested SDFG, but this might not work, so let's do it fully now. + gt_propagate_strides_from_access_node( + sdfg=nsdfg_node.sdfg, + state=nested_state, + outer_node=nested_access, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def _gt_map_strides_into_nested_sdfg( sdfg: dace.SDFG, + nsdfg_node: dace.nodes.NestedSDFG, + inner_data: str, + outer_subset: dace.subsets.Subset, + outer_desc: dace_data.Data, + ignore_symbol_mapping: bool, +) -> None: + """Modify the strides of `inner_data` inside `nsdfg_node` to match `outer_desc`. + + `inner_data` is the name of a data descriptor inside the NestedSDFG. + The function will then modify the strides of `inner_data`, assuming this + is an array, to match the ones of `outer_desc`. + + Args: + sdfg: The SDFG containing the NestedSDFG. + nsdfg_node: The node in the parent SDFG that contains the NestedSDFG. + inner_data: The name of the data descriptor that should be processed + inside the NestedSDFG (by construction also a connector name). + outer_subset: The subset that describes what part of the outer data is + mapped into the NestedSDFG. + outer_desc: The data descriptor of the data on the outside. + ignore_symbol_mapping: If possible the function will perform the renaming + through the `symbol_mapping` of the nested SDFG. If `True` then + the function will always perform the renaming. + Note that setting this value to `False` might have negative side effects. + + Todo: + - Handle explicit dimensions of size 1. + - What should we do if the stride symbol is used somewhere else, creating an + alias is probably not the right thing? + - Handle the case if the outer stride symbol is already used in another + context inside the Neste SDFG. + """ + # We need to compute the new strides. In the following we assume that the + # relative order of the dimensions does not change, but we support the case + # where some dimensions of the outer data descriptor are not present on the + # inside. For example this happens for the Memlet `a[__i0, 0:__a_size1]`. We + # detect this case by checking if the Memlet subset in that dimension has size 1. + # TODO(phimuell): Handle the case were some additional size 1 dimensions are added. + inner_desc: dace_data.Data = nsdfg_node.sdfg.arrays[inner_data] + inner_shape = inner_desc.shape + inner_strides_init = inner_desc.strides + + outer_strides = outer_desc.strides + outer_inflow = outer_subset.size() + + new_strides: list = [] + for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): + if dim_oinflow == 1: + # This is the case of implicit slicing along one dimension. + pass + else: + # There is inflow into the SDFG, so we need the stride. + new_strides.append(dim_ostride) + assert len(new_strides) <= len(inner_shape) + + # If we have a scalar on the inside, then there is nothing to adjust. + # We could have performed the test above, but doing it here, gives us + # the chance of validating it. + if isinstance(inner_desc, dace_data.Scalar): + if len(new_strides) != 0: + raise ValueError(f"Dimensional error for '{inner_data}' in '{nsdfg_node.label}'.") + return + + if not isinstance(inner_desc, dace_data.Array): + raise TypeError( + f"Expected that '{inner_data}' is an 'Array' but it is '{type(inner_desc).__name__}'." + ) + + if len(new_strides) != len(inner_shape): + raise ValueError("Failed to compute the inner strides.") + + # Now we actually replace the strides, there are two ways of doing it. + # The first is to create an alias in the `symbol_mapping`, however, + # this is only possible if the current strides are singular symbols, + # like `__a_strides_1`, but not expressions such as `horizontal_end - horizontal_start` + # or literal values. Furthermore, this would change the meaning of the + # old stride symbol in any context and not only in the one of the stride + # of a single and isolated data descriptor. + # The second way would be to replace `strides` attribute of the + # inner data descriptor. In case the new stride consists of expressions + # such as `value1 - value2` we have to make them available inside the + # NestedSDFG. However, it could be that the strides is used somewhere else. + # We will do the following, if `ignore_symbol_mapping` is `False` and + # the strides of the inner descriptors are symbols, we will use the + # symbol mapping. Otherwise, we will replace the `strides` attribute + # of the inner descriptor, in addition we will install a remapping, + # for those values that were a symbol. + if (not ignore_symbol_mapping) and all( + isinstance(inner_stride, dace.symbol) for inner_stride in inner_strides_init + ): + # Use the symbol + for inner_stride, outer_stride in zip(inner_desc.strides, new_strides, strict=True): + nsdfg_node.symbol_mapping[inner_stride.name] = outer_stride + else: + # We have to replace the `strides` attribute of the inner descriptor. + inner_desc.set_shape(inner_desc.shape, new_strides) + + # Now find the free symbols that the new strides need. + # Note that usually `free_symbols` returns `set[str]`, but here, because + # we fall back on SymPy, we get back symbols. We will keep them, because + # then we can use them to extract the type form them, which we need later. + new_strides_symbols: list[dace.symbol] = [] + for new_stride_dim in new_strides: + if dace.symbolic.issymbolic(new_stride_dim): + new_strides_symbols.extend(sym for sym in new_stride_dim.free_symbols) + else: + # It is not already a symbol, so we turn it into a symbol. + # However, we only add it, if it is also a symbol, for example `1`. + # should not be added. + new_stride_symbol = dace.symbolic.pystr_to_symbolic(new_stride_dim) + if new_stride_symbol.is_symbol: + new_strides_symbols.append(new_stride_symbol) + + # Now we determine the set of symbols that should be mapped inside the NestedSDFG. + # We will exclude all that are already inside the `symbol_mapping` (we do not + # check if they map to the same value, we just hope it). Furthermore, + # we will exclude all symbols that are listed in the `symbols` property + # of the SDFG that is nested, and hope that it has the same meaning. + # TODO(phimuell): Add better checks to avoid overwriting. + missing_symbol_mappings: set[dace.symbol] = { + sym + for sym in new_strides_symbols + if not (sym.name in nsdfg_node.sdfg.symbols or sym.name in nsdfg_node.symbol_mapping) + } + + # Now propagate the symbols from the parent SDFG to the NestedSDFG. + for sym in missing_symbol_mappings: + assert sym.name in sdfg.symbols, f"Expected that '{sym}' is defined in the parent SDFG." + nsdfg_node.sdfg.add_symbol(sym.name, sdfg.symbols[sym.name]) + nsdfg_node.symbol_mapping[sym.name] = sym + + +def _gt_find_toplevel_data_accesses( + sdfg: dace.SDFG, + only_transients: bool, only_arrays: bool = False, -) -> set[str]: - """Find all top level transients in the SDFG. +) -> dict[str, list[tuple[dace.SDFGState, dace_nodes.AccessNode]]]: + """Find all data that is accessed on the top level. The function will scan the SDFG, ignoring nested one, and return the - name of all transients that have an access node at the top level. - However, it will ignore access nodes that refers to registers. + name of all data that only have AccessNodes on the top level. In data + is found that has an AccessNode on both the top level and in a nested + scope and error is generated. + By default the function will return transient and non transient data, + however, if `only_transients` is `True` then only transient data will + be returned. + Furthermore, the function will ignore an access in the following cases: + - The AccessNode refers to data that is a register. + - The AccessNode refers to a View. + + Args: + sdfg: The SDFG to process. + only_transients: If `True` only include transients. + only_arrays: If `True`, defaults to `False`, only arrays are returned. + + Returns: + A `dict` that maps the name of a data container, to a list of tuples + containing the state where the AccessNode was found and the AccessNode. """ - top_level_transients: set[str] = set() + # List of data that is accessed on the top level and all its access node. + top_level_data: dict[str, list[tuple[dace.SDFGState, dace_nodes.AccessNode]]] = dict() + + # List of all data that were found not on top level. + not_top_level_data: set[str] = set() + for state in sdfg.states(): scope_dict = state.scope_dict() for dnode in state.data_nodes(): data: str = dnode.data if scope_dict[dnode] is not None: - if data in top_level_transients: - top_level_transients.remove(data) + # The node was not found on the top level. So we can ignore it. + # We also check if it was ever found on the top level, this should + # not happen, as everything should go through Maps. But some strange + # DaCe transformation might do it. + assert ( + data not in top_level_data + ), f"Found {data} on the top level and inside a scope." + not_top_level_data.add(data) continue - elif data in top_level_transients: + + elif data in top_level_data: + # The data is already known to be in top level data, so we must add the + # AccessNode to the list of known nodes. But nothing else. + top_level_data[data].append((state, dnode)) continue + elif gtx_transformations.util.is_view(dnode, sdfg): + # The AccessNode refers to a View so we ignore it anyway. continue + + # We have found a new data node that is on the top node and is unknown. + assert ( + data not in not_top_level_data + ), f"Found {data} on the top level and inside a scope." desc: dace_data.Data = dnode.desc(sdfg) - if not desc.transient: + # Check if we only accept arrays + if only_arrays and not isinstance(desc, dace_data.Array): continue - elif only_arrays and not isinstance(desc, dace_data.Array): + + # For now we ignore registers. + # We do this because register are allocated on the stack, so the compiler + # has all information and should organize the best thing possible. + # TODO(phimuell): verify this. + elif desc.storage is dace.StorageType.Register: continue - top_level_transients.add(data) - return top_level_transients + + # We are only interested in transients + if only_transients and (not desc.transient): + continue + + # Now create the new entry in the list and record the AccessNode. + top_level_data[data] = [(state, dnode)] + return top_level_data diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py index 1a4ce6d047..a98eac3c2c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py @@ -22,10 +22,6 @@ import dace -def _make_test_data(names: list[str]) -> dict[str, np.ndarray]: - return {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in names} - - def _make_test_sdfg( output_name: str = "G", input_name: str = "G", @@ -262,3 +258,92 @@ def test_map_buffer_elimination_not_apply(): validate_all=True, ) assert count == 0 + + +def test_map_buffer_elimination_with_nested_sdfgs(): + """ + After removing a transient connected to a nested SDFG node, ensure that the strides + are propagated to the arrays in nested SDFG. + """ + + stride1, stride2, stride3 = [dace.symbol(f"stride{i}", dace.int32) for i in range(3)] + + # top-level sdfg + sdfg = dace.SDFG(util.unique_name("map_buffer")) + inp, inp_desc = sdfg.add_array("__inp", (10,), dace.float64) + out, out_desc = sdfg.add_array( + "__out", (10, 10, 10), dace.float64, strides=(stride1, stride2, stride3) + ) + tmp, _ = sdfg.add_temp_transient_like(out_desc) + state = sdfg.add_state() + tmp_node = state.add_access(tmp) + + nsdfg1 = dace.SDFG(util.unique_name("map_buffer")) + inp1, inp1_desc = nsdfg1.add_array("__inp", (10,), dace.float64) + out1, out1_desc = nsdfg1.add_array("__out", (10, 10), dace.float64) + tmp1, _ = nsdfg1.add_temp_transient_like(out1_desc) + state1 = nsdfg1.add_state() + tmp1_node = state1.add_access(tmp1) + + nsdfg2 = dace.SDFG(util.unique_name("map_buffer")) + inp2, _ = nsdfg2.add_array("__inp", (10,), dace.float64) + out2, out2_desc = nsdfg2.add_array("__out", (10,), dace.float64) + tmp2, _ = nsdfg2.add_temp_transient_like(out2_desc) + state2 = nsdfg2.add_state() + tmp2_node = state2.add_access(tmp2) + + state2.add_mapped_tasklet( + "broadcast2", + map_ranges={"__i": "0:10"}, + code="__oval = __ival + 1.0", + inputs={ + "__ival": dace.Memlet(f"{inp2}[__i]"), + }, + outputs={ + "__oval": dace.Memlet(f"{tmp2}[__i]"), + }, + output_nodes={tmp2_node}, + external_edges=True, + ) + state2.add_nedge(tmp2_node, state2.add_access(out2), dace.Memlet.from_array(out2, out2_desc)) + + nsdfg2_node = state1.add_nested_sdfg(nsdfg2, nsdfg1, inputs={"__inp"}, outputs={"__out"}) + me1, mx1 = state1.add_map("broadcast1", ndrange={"__i": "0:10"}) + state1.add_memlet_path( + state1.add_access(inp1), + me1, + nsdfg2_node, + dst_conn="__inp", + memlet=dace.Memlet.from_array(inp1, inp1_desc), + ) + state1.add_memlet_path( + nsdfg2_node, mx1, tmp1_node, src_conn="__out", memlet=dace.Memlet(f"{tmp1}[__i, 0:10]") + ) + state1.add_nedge(tmp1_node, state1.add_access(out1), dace.Memlet.from_array(out1, out1_desc)) + + nsdfg1_node = state.add_nested_sdfg(nsdfg1, sdfg, inputs={"__inp"}, outputs={"__out"}) + me, mx = state.add_map("broadcast", ndrange={"__i": "0:10"}) + state.add_memlet_path( + state.add_access(inp), + me, + nsdfg1_node, + dst_conn="__inp", + memlet=dace.Memlet.from_array(inp, inp_desc), + ) + state.add_memlet_path( + nsdfg1_node, mx, tmp_node, src_conn="__out", memlet=dace.Memlet(f"{tmp}[__i, 0:10, 0:10]") + ) + state.add_nedge(tmp_node, state.add_access(out), dace.Memlet.from_array(out, out_desc)) + + sdfg.validate() + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMapBufferElimination( + assume_pointwise=False, + ), + validate=True, + validate_all=True, + ) + assert count == 3 + assert out1_desc.strides == out_desc.strides[1:] + assert out2_desc.strides == out_desc.strides[2:] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py new file mode 100644 index 0000000000..5b16e41bc3 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -0,0 +1,541 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np +import copy + +dace = pytest.importorskip("dace") +from dace import symbolic as dace_symbolic +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_strides_propagation_level3_sdfg() -> dace.SDFG: + """Generates the level 3 SDFG (nested-nested) SDFG for `test_strides_propagation()`.""" + sdfg = dace.SDFG(util.unique_name("level3")) + state = sdfg.add_state(is_start_block=True) + names = ["a3", "c3"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + state.add_mapped_tasklet( + "compL3", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a3[__i0]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("c3[__i0]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_level2_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + """Generates the level 2 SDFG (nested) SDFG for `test_strides_propagation()`. + + The function returns the level 2 SDFG and the NestedSDFG node that contains + the level 3 SDFG. + """ + sdfg = dace.SDFG(util.unique_name("level2")) + state = sdfg.add_state(is_start_block=True) + names = ["a2", "a2_alias", "b2", "c2"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_symbol(stride_name, dace.int64) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + state.add_mapped_tasklet( + "compL2_1", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a2[__i0]")}, + code="__out = __in1 + 10", + outputs={"__out": dace.Memlet("b2[__i0]")}, + external_edges=True, + ) + + state.add_mapped_tasklet( + "compL2_2", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("c2[__i0]")}, + code="__out = __in1", + outputs={"__out": dace.Memlet("a2_alias[__i0]")}, + external_edges=True, + ) + + # This is the nested SDFG we have here. + sdfg_level3 = _make_strides_propagation_level3_sdfg() + + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level3, + parent=sdfg, + inputs={"a3"}, + outputs={"c3"}, + symbol_mapping={s3: s3 for s3 in sdfg_level3.free_symbols}, + ) + + state.add_edge(state.add_access("a2"), None, nsdfg, "a3", dace.Memlet("a2[0:10]")) + state.add_edge(nsdfg, "c3", state.add_access("c2"), None, dace.Memlet("c2[0:10]")) + sdfg.validate() + + return sdfg, nsdfg + + +def _make_strides_propagation_level1_sdfg() -> ( + tuple[dace.SDFG, dace_nodes.NestedSDFG, dace_nodes.NestedSDFG] +): + """Generates the level 1 SDFG (top) SDFG for `test_strides_propagation()`. + + Note that the SDFG is valid, but will be indeterminate. The only point of + this SDFG is to have a lot of different situations that have to be handled + for renaming. + + Returns: + A tuple of length three, with the following members: + - The top level SDFG. + - The NestedSDFG node that contains the level 2 SDFG (member of the top level SDFG). + - The NestedSDFG node that contains the lebel 3 SDFG (member of the level 2 SDFG). + """ + + sdfg = dace.SDFG(util.unique_name("level1")) + state = sdfg.add_state(is_start_block=True) + names = ["a1", "b1", "c1"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_symbol(stride_name, dace.int64) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + sdfg_level2, nsdfg_level3 = _make_strides_propagation_level2_sdfg() + + nsdfg_level2: dace_nodes.NestedSDFG = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg, + inputs={"a2", "c2"}, + outputs={"a2_alias", "b2", "c2"}, + symbol_mapping={s: s for s in sdfg_level2.free_symbols}, + ) + + for inner_name in nsdfg_level2.in_connectors: + outer_name = inner_name[0] + "1" + state.add_edge( + state.add_access(outer_name), + None, + nsdfg_level2, + inner_name, + dace.Memlet(f"{outer_name}[0:10]"), + ) + for inner_name in nsdfg_level2.out_connectors: + outer_name = inner_name[0] + "1" + state.add_edge( + nsdfg_level2, + inner_name, + state.add_access(outer_name), + None, + dace.Memlet(f"{outer_name}[0:10]"), + ) + + sdfg.validate() + + return sdfg, nsdfg_level2, nsdfg_level3 + + +def test_strides_propagation_use_symbol_mapping(): + # Note that the SDFG we are building here is not really meaningful. + sdfg_level1, nsdfg_level2, nsdfg_level3 = _make_strides_propagation_level1_sdfg() + + # Tests if all strides are distinct in the beginning and match what we expect. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname}_stride" + actual_stride = adesc.strides[0] + assert len(adesc.strides) == 1 + assert ( + str(actual_stride) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert exp_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[exp_stride]) == exp_stride + + # Now we propagate `a` and `b`, but not `c`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=False) + sdfg_level1.validate() + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=False) + sdfg_level1.validate() + + # Because `ignore_symbol_mapping=False` the strides of the data descriptor should + # not have changed. But the `symbol_mapping` has been updated for `a` and `b`. + # However, the symbols will only point one level above. + for level, sdfg in enumerate([sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg], start=1): + for aname, adesc in sdfg.arrays.items(): + nsdfg = sdfg.parent_nsdfg_node + original_stride = f"{aname}_stride" + + if aname.startswith("c"): + target_symbol = f"{aname}_stride" + else: + target_symbol = f"{aname[0]}{level - 1}_stride" + + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == target_symbol + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == original_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + # Now we also propagate `c` thus now all data descriptors have the same stride + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=False) + sdfg_level1.validate() + for level, sdfg in enumerate([sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg], start=1): + for aname, adesc in sdfg.arrays.items(): + nsdfg = sdfg.parent_nsdfg_node + original_stride = f"{aname}_stride" + target_symbol = f"{aname[0]}{level-1}_stride" + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == target_symbol + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == original_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + +def test_strides_propagation_ignore_symbol_mapping(): + # Note that the SDFG we are building here is not really meaningful. + sdfg_level1, nsdfg_level2, nsdfg_level3 = _make_strides_propagation_level1_sdfg() + + # Tests if all strides are distinct in the beginning and match what we expect. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname}_stride" + actual_stride = adesc.strides[0] + assert len(adesc.strides) == 1 + assert ( + str(actual_stride) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert exp_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[exp_stride]) == exp_stride + + # Now we propagate `a` and `b`, but not `c`. + # TODO(phimuell): Create a version where we can set `ignore_symbol_mapping=False`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True) + sdfg_level1.validate() + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=True) + sdfg_level1.validate() + + # After the propagation `a` and `b` should use the same stride (the one that + # it has on level 1, but `c` should still be level depending. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + original_stride = f"{aname}_stride" + if aname.startswith("c"): + exp_stride = f"{aname}_stride" + else: + exp_stride = f"{aname[0]}1_stride" + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == original_stride + + # Now we also propagate `c` thus now all data descriptors have the same stride + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=True) + sdfg_level1.validate() + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname[0]}1_stride" + original_stride = f"{aname}_stride" + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + # The symbol mapping must should not be updated. + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == original_stride + + +def _make_strides_propagation_dependent_symbol_nsdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("strides_propagation_dependent_symbol_nsdfg")) + state = sdfg.add_state(is_start_block=True) + + array_names = ["a2", "b2"] + for name in array_names: + stride_sym = dace.symbol(f"{name}_stride", dtype=dace.uint64) + sdfg.add_symbol(stride_sym.name, stride_sym.dtype) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + strides=(stride_sym,), + transient=False, + ) + + state.add_mapped_tasklet( + "nested_comp", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a2[__i0]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("b2[__i0]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_dependent_symbol_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_dependent_symbol_sdfg")) + state = sdfg_level1.add_state(is_start_block=True) + + array_names = ["a1", "b1"] + for name in array_names: + stride_sym1 = dace.symbol(f"{name}_1stride", dtype=dace.uint64) + stride_sym2 = dace.symbol(f"{name}_2stride", dtype=dace.int64) + sdfg_level1.add_symbol(stride_sym1.name, stride_sym1.dtype) + sdfg_level1.add_symbol(stride_sym2.name, stride_sym2.dtype) + stride_sym = stride_sym1 * stride_sym2 + sdfg_level1.add_array( + name, + shape=(10,), + dtype=dace.float64, + strides=(stride_sym,), + transient=False, + ) + + sdfg_level2 = _make_strides_propagation_dependent_symbol_nsdfg() + + for sym, sym_dtype in sdfg_level2.symbols.items(): + sdfg_level1.add_symbol(sym, sym_dtype) + + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg_level1, + inputs={"a2"}, + outputs={"b2"}, + symbol_mapping={s: s for s in sdfg_level2.symbols}, + ) + + state.add_edge(state.add_access("a1"), None, nsdfg, "a2", dace.Memlet("a1[0:10]")) + state.add_edge(nsdfg, "b2", state.add_access("b1"), None, dace.Memlet("b1[0:10]")) + sdfg_level1.validate() + + return sdfg_level1, nsdfg + + +def test_strides_propagation_dependent_symbol(): + sdfg_level1, nsdfg_level2 = _make_strides_propagation_dependent_symbol_sdfg() + sym1_dtype = dace.uint64 + sym2_dtype = dace.int64 + + # Ensure that the special symbols are not already present inside the nested SDFG. + for aname, adesc in sdfg_level1.arrays.items(): + sym1 = f"{aname}_1stride" + sym2 = f"{aname}_2stride" + for sym, dtype in [(sym1, sym1_dtype), (sym2, sym2_dtype)]: + assert sym in {fs.name for fs in adesc.strides[0].free_symbols} + assert sym not in nsdfg_level2.symbol_mapping + assert sym not in nsdfg_level2.sdfg.symbols + assert sym in sdfg_level1.symbols + assert sdfg_level1.symbols[sym] == dtype + + # Now propagate `a1` and `b1`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True) + sdfg_level1.validate() + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=True) + sdfg_level1.validate() + + # Now we check if the update has worked. + for aname, adesc in sdfg_level1.arrays.items(): + sym1 = f"{aname}_1stride" + sym2 = f"{aname}_2stride" + adesc2 = nsdfg_level2.sdfg.arrays[aname.replace("1", "2")] + assert adesc2.strides == adesc.strides + + for sym, dtype in [(sym1, sym1_dtype), (sym2, sym2_dtype)]: + assert sym in nsdfg_level2.symbol_mapping + assert nsdfg_level2.symbol_mapping[sym].name == sym + assert sym in sdfg_level1.symbols + assert sdfg_level1.symbols[sym] == dtype + assert sym in nsdfg_level2.sdfg.symbols + assert nsdfg_level2.sdfg.symbols[sym] == dtype + + +def _make_strides_propagation_shared_symbols_nsdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("strides_propagation_shared_symbols_nsdfg")) + state = sdfg.add_state(is_start_block=True) + + # NOTE: Both arrays have the same symbols used for strides. + array_names = ["a2", "b2"] + stride_sym0 = dace.symbol(f"__stride_0", dtype=dace.uint64) + stride_sym1 = dace.symbol(f"__stride_1", dtype=dace.uint64) + sdfg.add_symbol(stride_sym0.name, stride_sym0.dtype) + sdfg.add_symbol(stride_sym1.name, stride_sym1.dtype) + for name in array_names: + sdfg.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + strides=(stride_sym0, stride_sym1), + transient=False, + ) + + state.add_mapped_tasklet( + "nested_comp", + map_ranges={ + "__i0": "0:10", + "__i1": "0:10", + }, + inputs={"__in1": dace.Memlet("a2[__i0, __i1]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("b2[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_shared_symbols_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_shared_symbols_sdfg")) + state = sdfg_level1.add_state(is_start_block=True) + + # NOTE: Both arrays use the same symbols as strides. + # Furthermore, they are the same as in the nested SDFG, i.e. they are shared. + array_names = ["a1", "b1"] + stride_sym0 = dace.symbol(f"__stride_0", dtype=dace.uint64) + stride_sym1 = dace.symbol(f"__stride_1", dtype=dace.uint64) + sdfg_level1.add_symbol(stride_sym0.name, stride_sym0.dtype) + sdfg_level1.add_symbol(stride_sym1.name, stride_sym1.dtype) + for name in array_names: + sdfg_level1.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + strides=( + stride_sym0, + stride_sym1, + ), + transient=False, + ) + + sdfg_level2 = _make_strides_propagation_shared_symbols_nsdfg() + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg_level1, + inputs={"a2"}, + outputs={"b2"}, + symbol_mapping={s: s for s in sdfg_level2.symbols}, + ) + + state.add_edge(state.add_access("a1"), None, nsdfg, "a2", dace.Memlet("a1[0:10, 0:10]")) + state.add_edge(nsdfg, "b2", state.add_access("b1"), None, dace.Memlet("b1[0:10, 0:10]")) + sdfg_level1.validate() + + return sdfg_level1, nsdfg + + +def test_strides_propagation_shared_symbols_sdfg(): + """Tests what happens if symbols are (unintentionally) shred between descriptor. + + This test looks rather artificial, but it is actually quite likely. Because + transients will most likely have the same shape and if the strides are not + set explicitly, which is the case, the strides will also be related to their + shape. This test explores the situation, where we can, for whatever reason, + only propagate the strides of one such data descriptor. + + Note: + If `ignore_symbol_mapping` is `False` then this test will fail. + This is because the `symbol_mapping` of the NestedSDFG will act on the + whole SDFG. Thus it will not only change the strides of `b` but as an + unintended side effect also the strides of `a`. + """ + + def ref(a1, b1): + for i in range(10): + for j in range(10): + b1[i, j] = a1[i, j] + 10.0 + + sdfg_level1, nsdfg_level2 = _make_strides_propagation_shared_symbols_sdfg() + + res_args = { + "a1": np.array(np.random.rand(10, 10), order="C", dtype=np.float64, copy=True), + "b1": np.array(np.random.rand(10, 10), order="F", dtype=np.float64, copy=True), + } + ref_args = copy.deepcopy(res_args) + + # Now we change the strides of `b1`, and then we propagate the new strides + # into the nested SDFG. We want to keep (for whatever reasons) strides of `a1`. + stride_b1_sym0 = dace.symbol(f"__b1_stride_0", dtype=dace.uint64) + stride_b1_sym1 = dace.symbol(f"__b1_stride_1", dtype=dace.uint64) + sdfg_level1.add_symbol(stride_b1_sym0.name, stride_b1_sym0.dtype) + sdfg_level1.add_symbol(stride_b1_sym1.name, stride_b1_sym1.dtype) + + desc_b1 = sdfg_level1.arrays["b1"] + desc_b1.set_shape((10, 10), (stride_b1_sym0, stride_b1_sym1)) + + # Now we propagate the data into it. + gtx_transformations.gt_propagate_strides_of( + sdfg=sdfg_level1, + data_name="b1", + ) + + # Now we have to prepare the call arguments, i.e. adding the strides + itemsize = res_args["b1"].itemsize + res_args.update( + { + "__b1_stride_0": res_args["b1"].strides[0] // itemsize, + "__b1_stride_1": res_args["b1"].strides[1] // itemsize, + "__stride_0": res_args["a1"].strides[0] // itemsize, + "__stride_1": res_args["a1"].strides[1] // itemsize, + } + ) + ref(**ref_args) + sdfg_level1(**res_args) + assert np.allclose(ref_args["b1"], res_args["b1"]) From e8743dd357656f25c2b73884858bddc56f72d0a0 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 6 Jan 2025 10:38:26 +0100 Subject: [PATCH 081/178] ci: fix boost install in cartesian and daily ci plan (#1787) Boost download link expired, but actually no custom boost (header) installation is required. --- .github/workflows/daily-ci.yml | 7 ------- .github/workflows/test-cartesian.yml | 10 ++-------- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/.github/workflows/daily-ci.yml b/.github/workflows/daily-ci.yml index 30ad0a6ff9..7ece5a4d5e 100644 --- a/.github/workflows/daily-ci.yml +++ b/.github/workflows/daily-ci.yml @@ -34,13 +34,6 @@ jobs: shell: bash run: | sudo apt install libboost-dev - wget https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.gz - echo 7bd7ddceec1a1dfdcbdb3e609b60d01739c38390a5f956385a12f3122049f0ca boost_1_76_0.tar.gz > boost_hash.txt - sha256sum -c boost_hash.txt - tar xzf boost_1_76_0.tar.gz - mkdir -p boost/include - mv boost_1_76_0/boost boost/include/ - echo "BOOST_ROOT=${PWD}/boost" >> $GITHUB_ENV - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: diff --git a/.github/workflows/test-cartesian.yml b/.github/workflows/test-cartesian.yml index aa59660a68..f7e78ee6c1 100644 --- a/.github/workflows/test-cartesian.yml +++ b/.github/workflows/test-cartesian.yml @@ -29,16 +29,10 @@ jobs: tox-factor: [internal, dace] steps: - uses: actions/checkout@v4 - - name: Install boost + - name: Install C++ libraries shell: bash run: | - wget https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.gz - echo 7bd7ddceec1a1dfdcbdb3e609b60d01739c38390a5f956385a12f3122049f0ca boost_1_76_0.tar.gz > boost_hash.txt - sha256sum -c boost_hash.txt - tar xzf boost_1_76_0.tar.gz - mkdir -p boost/include - mv boost_1_76_0/boost boost/include/ - echo "BOOST_ROOT=${PWD}/boost" >> $GITHUB_ENV + sudo apt install libboost-dev - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: From 8040178d73d54ca0556b6ff7be5e4bce6b2d8ab9 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 7 Jan 2025 11:32:35 +0100 Subject: [PATCH 082/178] bug[next]: Fix propagated symbols order stability across runs (#1788) If statements in FOAST have a special `propagated_symbols` annex attributed that contains all symbols that are defined and hence available outside of the if. The compuation of these symbols is based on a set operation whose order is not stable across runs. This PR fixes that by sorting them. Since behavior changes across runs are hard to test and we don't have any facilities for that no test is provided. --- src/gt4py/next/ffront/field_operator_ast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 4693fed1a0..4f547aae14 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -180,7 +180,7 @@ class IfStmt(Stmt): @datamodels.root_validator @classmethod def _collect_common_symbols(cls: type[IfStmt], instance: IfStmt) -> None: - common_symbol_names = ( + common_symbol_names = sorted( # sort is required to get stable results across runs instance.true_branch.annex.symtable.keys() & instance.false_branch.annex.symtable.keys() ) instance.annex.propagated_symbols = { From 1601d87d3ea2b24b340165f3f1b8c37edafcd52e Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 10 Jan 2025 03:09:08 -0500 Subject: [PATCH 083/178] fix[cartesian]: Race condition in unit test for K write (#1791) The `interval` analysis in the unit test `test_K_offset_write_conditional` fails to catch a mistake in the code that leads to a race condition. Work: - Fix the bad interval - Remove not needed restriction on CUDA version Further work to fix the underlying problem and the larger issue of bound check on variable indexing is covered [here](https://github.com/GridTools/gt4py/issues/1684) and [there](https://github.com/GridTools/gt4py/issues/1754) --- .../cartesian/frontend/gtscript_frontend.py | 15 ----- .../test_code_generation.py | 60 ++++++++++++------- 2 files changed, 37 insertions(+), 38 deletions(-) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index f155ea6209..4d8ac98529 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -1451,21 +1451,6 @@ def visit_Assign(self, node: ast.Assign) -> list: message="Assignment to non-zero offsets in K is not available in PARALLEL. Choose FORWARD or BACKWARD.", loc=nodes.Location.from_ast_node(t), ) - if self.backend_name in ["gt:gpu", "dace:gpu"]: - import cupy as cp - - if cp.cuda.runtime.runtimeGetVersion() < 12000: - raise GTScriptSyntaxError( - message=f"Assignment to non-zero offsets in K is not available in {self.backend_name} for CUDA<12. Please update CUDA.", - loc=nodes.Location.from_ast_node(t), - ) - - if self.backend_name in ["gt:gpu"]: - raise GTScriptSyntaxError( - message=f"Assignment to non-zero offsets in K is not available in {self.backend_name} as an unsolved bug remains." - "Please refer to https://github.com/GridTools/gt4py/issues/1754.", - loc=nodes.Location.from_ast_node(t), - ) if not self._is_known(name): if name in self.temp_decls: diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 4609184547..5a43144b4b 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -582,17 +582,6 @@ def test_K_offset_write(backend): # Cuda generates bad code for the K offset if backend == "cuda": pytest.skip("cuda K-offset write generates bad code") - if backend in ["dace:gpu"]: - import cupy as cp - - if cp.cuda.runtime.runtimeGetVersion() < 12000: - pytest.skip( - f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" - ) - if backend in ["gt:gpu"]: - pytest.skip( - f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1754" - ) arraylib = get_array_library(backend) array_shape = (1, 1, 4) @@ -664,17 +653,6 @@ def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): def test_K_offset_write_conditional(backend): if backend == "cuda": pytest.skip("Cuda backend is not capable of K offset write") - if backend in ["dace:gpu"]: - import cupy as cp - - if cp.cuda.runtime.runtimeGetVersion() < 12000: - pytest.skip( - f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" - ) - if backend in ["gt:gpu"]: - pytest.skip( - f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1754" - ) arraylib = get_array_library(backend) array_shape = (1, 1, 4) @@ -682,7 +660,7 @@ def test_K_offset_write_conditional(backend): @gtscript.stencil(backend=backend) def column_physics_conditional(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): - with computation(BACKWARD), interval(1, None): + with computation(BACKWARD), interval(1, -1): if A > 0 and B > 0: A[0, 0, -1] = scalar B[0, 0, 1] = A @@ -700,6 +678,42 @@ def column_physics_conditional(A: Field[np.float64], B: Field[np.float64], scala backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 ) column_physics_conditional(A, B, 2.0) + # Manual unroll of the above + # Starts with + # - A[...] = [40, 41, 42, 43] + # - B[...] = [1, 1, 1, 1] + # Now in-stencil + # ITERATION k = 2 of [2:1] + # if condition + # - A[2] == 42 && B[2] == 1 => True + # - A[1] = 2.0 + # - B[3] = A[2] = 42 + # while + # - lev = 1 + # - A[2] == 42 && B[2] == 1 => True + # - A[3] = -1 + # - B[2] = -1 + # - lev = 2 + # - A[2] == 42 && B[2] == -1 => False + # End of iteration state + # - A[...] = A[40, 2.0, 2.0, -1] + # - B[...] = A[1, 1, -1, 42] + # ITERATION k = 1 of [2:1] + # if condition + # - A[1] == 2.0 && B[1] == 1 => True + # - A[0] = 2.0 + # - B[2] = A[1] = 2.0 + # while + # - lev = 1 + # - A[1] == 2.0 && B[1] == 1 => True + # - A[2] = -1 + # - B[1] = -1 + # - lev = 2 + # - A[1] == 2.0 && B[2] == -1 => False + # End of stencil state + # - A[...] = A[2.0, 2.0, -1, -1] + # - B[...] = A[1, -1, 2.0, 42] + assert (A[0, 0, :] == arraylib.array([2, 2, -1, -1])).all() assert (B[0, 0, :] == arraylib.array([1, -1, 2, 42])).all() From 3a1a40393071ffa4cef11f0d6e70683b2fd3ddf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Fri, 10 Jan 2025 09:19:40 +0100 Subject: [PATCH 084/178] build: drop support for python 3.8 and 3.9 (#1792) This PR only changes configuration files to drop support for old python versions. Refactorings and code changes will follow in different PRs. --- .github/workflows/_disabled/gt4py-sphinx.yml | 2 +- .github/workflows/daily-ci.yml | 2 +- .github/workflows/test-cartesian-fallback.yml | 2 +- .github/workflows/test-cartesian.yml | 2 +- .github/workflows/test-eve-fallback.yml | 2 +- .github/workflows/test-eve.yml | 2 +- .github/workflows/test-storage-fallback.yml | 2 +- .github/workflows/test-storage.yml | 2 +- .pre-commit-config.yaml | 28 ++-- README.md | 2 +- ci/cscs-ci.yml | 55 ------- constraints.txt | 150 +++++++++--------- min-extra-requirements-test.txt | 4 +- min-requirements-test.txt | 2 - pyproject.toml | 17 +- requirements-dev.txt | 149 +++++++++-------- src/gt4py/cartesian/gtc/common.py | 4 +- src/gt4py/eve/utils.py | 8 +- tox.ini | 33 ++-- 19 files changed, 194 insertions(+), 274 deletions(-) diff --git a/.github/workflows/_disabled/gt4py-sphinx.yml b/.github/workflows/_disabled/gt4py-sphinx.yml index d862ab7321..cb3b275787 100644 --- a/.github/workflows/_disabled/gt4py-sphinx.yml +++ b/.github/workflows/_disabled/gt4py-sphinx.yml @@ -22,7 +22,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: 3.10 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/daily-ci.yml b/.github/workflows/daily-ci.yml index 7ece5a4d5e..28512a18ac 100644 --- a/.github/workflows/daily-ci.yml +++ b/.github/workflows/daily-ci.yml @@ -15,7 +15,7 @@ jobs: daily-ci: strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11"] tox-module-factor: ["cartesian", "eve", "next", "storage"] os: ["ubuntu-latest"] requirements-file: ["requirements-dev.txt", "min-requirements-test.txt", "min-extra-requirements-test.txt"] diff --git a/.github/workflows/test-cartesian-fallback.yml b/.github/workflows/test-cartesian-fallback.yml index 45bbdf271a..76fd898159 100644 --- a/.github/workflows/test-cartesian-fallback.yml +++ b/.github/workflows/test-cartesian-fallback.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11"] tox-factor: [internal, dace] steps: diff --git a/.github/workflows/test-cartesian.yml b/.github/workflows/test-cartesian.yml index f7e78ee6c1..fd896c3d89 100644 --- a/.github/workflows/test-cartesian.yml +++ b/.github/workflows/test-cartesian.yml @@ -25,7 +25,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11"] tox-factor: [internal, dace] steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/test-eve-fallback.yml b/.github/workflows/test-eve-fallback.yml index 661118e71d..461400423f 100644 --- a/.github/workflows/test-eve-fallback.yml +++ b/.github/workflows/test-eve-fallback.yml @@ -18,7 +18,7 @@ jobs: test-eve: strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11"] os: ["ubuntu-latest"] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/test-eve.yml b/.github/workflows/test-eve.yml index bfd6d8e481..e83c4c563b 100644 --- a/.github/workflows/test-eve.yml +++ b/.github/workflows/test-eve.yml @@ -22,7 +22,7 @@ jobs: test-eve: strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11"] os: ["ubuntu-latest"] fail-fast: false diff --git a/.github/workflows/test-storage-fallback.yml b/.github/workflows/test-storage-fallback.yml index df861c6468..022c66b1f1 100644 --- a/.github/workflows/test-storage-fallback.yml +++ b/.github/workflows/test-storage-fallback.yml @@ -19,7 +19,7 @@ jobs: test-storage: strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11"] tox-factor: [internal, dace] os: ["ubuntu-latest"] diff --git a/.github/workflows/test-storage.yml b/.github/workflows/test-storage.yml index 2f85670eeb..bfe6e49d23 100644 --- a/.github/workflows/test-storage.yml +++ b/.github/workflows/test-storage.yml @@ -23,7 +23,7 @@ jobs: test-storage: strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11"] tox-factor: [internal, dace] os: ["ubuntu-latest"] fail-fast: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e383112310..051781ea49 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,7 +50,7 @@ repos: ## version = re.search('ruff==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: v{version}") ##]]] - rev: v0.8.2 + rev: v0.8.6 ##[[[end]]] hooks: # Run the linter. @@ -72,9 +72,9 @@ repos: ## version = re.search('mypy==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"#========= FROM constraints.txt: v{version} =========") ##]]] - #========= FROM constraints.txt: v1.13.0 ========= + #========= FROM constraints.txt: v1.14.1 ========= ##[[[end]]] - rev: v1.13.0 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) + rev: v1.14.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) hooks: - id: mypy additional_dependencies: # versions from constraints.txt @@ -90,31 +90,29 @@ repos: ## for pkg in packages: ## print(f"- {pkg}==" + str(re.search(f'\n{pkg}==([0-9\.]*)', constraints)[1])) ##]]] - - astunparse==1.6.3 - - attrs==24.2.0 - - black==24.8.0 + - attrs==24.3.0 + - black==24.10.0 - boltons==24.1.0 - cached-property==2.0.1 - - click==8.1.7 - - cmake==3.31.1 - - cytoolz==1.0.0 - - deepdiff==8.0.1 + - click==8.1.8 + - cmake==3.31.2 + - cytoolz==1.0.1 + - deepdiff==8.1.1 - devtools==0.12.2 - diskcache==5.6.3 - factory-boy==3.3.1 - filelock==3.16.1 - frozendict==2.4.6 - gridtools-cpp==2.3.8 - - importlib-resources==6.4.5 - - jinja2==3.1.4 + - jinja2==3.1.5 - lark==1.2.2 - mako==1.3.8 - nanobind==2.4.0 - - ninja==1.11.1.2 - - numpy==1.24.4 + - ninja==1.11.1.3 + - numpy==1.26.4 - packaging==24.2 - pybind11==2.13.6 - - setuptools==75.3.0 + - setuptools==75.8.0 - tabulate==0.9.0 - typing-extensions==4.12.2 - xxhash==3.0.0 diff --git a/README.md b/README.md index b782e20f63..07e0e1cdee 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ GT4Py is a Python library for generating high performance implementations of stencil kernels from a high-level definition using regular Python functions. GT4Py is part of the GridTools framework, a set of libraries and utilities to develop performance portable applications in the area of weather and climate modeling. -**NOTE:** The `gt4py.next` subpackage contains a new version of GT4Py which is not compatible with the current _stable_ version defined in `gt4py.cartesian`. The new version is highly experimental, it only works with unstructured meshes and it requires `python >= 3.10`. +**NOTE:** The `gt4py.next` subpackage contains a new version of GT4Py which is not compatible with the current _stable_ version defined in `gt4py.cartesian`. The new version is still experimental. ## 📃 Description diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index 7adb88459e..c2a872c1c4 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -9,13 +9,6 @@ include: PYVERSION_PREFIX: py310 PYVERSION: 3.10.9 -.py39: &py39 - PYVERSION_PREFIX: py39 - PYVERSION: 3.9.1 - -.py38: &py38 - PYVERSION_PREFIX: py38 - PYVERSION: 3.8.5 stages: - baseimage @@ -78,20 +71,6 @@ build_py310_baseimage_aarch64: variables: <<: *py310 -build_py39_baseimage_x86_64: - extends: .build_baseimage_x86_64 - variables: - <<: *py39 -build_py39_baseimage_aarch64: - extends: .build_baseimage_aarch64 - variables: - <<: *py39 - -build_py38_baseimage_x86_64: - extends: .build_baseimage_x86_64 - variables: - <<: *py38 - .build_image: stage: image @@ -128,23 +107,6 @@ build_py310_image_aarch64: variables: <<: *py310 -build_py39_image_x86_64: - extends: .build_image_x86_64 - needs: [build_py39_baseimage_x86_64] - variables: - <<: *py39 -build_py39_image_aarch64: - extends: .build_image_aarch64 - needs: [build_py39_baseimage_aarch64] - variables: - <<: *py39 - -build_py38_image_x86_64: - extends: .build_image_x86_64 - needs: [build_py38_baseimage_x86_64] - variables: - <<: *py38 - .test_helper: stage: test @@ -210,20 +172,3 @@ test_py310_aarch64: needs: [build_py310_image_aarch64] variables: <<: *py310 - -test_py39_x86_64: - extends: [.test_helper_x86_64] - needs: [build_py39_image_x86_64] - variables: - <<: *py39 -test_py39_aarch64: - extends: [.test_helper_aarch64] - needs: [build_py39_image_aarch64] - variables: - <<: *py39 - -test_py38_x86_64: - extends: [.test_helper_x86_64] - needs: [build_py38_image_x86_64] - variables: - <<: *py38 diff --git a/constraints.txt b/constraints.txt index fbdfb6e267..8b3e5e697f 100644 --- a/constraints.txt +++ b/constraints.txt @@ -1,182 +1,178 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # "tox run -e requirements-base" # aenum==3.1.15 # via dace -alabaster==0.7.13 # via sphinx +alabaster==1.0.0 # via sphinx annotated-types==0.7.0 # via pydantic asttokens==2.4.1 # via devtools, stack-data -astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) -attrs==24.2.0 # via gt4py (pyproject.toml), hypothesis, jsonschema, referencing +astunparse==1.6.3 # via dace +attrs==24.3.0 # via gt4py (pyproject.toml), hypothesis, jsonschema, referencing babel==2.16.0 # via sphinx -backcall==0.2.0 # via ipython -black==24.8.0 # via gt4py (pyproject.toml) +black==24.10.0 # via gt4py (pyproject.toml) boltons==24.1.0 # via gt4py (pyproject.toml) bracex==2.5.post1 # via wcmatch build==1.2.2.post1 # via pip-tools -bump-my-version==0.28.1 # via -r requirements-dev.in +bump-my-version==0.29.0 # via -r requirements-dev.in cached-property==2.0.1 # via gt4py (pyproject.toml) cachetools==5.5.0 # via tox -certifi==2024.8.30 # via requests +certifi==2024.12.14 # via requests cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox -charset-normalizer==3.4.0 # via requests -clang-format==19.1.4 # via -r requirements-dev.in, gt4py (pyproject.toml) -click==8.1.7 # via black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.31.1 # via gt4py (pyproject.toml) +charset-normalizer==3.4.1 # via requests +clang-format==19.1.6 # via -r requirements-dev.in, gt4py (pyproject.toml) +click==8.1.8 # via black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click +cmake==3.31.2 # via gt4py (pyproject.toml) cogapp==3.4.1 # via -r requirements-dev.in colorama==0.4.6 # via tox comm==0.2.2 # via ipykernel -contourpy==1.1.1 # via matplotlib -coverage==7.6.1 # via -r requirements-dev.in, pytest-cov +contourpy==1.3.1 # via matplotlib +coverage==7.6.10 # via -r requirements-dev.in, pytest-cov cycler==0.12.1 # via matplotlib -cytoolz==1.0.0 # via gt4py (pyproject.toml) +cytoolz==1.0.1 # via gt4py (pyproject.toml) dace==1.0.0 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -debugpy==1.8.9 # via ipykernel +debugpy==1.8.11 # via ipykernel decorator==5.1.1 # via ipython -deepdiff==8.0.1 # via gt4py (pyproject.toml) +deepdiff==8.1.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.9 # via dace diskcache==5.6.3 # via gt4py (pyproject.toml) distlib==0.3.9 # via virtualenv -docutils==0.20.1 # via sphinx, sphinx-rtd-theme -exceptiongroup==1.2.2 # via hypothesis, pytest +docutils==0.21.2 # via sphinx, sphinx-rtd-theme +exceptiongroup==1.2.2 # via hypothesis, ipython, pytest execnet==2.1.1 # via pytest-cache, pytest-xdist executing==2.1.0 # via devtools, stack-data factory-boy==3.3.1 # via gt4py (pyproject.toml), pytest-factoryboy -faker==33.1.0 # via factory-boy +faker==33.3.0 # via factory-boy fastjsonschema==2.21.1 # via nbformat filelock==3.16.1 # via gt4py (pyproject.toml), tox, virtualenv -fonttools==4.55.2 # via matplotlib +fonttools==4.55.3 # via matplotlib fparser==0.2.0 # via dace frozendict==2.4.6 # via gt4py (pyproject.toml) -gitdb==4.0.11 # via gitpython -gitpython==3.1.43 # via tach +gitdb==4.0.12 # via gitpython +gitpython==3.1.44 # via tach gridtools-cpp==2.3.8 # via gt4py (pyproject.toml) -hypothesis==6.113.0 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.6.1 # via pre-commit +hypothesis==6.123.11 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.6.5 # via pre-commit idna==3.10 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==8.5.0 # via build, jupyter-client, sphinx -importlib-resources==6.4.5 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications, matplotlib inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest ipykernel==6.29.5 # via nbmake -ipython==8.12.3 # via ipykernel +ipython==8.31.0 # via ipykernel +jax==0.4.38 # via gt4py (pyproject.toml) +jaxlib==0.4.38 # via jax jedi==0.19.2 # via ipython -jinja2==3.1.4 # via gt4py (pyproject.toml), sphinx +jinja2==3.1.5 # via gt4py (pyproject.toml), sphinx jsonschema==4.23.0 # via nbformat -jsonschema-specifications==2023.12.1 # via jsonschema +jsonschema-specifications==2024.10.1 # via jsonschema jupyter-client==8.6.3 # via ipykernel, nbclient -jupyter-core==5.7.2 # via ipykernel, jupyter-client, nbformat -jupytext==1.16.4 # via -r requirements-dev.in -kiwisolver==1.4.7 # via matplotlib +jupyter-core==5.7.2 # via ipykernel, jupyter-client, nbclient, nbformat +jupytext==1.16.6 # via -r requirements-dev.in +kiwisolver==1.4.8 # via matplotlib lark==1.2.2 # via gt4py (pyproject.toml) mako==1.3.8 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins, rich -markupsafe==2.1.5 # via jinja2, mako -matplotlib==3.7.5 # via -r requirements-dev.in +markupsafe==3.0.2 # via jinja2, mako +matplotlib==3.10.0 # via -r requirements-dev.in matplotlib-inline==0.1.7 # via ipykernel, ipython mdit-py-plugins==0.4.2 # via jupytext mdurl==0.1.2 # via markdown-it-py +ml-dtypes==0.5.1 # via jax, jaxlib mpmath==1.3.0 # via sympy -mypy==1.13.0 # via -r requirements-dev.in +mypy==1.14.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy nanobind==2.4.0 # via gt4py (pyproject.toml) -nbclient==0.6.8 # via nbmake +nbclient==0.10.2 # via nbmake nbformat==5.10.4 # via jupytext, nbclient, nbmake -nbmake==1.5.4 # via -r requirements-dev.in -nest-asyncio==1.6.0 # via ipykernel, nbclient -networkx==3.1 # via dace, tach -ninja==1.11.1.2 # via gt4py (pyproject.toml) +nbmake==1.5.5 # via -r requirements-dev.in +nest-asyncio==1.6.0 # via ipykernel +networkx==3.4.2 # via dace, tach +ninja==1.11.1.3 # via gt4py (pyproject.toml) nodeenv==1.9.1 # via pre-commit -numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), matplotlib, scipy -orderly-set==5.2.2 # via deepdiff +numpy==1.26.4 # via contourpy, dace, gt4py (pyproject.toml), jax, jaxlib, matplotlib, ml-dtypes, scipy +opt-einsum==3.4.0 # via jax +orderly-set==5.2.3 # via deepdiff packaging==24.2 # via black, build, dace, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox parso==0.8.4 # via jedi pathspec==0.12.1 # via black pexpect==4.9.0 # via ipython -pickleshare==0.7.5 # via ipython -pillow==10.4.0 # via matplotlib +pillow==11.1.0 # via matplotlib pip-tools==7.4.1 # via -r requirements-dev.in pipdeptree==2.24.0 # via -r requirements-dev.in -pkgutil-resolve-name==1.3.10 # via jsonschema platformdirs==4.3.6 # via black, jupyter-core, tox, virtualenv pluggy==1.5.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.5.0 # via -r requirements-dev.in -prompt-toolkit==3.0.36 # via ipython, questionary, tach -psutil==6.1.0 # via -r requirements-dev.in, ipykernel, pytest-xdist +pre-commit==4.0.1 # via -r requirements-dev.in +prompt-toolkit==3.0.48 # via ipython, questionary, tach +psutil==6.1.1 # via -r requirements-dev.in, ipykernel, pytest-xdist ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data pybind11==2.13.6 # via gt4py (pyproject.toml) -pydantic==2.10.3 # via bump-my-version, pydantic-settings -pydantic-core==2.27.1 # via pydantic -pydantic-settings==2.6.1 # via bump-my-version -pydot==3.0.3 # via tach -pygments==2.18.0 # via -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx -pyparsing==3.1.4 # via matplotlib, pydot +pydantic==2.10.4 # via bump-my-version, pydantic-settings +pydantic-core==2.27.2 # via pydantic +pydantic-settings==2.7.1 # via bump-my-version +pydot==3.0.4 # via tach +pygments==2.19.1 # via -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx +pyparsing==3.2.1 # via matplotlib, pydot pyproject-api==1.8.0 # via tox pyproject-hooks==1.2.0 # via build, pip-tools pytest==8.3.4 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in -pytest-cov==5.0.0 # via -r requirements-dev.in +pytest-cov==6.0.0 # via -r requirements-dev.in pytest-custom-exit-code==0.3.0 # via -r requirements-dev.in pytest-factoryboy==2.7.0 # via -r requirements-dev.in pytest-instafail==0.5.0 # via -r requirements-dev.in pytest-xdist==3.6.1 # via -r requirements-dev.in python-dateutil==2.9.0.post0 # via faker, jupyter-client, matplotlib python-dotenv==1.0.1 # via pydantic-settings -pytz==2024.2 # via babel pyyaml==6.0.2 # via dace, jupytext, pre-commit, tach pyzmq==26.2.0 # via ipykernel, jupyter-client -questionary==2.0.1 # via bump-my-version +questionary==2.1.0 # via bump-my-version referencing==0.35.1 # via jsonschema, jsonschema-specifications requests==2.32.3 # via sphinx rich==13.9.4 # via bump-my-version, rich-click, tach rich-click==1.8.5 # via bump-my-version -rpds-py==0.20.1 # via jsonschema, referencing -ruff==0.8.2 # via -r requirements-dev.in -scipy==1.10.1 # via gt4py (pyproject.toml) +rpds-py==0.22.3 # via jsonschema, referencing +ruff==0.8.6 # via -r requirements-dev.in +scipy==1.15.0 # via gt4py (pyproject.toml), jax, jaxlib setuptools-scm==8.1.0 # via fparser six==1.17.0 # via asttokens, astunparse, python-dateutil -smmap==5.0.1 # via gitdb +smmap==5.0.2 # via gitdb snowballstemmer==2.2.0 # via sphinx sortedcontainers==2.4.0 # via hypothesis -sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery +sphinx==8.1.3 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery sphinx-rtd-theme==3.0.2 # via -r requirements-dev.in -sphinxcontrib-applehelp==1.0.4 # via sphinx -sphinxcontrib-devhelp==1.0.2 # via sphinx -sphinxcontrib-htmlhelp==2.0.1 # via sphinx +sphinxcontrib-applehelp==2.0.0 # via sphinx +sphinxcontrib-devhelp==2.0.0 # via sphinx +sphinxcontrib-htmlhelp==2.1.0 # via sphinx sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx -sphinxcontrib-qthelp==1.0.3 # via sphinx -sphinxcontrib-serializinghtml==1.1.5 # via sphinx +sphinxcontrib-qthelp==2.0.0 # via sphinx +sphinxcontrib-serializinghtml==2.0.0 # via sphinx stack-data==0.6.3 # via ipython -stdlib-list==0.10.0 # via tach sympy==1.13.3 # via dace tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.16.5 # via -r requirements-dev.in -tomli==2.2.1 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox -tomli-w==1.0.0 # via tach +tach==0.19.5 # via -r requirements-dev.in +tomli==2.2.1 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, sphinx, tach, tox +tomli-w==1.1.0 # via tach tomlkit==0.13.2 # via bump-my-version toolz==1.0.0 # via cytoolz tornado==6.4.2 # via ipykernel, jupyter-client tox==4.23.2 # via -r requirements-dev.in traitlets==5.14.3 # via comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat types-tabulate==0.9.0.20241207 # via -r requirements-dev.in -typing-extensions==4.12.2 # via annotated-types, black, faker, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, setuptools-scm, tox -urllib3==2.2.3 # via requests -virtualenv==20.28.0 # via pre-commit, tox +typing-extensions==4.12.2 # via black, faker, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, tox +urllib3==2.3.0 # via requests +virtualenv==20.28.1 # via pre-commit, tox wcmatch==10.0 # via bump-my-version wcwidth==0.2.13 # via prompt-toolkit wheel==0.45.1 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) -zipp==3.20.2 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: pip==24.3.1 # via pip-tools, pipdeptree -setuptools==75.3.0 # via gt4py (pyproject.toml), pip-tools, setuptools-scm +setuptools==75.8.0 # via gt4py (pyproject.toml), pip-tools, setuptools-scm diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 6d75415181..a4924cc09c 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -49,7 +49,6 @@ ## result.append(str(make_min_req(r))) ## print("\n".join(sorted(result))) ##]]] -astunparse==1.6.3; python_version < "3.9" attrs==21.3 black==22.3 boltons==20.1 @@ -71,8 +70,7 @@ filelock==3.16.1 frozendict==2.3 gridtools-cpp==2.3.8 hypothesis==6.0.0 -importlib-resources==5.0; python_version < "3.9" -jax[cpu]==0.4.18; python_version >= "3.10" +jax[cpu]==0.4.18 jinja2==3.0.0 jupytext==1.14 lark==1.1.2 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index 991b7a6941..4b24385410 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -46,7 +46,6 @@ ## result.append(str(make_min_req(r))) ## print("\n".join(sorted(result))) ##]]] -astunparse==1.6.3; python_version < "3.9" attrs==21.3 black==22.3 boltons==20.1 @@ -67,7 +66,6 @@ filelock==3.16.1 frozendict==2.3 gridtools-cpp==2.3.8 hypothesis==6.0.0 -importlib-resources==5.0; python_version < "3.9" jinja2==3.0.0 jupytext==1.14 lark==1.1.2 diff --git a/pyproject.toml b/pyproject.toml index d086363ec4..78735116ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,8 +14,6 @@ classifiers = [ 'License :: OSI Approved :: BSD License', 'Operating System :: POSIX', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: Implementation :: CPython', @@ -24,7 +22,6 @@ classifiers = [ 'Topic :: Scientific/Engineering :: Physics' ] dependencies = [ - "astunparse>=1.6.3;python_version<'3.9'", 'attrs>=21.3', 'black>=22.3', 'boltons>=20.1', @@ -39,7 +36,6 @@ dependencies = [ 'filelock>=3.16.1', 'frozendict>=2.3', 'gridtools-cpp>=2.3.8,==2.*', - "importlib-resources>=5.0;python_version<'3.9'", 'jinja2>=3.0.0', 'lark>=1.1.2', 'mako>=1.1', @@ -67,7 +63,7 @@ keywords = [ license = {file = 'LICENSE.txt'} name = 'gt4py' readme = 'README.md' -requires-python = '>=3.8' +requires-python = '>=3.10' [project.optional-dependencies] # Bundles @@ -80,9 +76,9 @@ cuda12 = ['cupy-cuda12x>=12.0'] dace = ['dace>=1.0.0,<1.1.0'] # v1.x will contain breaking changes, see https://github.com/spcl/dace/milestone/4 formatting = ['clang-format>=9.0'] gpu = ['cupy>=12.0'] -jax-cpu = ['jax[cpu]>=0.4.18; python_version>="3.10"'] -jax-cuda11 = ['jax[cuda11_pip]>=0.4.18; python_version>="3.10"'] -jax-cuda12 = ['jax[cuda12_pip]>=0.4.18; python_version>="3.10"'] +jax-cpu = ['jax[cpu]>=0.4.18'] +jax-cuda11 = ['jax[cuda11_pip]>=0.4.18'] +jax-cuda12 = ['jax[cuda12_pip]>=0.4.18'] performance = ['scipy>=1.9.2'] rocm-43 = ['cupy-rocm-4-3'] testing = ['hypothesis>=6.0.0', 'pytest>=7.0'] @@ -275,7 +271,7 @@ line-length = 100 # It should be the same as in `tool.black.line-length` above respect-gitignore = true show-fixes = true # show-source = true -target-version = 'py38' +target-version = 'py310' [tool.ruff.format] docstring-code-format = true @@ -292,7 +288,8 @@ docstring-code-format = true # NPY: NumPy-specific rules # RUF: Ruff-specific rules ignore = [ - 'E501' # [line-too-long] + 'E501', # [line-too-long] + 'B905' # [zip-without-explicit-strict] # TODO(egparedes): Reevaluate this rule ] select = ['E', 'F', 'I', 'B', 'A', 'T10', 'ERA', 'NPY', 'RUF'] typing-modules = ['gt4py.eve.extended_typing'] diff --git a/requirements-dev.txt b/requirements-dev.txt index 40554cef13..463b1bc6ac 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,181 +1,178 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # "tox run -e requirements-base" # aenum==3.1.15 # via -c constraints.txt, dace -alabaster==0.7.13 # via -c constraints.txt, sphinx +alabaster==1.0.0 # via -c constraints.txt, sphinx annotated-types==0.7.0 # via -c constraints.txt, pydantic asttokens==2.4.1 # via -c constraints.txt, devtools, stack-data -astunparse==1.6.3 ; python_version < "3.9" # via -c constraints.txt, dace, gt4py (pyproject.toml) -attrs==24.2.0 # via -c constraints.txt, gt4py (pyproject.toml), hypothesis, jsonschema, referencing +astunparse==1.6.3 # via -c constraints.txt, dace +attrs==24.3.0 # via -c constraints.txt, gt4py (pyproject.toml), hypothesis, jsonschema, referencing babel==2.16.0 # via -c constraints.txt, sphinx -backcall==0.2.0 # via -c constraints.txt, ipython -black==24.8.0 # via -c constraints.txt, gt4py (pyproject.toml) +black==24.10.0 # via -c constraints.txt, gt4py (pyproject.toml) boltons==24.1.0 # via -c constraints.txt, gt4py (pyproject.toml) bracex==2.5.post1 # via -c constraints.txt, wcmatch build==1.2.2.post1 # via -c constraints.txt, pip-tools -bump-my-version==0.28.1 # via -c constraints.txt, -r requirements-dev.in +bump-my-version==0.29.0 # via -c constraints.txt, -r requirements-dev.in cached-property==2.0.1 # via -c constraints.txt, gt4py (pyproject.toml) cachetools==5.5.0 # via -c constraints.txt, tox -certifi==2024.8.30 # via -c constraints.txt, requests +certifi==2024.12.14 # via -c constraints.txt, requests cfgv==3.4.0 # via -c constraints.txt, pre-commit chardet==5.2.0 # via -c constraints.txt, tox -charset-normalizer==3.4.0 # via -c constraints.txt, requests -clang-format==19.1.4 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) -click==8.1.7 # via -c constraints.txt, black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.31.1 # via -c constraints.txt, gt4py (pyproject.toml) +charset-normalizer==3.4.1 # via -c constraints.txt, requests +clang-format==19.1.6 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) +click==8.1.8 # via -c constraints.txt, black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click +cmake==3.31.2 # via -c constraints.txt, gt4py (pyproject.toml) cogapp==3.4.1 # via -c constraints.txt, -r requirements-dev.in colorama==0.4.6 # via -c constraints.txt, tox comm==0.2.2 # via -c constraints.txt, ipykernel -contourpy==1.1.1 # via -c constraints.txt, matplotlib -coverage[toml]==7.6.1 # via -c constraints.txt, -r requirements-dev.in, pytest-cov +contourpy==1.3.1 # via -c constraints.txt, matplotlib +coverage[toml]==7.6.10 # via -c constraints.txt, -r requirements-dev.in, pytest-cov cycler==0.12.1 # via -c constraints.txt, matplotlib -cytoolz==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) +cytoolz==1.0.1 # via -c constraints.txt, gt4py (pyproject.toml) dace==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) darglint==1.8.1 # via -c constraints.txt, -r requirements-dev.in -debugpy==1.8.9 # via -c constraints.txt, ipykernel +debugpy==1.8.11 # via -c constraints.txt, ipykernel decorator==5.1.1 # via -c constraints.txt, ipython -deepdiff==8.0.1 # via -c constraints.txt, gt4py (pyproject.toml) +deepdiff==8.1.1 # via -c constraints.txt, gt4py (pyproject.toml) devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml) dill==0.3.9 # via -c constraints.txt, dace diskcache==5.6.3 # via -c constraints.txt, gt4py (pyproject.toml) distlib==0.3.9 # via -c constraints.txt, virtualenv -docutils==0.20.1 # via -c constraints.txt, sphinx, sphinx-rtd-theme -exceptiongroup==1.2.2 # via -c constraints.txt, hypothesis, pytest +docutils==0.21.2 # via -c constraints.txt, sphinx, sphinx-rtd-theme +exceptiongroup==1.2.2 # via -c constraints.txt, hypothesis, ipython, pytest execnet==2.1.1 # via -c constraints.txt, pytest-cache, pytest-xdist executing==2.1.0 # via -c constraints.txt, devtools, stack-data factory-boy==3.3.1 # via -c constraints.txt, gt4py (pyproject.toml), pytest-factoryboy -faker==33.1.0 # via -c constraints.txt, factory-boy +faker==33.3.0 # via -c constraints.txt, factory-boy fastjsonschema==2.21.1 # via -c constraints.txt, nbformat filelock==3.16.1 # via -c constraints.txt, gt4py (pyproject.toml), tox, virtualenv -fonttools==4.55.2 # via -c constraints.txt, matplotlib +fonttools==4.55.3 # via -c constraints.txt, matplotlib fparser==0.2.0 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) -gitdb==4.0.11 # via -c constraints.txt, gitpython -gitpython==3.1.43 # via -c constraints.txt, tach +gitdb==4.0.12 # via -c constraints.txt, gitpython +gitpython==3.1.44 # via -c constraints.txt, tach gridtools-cpp==2.3.8 # via -c constraints.txt, gt4py (pyproject.toml) -hypothesis==6.113.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.6.1 # via -c constraints.txt, pre-commit +hypothesis==6.123.11 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.6.5 # via -c constraints.txt, pre-commit idna==3.10 # via -c constraints.txt, requests imagesize==1.4.1 # via -c constraints.txt, sphinx -importlib-metadata==8.5.0 # via -c constraints.txt, build, jupyter-client, sphinx -importlib-resources==6.4.5 ; python_version < "3.9" # via -c constraints.txt, gt4py (pyproject.toml), jsonschema, jsonschema-specifications, matplotlib inflection==0.5.1 # via -c constraints.txt, pytest-factoryboy iniconfig==2.0.0 # via -c constraints.txt, pytest ipykernel==6.29.5 # via -c constraints.txt, nbmake -ipython==8.12.3 # via -c constraints.txt, ipykernel +ipython==8.31.0 # via -c constraints.txt, ipykernel +jax[cpu]==0.4.38 # via -c constraints.txt, gt4py (pyproject.toml) +jaxlib==0.4.38 # via -c constraints.txt, jax jedi==0.19.2 # via -c constraints.txt, ipython -jinja2==3.1.4 # via -c constraints.txt, gt4py (pyproject.toml), sphinx +jinja2==3.1.5 # via -c constraints.txt, gt4py (pyproject.toml), sphinx jsonschema==4.23.0 # via -c constraints.txt, nbformat -jsonschema-specifications==2023.12.1 # via -c constraints.txt, jsonschema +jsonschema-specifications==2024.10.1 # via -c constraints.txt, jsonschema jupyter-client==8.6.3 # via -c constraints.txt, ipykernel, nbclient -jupyter-core==5.7.2 # via -c constraints.txt, ipykernel, jupyter-client, nbformat -jupytext==1.16.4 # via -c constraints.txt, -r requirements-dev.in -kiwisolver==1.4.7 # via -c constraints.txt, matplotlib +jupyter-core==5.7.2 # via -c constraints.txt, ipykernel, jupyter-client, nbclient, nbformat +jupytext==1.16.6 # via -c constraints.txt, -r requirements-dev.in +kiwisolver==1.4.8 # via -c constraints.txt, matplotlib lark==1.2.2 # via -c constraints.txt, gt4py (pyproject.toml) mako==1.3.8 # via -c constraints.txt, gt4py (pyproject.toml) markdown-it-py==3.0.0 # via -c constraints.txt, jupytext, mdit-py-plugins, rich -markupsafe==2.1.5 # via -c constraints.txt, jinja2, mako -matplotlib==3.7.5 # via -c constraints.txt, -r requirements-dev.in +markupsafe==3.0.2 # via -c constraints.txt, jinja2, mako +matplotlib==3.10.0 # via -c constraints.txt, -r requirements-dev.in matplotlib-inline==0.1.7 # via -c constraints.txt, ipykernel, ipython mdit-py-plugins==0.4.2 # via -c constraints.txt, jupytext mdurl==0.1.2 # via -c constraints.txt, markdown-it-py +ml-dtypes==0.5.1 # via -c constraints.txt, jax, jaxlib mpmath==1.3.0 # via -c constraints.txt, sympy -mypy==1.13.0 # via -c constraints.txt, -r requirements-dev.in +mypy==1.14.1 # via -c constraints.txt, -r requirements-dev.in mypy-extensions==1.0.0 # via -c constraints.txt, black, mypy nanobind==2.4.0 # via -c constraints.txt, gt4py (pyproject.toml) -nbclient==0.6.8 # via -c constraints.txt, nbmake +nbclient==0.10.2 # via -c constraints.txt, nbmake nbformat==5.10.4 # via -c constraints.txt, jupytext, nbclient, nbmake -nbmake==1.5.4 # via -c constraints.txt, -r requirements-dev.in -nest-asyncio==1.6.0 # via -c constraints.txt, ipykernel, nbclient -networkx==3.1 # via -c constraints.txt, dace, tach -ninja==1.11.1.2 # via -c constraints.txt, gt4py (pyproject.toml) +nbmake==1.5.5 # via -c constraints.txt, -r requirements-dev.in +nest-asyncio==1.6.0 # via -c constraints.txt, ipykernel +networkx==3.4.2 # via -c constraints.txt, dace, tach +ninja==1.11.1.3 # via -c constraints.txt, gt4py (pyproject.toml) nodeenv==1.9.1 # via -c constraints.txt, pre-commit -numpy==1.24.4 # via -c constraints.txt, contourpy, dace, gt4py (pyproject.toml), matplotlib -orderly-set==5.2.2 # via -c constraints.txt, deepdiff +numpy==1.26.4 # via -c constraints.txt, contourpy, dace, gt4py (pyproject.toml), jax, jaxlib, matplotlib, ml-dtypes, scipy +opt-einsum==3.4.0 # via -c constraints.txt, jax +orderly-set==5.2.3 # via -c constraints.txt, deepdiff packaging==24.2 # via -c constraints.txt, black, build, dace, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox parso==0.8.4 # via -c constraints.txt, jedi pathspec==0.12.1 # via -c constraints.txt, black pexpect==4.9.0 # via -c constraints.txt, ipython -pickleshare==0.7.5 # via -c constraints.txt, ipython -pillow==10.4.0 # via -c constraints.txt, matplotlib +pillow==11.1.0 # via -c constraints.txt, matplotlib pip-tools==7.4.1 # via -c constraints.txt, -r requirements-dev.in pipdeptree==2.24.0 # via -c constraints.txt, -r requirements-dev.in -pkgutil-resolve-name==1.3.10 # via -c constraints.txt, jsonschema platformdirs==4.3.6 # via -c constraints.txt, black, jupyter-core, tox, virtualenv pluggy==1.5.0 # via -c constraints.txt, pytest, tox ply==3.11 # via -c constraints.txt, dace -pre-commit==3.5.0 # via -c constraints.txt, -r requirements-dev.in -prompt-toolkit==3.0.36 # via -c constraints.txt, ipython, questionary, tach -psutil==6.1.0 # via -c constraints.txt, -r requirements-dev.in, ipykernel, pytest-xdist +pre-commit==4.0.1 # via -c constraints.txt, -r requirements-dev.in +prompt-toolkit==3.0.48 # via -c constraints.txt, ipython, questionary, tach +psutil==6.1.1 # via -c constraints.txt, -r requirements-dev.in, ipykernel, pytest-xdist ptyprocess==0.7.0 # via -c constraints.txt, pexpect pure-eval==0.2.3 # via -c constraints.txt, stack-data pybind11==2.13.6 # via -c constraints.txt, gt4py (pyproject.toml) -pydantic==2.10.3 # via -c constraints.txt, bump-my-version, pydantic-settings -pydantic-core==2.27.1 # via -c constraints.txt, pydantic -pydantic-settings==2.6.1 # via -c constraints.txt, bump-my-version -pydot==3.0.3 # via -c constraints.txt, tach -pygments==2.18.0 # via -c constraints.txt, -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx -pyparsing==3.1.4 # via -c constraints.txt, matplotlib, pydot +pydantic==2.10.4 # via -c constraints.txt, bump-my-version, pydantic-settings +pydantic-core==2.27.2 # via -c constraints.txt, pydantic +pydantic-settings==2.7.1 # via -c constraints.txt, bump-my-version +pydot==3.0.4 # via -c constraints.txt, tach +pygments==2.19.1 # via -c constraints.txt, -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx +pyparsing==3.2.1 # via -c constraints.txt, matplotlib, pydot pyproject-api==1.8.0 # via -c constraints.txt, tox pyproject-hooks==1.2.0 # via -c constraints.txt, build, pip-tools pytest==8.3.4 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist pytest-cache==1.0 # via -c constraints.txt, -r requirements-dev.in -pytest-cov==5.0.0 # via -c constraints.txt, -r requirements-dev.in +pytest-cov==6.0.0 # via -c constraints.txt, -r requirements-dev.in pytest-custom-exit-code==0.3.0 # via -c constraints.txt, -r requirements-dev.in pytest-factoryboy==2.7.0 # via -c constraints.txt, -r requirements-dev.in pytest-instafail==0.5.0 # via -c constraints.txt, -r requirements-dev.in pytest-xdist[psutil]==3.6.1 # via -c constraints.txt, -r requirements-dev.in python-dateutil==2.9.0.post0 # via -c constraints.txt, faker, jupyter-client, matplotlib python-dotenv==1.0.1 # via -c constraints.txt, pydantic-settings -pytz==2024.2 # via -c constraints.txt, babel pyyaml==6.0.2 # via -c constraints.txt, dace, jupytext, pre-commit, tach pyzmq==26.2.0 # via -c constraints.txt, ipykernel, jupyter-client -questionary==2.0.1 # via -c constraints.txt, bump-my-version +questionary==2.1.0 # via -c constraints.txt, bump-my-version referencing==0.35.1 # via -c constraints.txt, jsonschema, jsonschema-specifications requests==2.32.3 # via -c constraints.txt, sphinx rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach rich-click==1.8.5 # via -c constraints.txt, bump-my-version -rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing -ruff==0.8.2 # via -c constraints.txt, -r requirements-dev.in +rpds-py==0.22.3 # via -c constraints.txt, jsonschema, referencing +ruff==0.8.6 # via -c constraints.txt, -r requirements-dev.in +scipy==1.15.0 # via -c constraints.txt, jax, jaxlib setuptools-scm==8.1.0 # via -c constraints.txt, fparser six==1.17.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil -smmap==5.0.1 # via -c constraints.txt, gitdb +smmap==5.0.2 # via -c constraints.txt, gitdb snowballstemmer==2.2.0 # via -c constraints.txt, sphinx sortedcontainers==2.4.0 # via -c constraints.txt, hypothesis -sphinx==7.1.2 # via -c constraints.txt, -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery +sphinx==8.1.3 # via -c constraints.txt, -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery sphinx-rtd-theme==3.0.2 # via -c constraints.txt, -r requirements-dev.in -sphinxcontrib-applehelp==1.0.4 # via -c constraints.txt, sphinx -sphinxcontrib-devhelp==1.0.2 # via -c constraints.txt, sphinx -sphinxcontrib-htmlhelp==2.0.1 # via -c constraints.txt, sphinx +sphinxcontrib-applehelp==2.0.0 # via -c constraints.txt, sphinx +sphinxcontrib-devhelp==2.0.0 # via -c constraints.txt, sphinx +sphinxcontrib-htmlhelp==2.1.0 # via -c constraints.txt, sphinx sphinxcontrib-jquery==4.1 # via -c constraints.txt, sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via -c constraints.txt, sphinx -sphinxcontrib-qthelp==1.0.3 # via -c constraints.txt, sphinx -sphinxcontrib-serializinghtml==1.1.5 # via -c constraints.txt, sphinx +sphinxcontrib-qthelp==2.0.0 # via -c constraints.txt, sphinx +sphinxcontrib-serializinghtml==2.0.0 # via -c constraints.txt, sphinx stack-data==0.6.3 # via -c constraints.txt, ipython -stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.13.3 # via -c constraints.txt, dace tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.16.5 # via -c constraints.txt, -r requirements-dev.in -tomli==2.2.1 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox -tomli-w==1.0.0 # via -c constraints.txt, tach +tach==0.19.5 # via -c constraints.txt, -r requirements-dev.in +tomli==2.2.1 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, sphinx, tach, tox +tomli-w==1.1.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version toolz==1.0.0 # via -c constraints.txt, cytoolz tornado==6.4.2 # via -c constraints.txt, ipykernel, jupyter-client tox==4.23.2 # via -c constraints.txt, -r requirements-dev.in traitlets==5.14.3 # via -c constraints.txt, comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat types-tabulate==0.9.0.20241207 # via -c constraints.txt, -r requirements-dev.in -typing-extensions==4.12.2 # via -c constraints.txt, annotated-types, black, faker, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, setuptools-scm, tox -urllib3==2.2.3 # via -c constraints.txt, requests -virtualenv==20.28.0 # via -c constraints.txt, pre-commit, tox +typing-extensions==4.12.2 # via -c constraints.txt, black, faker, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, tox +urllib3==2.3.0 # via -c constraints.txt, requests +virtualenv==20.28.1 # via -c constraints.txt, pre-commit, tox wcmatch==10.0 # via -c constraints.txt, bump-my-version wcwidth==0.2.13 # via -c constraints.txt, prompt-toolkit wheel==0.45.1 # via -c constraints.txt, astunparse, pip-tools xxhash==3.0.0 # via -c constraints.txt, gt4py (pyproject.toml) -zipp==3.20.2 # via -c constraints.txt, importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: pip==24.3.1 # via -c constraints.txt, pip-tools, pipdeptree -setuptools==75.3.0 # via -c constraints.txt, gt4py (pyproject.toml), pip-tools, setuptools-scm +setuptools==75.8.0 # via -c constraints.txt, gt4py (pyproject.toml), pip-tools, setuptools-scm diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index dcb01db7ca..8c3c731c75 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -229,8 +229,8 @@ class LevelMarker(eve.StrEnum): @enum.unique class ExprKind(eve.IntEnum): - SCALAR: ExprKind = typing.cast("ExprKind", enum.auto()) - FIELD: ExprKind = typing.cast("ExprKind", enum.auto()) + SCALAR = typing.cast("ExprKind", enum.auto()) + FIELD = typing.cast("ExprKind", enum.auto()) class LocNode(eve.Node): diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 2c66d39290..96e41a7bd8 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -440,8 +440,8 @@ def content_hash(*args: Any, hash_algorithm: str | xtyping.HashlibAlgorithm | No return result -ddiff = deepdiff.DeepDiff -"""Shortcut for deepdiff.DeepDiff. +ddiff = deepdiff.diff.DeepDiff +"""Shortcut for deepdiff.diff.DeepDiff. Check https://zepworks.com/deepdiff/current/diff.html for more info. """ @@ -458,13 +458,13 @@ def dhash(obj: Any, **kwargs: Any) -> str: def pprint_ddiff( old: Any, new: Any, *, pprint_opts: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> None: - """Pretty printing of deepdiff.DeepDiff objects. + """Pretty printing of deepdiff.diff.DeepDiff objects. Keyword Arguments: pprint_opts: kwargs dict with options for pprint.pprint. """ pprint_opts = pprint_opts or {"indent": 2} - pprint.pprint(deepdiff.DeepDiff(old, new, **kwargs), **pprint_opts) + pprint.pprint(deepdiff.diff.DeepDiff(old, new, **kwargs), **pprint_opts) AnyWordsIterable = Union[str, Iterable[str]] diff --git a/tox.ini b/tox.ini index 8da0e45810..e7bfd4a3e4 100644 --- a/tox.ini +++ b/tox.ini @@ -9,19 +9,14 @@ envlist = storage-py{310}-{internal,dace}-{cpu} # docs labels = - test-cartesian-cpu = cartesian-py38-internal-cpu, cartesian-internal-py39-cpu, \ - cartesian-internal-py310-cpu, cartesian-py311-internal-cpu, \ - cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, cartesian-py311-dace-cpu - test-eve-cpu = eve-py38, eve-py39, eve-py310, eve-py311 + test-cartesian-cpu = cartesian-internal-py310-cpu, cartesian-py311-internal-cpu, cartesian-py310-dace-cpu, cartesian-py311-dace-cpu + test-eve-cpu = eve-py310, eve-py311 test-next-cpu = next-py310-nomesh-cpu, next-py311-nomesh-cpu, next-py310-atlas-cpu, next-py311-atlas-cpu - test-storage-cpu = storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, storage-py311-internal-cpu, \ - storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu, storage-py311-dace-cpu - test-cpu = cartesian-py38-internal-cpu, cartesian-py39-internal-cpu, cartesian-py310-internal-cpu, cartesian-py311-internal-cpu, \ - cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, cartesian-py311-dace-cpu, \ - eve-py38, eve-py39, eve-py310, eve-py311, \ + test-storage-cpu = storage-py310-internal-cpu, storage-py311-internal-cpu, storage-py310-dace-cpu, storage-py311-dace-cpu + test-cpu = cartesian-py310-internal-cpu, cartesian-py311-internal-cpu, cartesian-py310-dace-cpu, cartesian-py311-dace-cpu, \ + eve-py310, eve-py311, \ next-py310-nomesh-cpu, next-py311-nomesh-cpu, next-py310-atlas-cpu, next-py311-atlas-cpu, \ - storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, storage-py311-internal-cpu, \ - storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu, storage-py311-dace-cpu + storage-py310-internal-cpu, storage-py311-internal-cpu, storage-py310-dace-cpu, storage-py311-dace-cpu [testenv] deps = -r {tox_root}{/}{env:ENV_REQUIREMENTS_FILE:requirements-dev.txt} @@ -42,7 +37,7 @@ set_env = PYTHONWARNINGS = {env:PYTHONWARNINGS:ignore:Support for `[tool.setuptools]` in `pyproject.toml` is still *beta*:UserWarning,ignore:Field View Program:UserWarning} # -- Primary tests -- -[testenv:cartesian-py{38,39,310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] +[testenv:cartesian-py{310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] description = Run 'gt4py.cartesian' tests pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH, CXX, CC, OPENMP_CPPFLAGS, OPENMP_LDFLAGS, PIP_USER, PYTHONUSERBASE allowlist_externals = @@ -65,7 +60,7 @@ commands = # coverage json --rcfile=setup.cfg # coverage html --rcfile=setup.cfg --show-contexts -[testenv:eve-py{38,39,310,311}] +[testenv:eve-py{310,311}] description = Run 'gt4py.eve' tests commands = python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} {posargs} tests{/}eve_tests @@ -89,7 +84,7 @@ commands = " {posargs} tests{/}next_tests pytest --doctest-modules src{/}gt4py{/}next -[testenv:storage-py{38,39,310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] +[testenv:storage-py{310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] description = Run 'gt4py.storage' tests commands = python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "\ @@ -112,7 +107,7 @@ commands = python -m pytest --nbmake examples -v -n {env:NUM_PROCESSES:1} # -- Other artefacts -- -[testenv:dev-py{38,39,310,311}{-atlas,}] +[testenv:dev-py{310,311}{-atlas,}] description = Initialize development environment for gt4py deps = -r {tox_root}{/}requirements-dev.txt @@ -141,17 +136,13 @@ set_env = # git add _static # commands_post = -[testenv:requirements-{base,py38,py39,py310,py311}] +[testenv:requirements-{base,py310,py311}] description = base: Update pinned development requirements - py38: Update requirements for testing a specific python version - py39: Update requirements for testing a specific python version py310: Update requirements for testing a specific python version py311: Update requirements for testing a specific python version base_python = - base: py38 - py38: py38 - py39: py39 + base: py310 py310: py310 py311: py311 deps = From 22e4a89a1ee8563680396dbfe856ea0a1437be35 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 10 Jan 2025 09:45:20 +0100 Subject: [PATCH 085/178] refactor[next]: use eve.datamodel for types (#1750) --- src/gt4py/eve/datamodels/core.py | 6 +- src/gt4py/eve/type_validation.py | 9 +++ .../ffront/foast_passes/type_deduction.py | 28 +++++--- src/gt4py/next/ffront/foast_to_gtir.py | 3 + .../next/ffront/past_passes/type_deduction.py | 9 +++ src/gt4py/next/ffront/past_process_args.py | 1 + src/gt4py/next/ffront/type_info.py | 8 ++- src/gt4py/next/ffront/type_specifications.py | 8 +-- src/gt4py/next/iterator/embedded.py | 11 ++- .../iterator/transforms/fuse_as_fieldop.py | 9 +-- .../next/iterator/transforms/global_tmps.py | 12 ++-- .../next/iterator/type_system/inference.py | 6 +- .../type_system/type_specifications.py | 16 +---- .../iterator/type_system/type_synthesizer.py | 26 +++---- src/gt4py/next/otf/binding/nanobind.py | 1 + .../codegens/gtfn/itir_to_gtfn_ir.py | 2 +- .../gtir_builtin_translators.py | 6 +- .../runners/dace_fieldview/gtir_dataflow.py | 29 ++++---- .../runners/dace_fieldview/gtir_sdfg.py | 1 + .../runners/dace_fieldview/utility.py | 7 +- src/gt4py/next/type_system/type_info.py | 43 +++++++----- .../next/type_system/type_specifications.py | 67 +++++++++---------- .../next/type_system/type_translation.py | 8 +-- tests/eve_tests/unit_tests/test_datamodels.py | 15 ++++- .../iterator_tests/test_type_inference.py | 10 +-- 25 files changed, 186 insertions(+), 155 deletions(-) diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index 1b0e995156..31e63bdf9f 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -16,6 +16,7 @@ import dataclasses import functools import sys +import types import typing import warnings @@ -1254,8 +1255,11 @@ def _make_concrete_with_cache( if not is_generic_datamodel_class(datamodel_cls): raise TypeError(f"'{datamodel_cls.__name__}' is not a generic model class.") for t in type_args: + _accepted_types: tuple[type, ...] = (type, type(None), xtyping.StdGenericAliasType) + if sys.version_info >= (3, 10): + _accepted_types = (*_accepted_types, types.UnionType) if not ( - isinstance(t, (type, type(None), xtyping.StdGenericAliasType)) + isinstance(t, _accepted_types) or (getattr(type(t), "__module__", None) in ("typing", "typing_extensions")) ): raise TypeError( diff --git a/src/gt4py/eve/type_validation.py b/src/gt4py/eve/type_validation.py index e150832295..695ab69dc3 100644 --- a/src/gt4py/eve/type_validation.py +++ b/src/gt4py/eve/type_validation.py @@ -14,6 +14,8 @@ import collections.abc import dataclasses import functools +import sys +import types import typing from . import exceptions, extended_typing as xtyping, utils @@ -193,6 +195,12 @@ def __call__( if type_annotation is None: type_annotation = type(None) + if sys.version_info >= (3, 10): + if isinstance( + type_annotation, types.UnionType + ): # see https://github.com/python/cpython/issues/105499 + type_annotation = typing.Union[type_annotation.__args__] + # Non-generic types if xtyping.is_actual_type(type_annotation): assert not xtyping.get_args(type_annotation) @@ -277,6 +285,7 @@ def __call__( if issubclass(origin_type, (collections.abc.Sequence, collections.abc.Set)): assert len(type_args) == 1 + make_recursive(type_args[0]) if (member_validator := make_recursive(type_args[0])) is None: raise exceptions.EveValueError( f"{type_args[0]} type annotation is not supported." diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index d334487ae1..6b40cbb77f 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Optional, TypeVar, cast +from typing import Any, Optional, TypeAlias, TypeVar, cast import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits @@ -48,7 +48,7 @@ def with_altered_scalar_kind( if isinstance(type_spec, ts.FieldType): return ts.FieldType( dims=type_spec.dims, - dtype=ts.ScalarType(kind=new_scalar_kind, shape=type_spec.dtype.shape), + dtype=with_altered_scalar_kind(type_spec.dtype, new_scalar_kind), ) elif isinstance(type_spec, ts.ScalarType): return ts.ScalarType(kind=new_scalar_kind, shape=type_spec.shape) @@ -68,13 +68,18 @@ def construct_tuple_type( >>> mask_type = ts.FieldType( ... dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL) ... ) - >>> true_branch_types = [ts.ScalarType(kind=ts.ScalarKind), ts.ScalarType(kind=ts.ScalarKind)] + >>> true_branch_types = [ + ... ts.ScalarType(kind=ts.ScalarKind.FLOAT64), + ... ts.ScalarType(kind=ts.ScalarKind.FLOAT64), + ... ] >>> false_branch_types = [ - ... ts.FieldType(dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind)), - ... ts.ScalarType(kind=ts.ScalarKind), + ... ts.FieldType( + ... dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ... ), + ... ts.ScalarType(kind=ts.ScalarKind.FLOAT64), ... ] >>> print(construct_tuple_type(true_branch_types, false_branch_types, mask_type)) - [FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)), FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None))] + [FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)), FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None))] """ element_types_new = true_branch_types for i, element in enumerate(true_branch_types): @@ -105,8 +110,8 @@ def promote_to_mask_type( >>> I, J = (Dimension(value=dim) for dim in ["I", "J"]) >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) >>> dtype = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - >>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), ts.ScalarType(kind=dtype)) - FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=ScalarType(kind=, shape=None), shape=None)) + >>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), dtype) + FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) >>> promote_to_mask_type( ... ts.FieldType(dims=[I, J], dtype=bool_type), ts.FieldType(dims=[I], dtype=dtype) ... ) @@ -360,7 +365,7 @@ def visit_Assign(self, node: foast.Assign, **kwargs: Any) -> foast.Assign: def visit_TupleTargetAssign( self, node: foast.TupleTargetAssign, **kwargs: Any ) -> foast.TupleTargetAssign: - TargetType = list[foast.Starred | foast.Symbol] + TargetType: TypeAlias = list[foast.Starred | foast.Symbol] values = self.visit(node.value, **kwargs) if isinstance(values.type, ts.TupleType): @@ -374,7 +379,7 @@ def visit_TupleTargetAssign( ) new_targets: TargetType = [] - new_type: ts.TupleType | ts.DataType + new_type: ts.DataType for i, index in enumerate(indices): old_target = targets[i] @@ -391,7 +396,8 @@ def visit_TupleTargetAssign( location=old_target.location, ) else: - new_type = values.type.types[index] + new_type = values.type.types[index] # type: ignore[assignment] # see check in next line + assert isinstance(new_type, ts.DataType) new_target = self.visit( old_target, refine_type=new_type, location=old_target.location, **kwargs ) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 3c65695aec..4519b4e571 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -236,6 +236,7 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators dtype = type_info.extract_dtype(node.type) + assert isinstance(dtype, ts.ScalarType) if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: if dtype.kind != ts.ScalarKind.BOOL: raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") @@ -417,12 +418,14 @@ def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: dtype = type_info.extract_dtype(node.type) + assert isinstance(dtype, ts.ScalarType) min_value, _ = type_info.arithmetic_bounds(dtype) init_expr = self._make_literal(str(min_value), dtype) return self._make_reduction_expr(node, "maximum", init_expr, **kwargs) def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: dtype = type_info.extract_dtype(node.type) + assert isinstance(dtype, ts.ScalarType) _, max_value = type_info.arithmetic_bounds(dtype) init_expr = self._make_literal(str(max_value), dtype) return self._make_reduction_expr(node, "minimum", init_expr, **kwargs) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 92f7327218..9355273588 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -104,6 +104,15 @@ def visit_Program(self, node: past.Program, **kwargs: Any) -> past.Program: location=node.location, ) + def visit_Slice(self, node: past.Slice, **kwargs: Any) -> past.Slice: + return past.Slice( + lower=self.visit(node.lower, **kwargs), + upper=self.visit(node.upper, **kwargs), + step=self.visit(node.step, **kwargs), + type=ts.DeferredType(constraint=None), + location=node.location, + ) + def visit_Subscript(self, node: past.Subscript, **kwargs: Any) -> past.Subscript: value = self.visit(node.value, **kwargs) return past.Subscript( diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index 7958b7a8d3..1add668791 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -109,6 +109,7 @@ def _field_constituents_shape_and_dims( match arg_type: case ts.TupleType(): for el, el_type in zip(arg, arg_type.types): + assert isinstance(el_type, ts.DataType) yield from _field_constituents_shape_and_dims(el, el_type) case ts.FieldType(): dims = type_info.extract_dims(arg_type) diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 8160a2c42d..83ecf92839 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -169,7 +169,9 @@ def _scan_param_promotion(param: ts.TypeSpec, arg: ts.TypeSpec) -> ts.FieldType -------- >>> _scan_param_promotion( ... ts.ScalarType(kind=ts.ScalarKind.INT64), - ... ts.FieldType(dims=[common.Dimension("I")], dtype=ts.ScalarKind.FLOAT64), + ... ts.FieldType( + ... dims=[common.Dimension("I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ... ), ... ) FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)) """ @@ -252,8 +254,8 @@ def function_signature_incompatibilities_scanop( # build a function type to leverage the already existing signature checking capabilities function_type = ts.FunctionType( pos_only_args=[], - pos_or_kw_args=promoted_params, # type: ignore[arg-type] # dict is invariant, but we don't care here. - kw_only_args=promoted_kwparams, # type: ignore[arg-type] # same as above + pos_or_kw_args=promoted_params, + kw_only_args=promoted_kwparams, returns=ts.DeferredType(constraint=None), ) diff --git a/src/gt4py/next/ffront/type_specifications.py b/src/gt4py/next/ffront/type_specifications.py index e4f6c826fe..b76a116297 100644 --- a/src/gt4py/next/ffront/type_specifications.py +++ b/src/gt4py/next/ffront/type_specifications.py @@ -6,23 +6,19 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from dataclasses import dataclass import gt4py.next.type_system.type_specifications as ts -from gt4py.next import common as func_common +from gt4py.next import common -@dataclass(frozen=True) class ProgramType(ts.TypeSpec, ts.CallableType): definition: ts.FunctionType -@dataclass(frozen=True) class FieldOperatorType(ts.TypeSpec, ts.CallableType): definition: ts.FunctionType -@dataclass(frozen=True) class ScanOperatorType(ts.TypeSpec, ts.CallableType): - axis: func_common.Dimension + axis: common.Dimension definition: ts.FunctionType diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 13c64e264e..5949d29432 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -54,7 +54,6 @@ ) from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, runtime -from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.otf import arguments from gt4py.next.type_system import type_specifications as ts, type_translation @@ -1460,7 +1459,7 @@ class _List(Generic[DT]): def __getitem__(self, i: int): return self.values[i] - def __gt_type__(self) -> itir_ts.ListType: + def __gt_type__(self) -> ts.ListType: offset_tag = self.offset.value assert isinstance(offset_tag, str) element_type = type_translation.from_value(self.values[0]) @@ -1470,7 +1469,7 @@ def __gt_type__(self) -> itir_ts.ListType: connectivity = offset_provider[offset_tag] assert common.is_neighbor_connectivity(connectivity) local_dim = connectivity.__gt_type__().neighbor_dim - return itir_ts.ListType(element_type=element_type, offset_type=local_dim) + return ts.ListType(element_type=element_type, offset_type=local_dim) @dataclasses.dataclass(frozen=True) @@ -1480,10 +1479,10 @@ class _ConstList(Generic[DT]): def __getitem__(self, _): return self.value - def __gt_type__(self) -> itir_ts.ListType: + def __gt_type__(self) -> ts.ListType: element_type = type_translation.from_value(self.value) assert isinstance(element_type, ts.DataType) - return itir_ts.ListType( + return ts.ListType( element_type=element_type, offset_type=_CONST_DIM, ) @@ -1801,7 +1800,7 @@ def _fieldspec_list_to_value( domain: common.Domain, type_: ts.TypeSpec ) -> tuple[common.Domain, ts.TypeSpec]: """Translate the list element type into the domain.""" - if isinstance(type_, itir_ts.ListType): + if isinstance(type_, ts.ListType): if type_.offset_type == _CONST_DIM: return domain.insert( len(domain), common.named_range((_CONST_DIM, 1)) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index b7087472e0..cc42896f2b 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -20,10 +20,7 @@ inline_lifts, trace_shifts, ) -from gt4py.next.iterator.type_system import ( - inference as type_inference, - type_specifications as it_ts, -) +from gt4py.next.iterator.type_system import inference as type_inference from gt4py.next.type_system import type_info, type_specifications as ts @@ -140,7 +137,7 @@ def fuse_as_fieldop( if arg.type and not isinstance(arg.type, ts.DeferredType): assert isinstance(arg.type, ts.TypeSpec) dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) - assert not isinstance(dtype, it_ts.ListType) + assert not isinstance(dtype, ts.ListType) new_param: str if isinstance( arg, itir.SymRef @@ -246,7 +243,7 @@ def visit_FunCall(self, node: itir.FunCall): ) or cpm.is_call_to(arg, "if_") ) - and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) + and (isinstance(dtype, ts.ListType) or len(arg_shifts) <= 1) ) ) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 334fb330d7..ac7fcb8f1c 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -98,12 +98,12 @@ def _transform_by_pattern( tmp_expr.type, tuple_constructor=lambda *elements: tuple(elements), ) - tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = ( - type_info.apply_to_primitive_constituents( - type_info.extract_dtype, - tmp_expr.type, - tuple_constructor=lambda *elements: tuple(elements), - ) + tmp_dtypes: ( + ts.ScalarType | ts.ListType | tuple[ts.ScalarType | ts.ListType | tuple, ...] + ) = type_info.apply_to_primitive_constituents( + type_info.extract_dtype, + tmp_expr.type, + tuple_constructor=lambda *elements: tuple(elements), ) # allocate temporary for all tuple elements diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 1b980783fa..1da59546c0 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -275,8 +275,8 @@ def _get_dimensions(obj: Any): if isinstance(obj, common.Dimension): yield obj elif isinstance(obj, ts.TypeSpec): - for field in dataclasses.fields(obj.__class__): - yield from _get_dimensions(getattr(obj, field.name)) + for field in obj.__datamodel_fields__.keys(): + yield from _get_dimensions(getattr(obj, field)) elif isinstance(obj, collections.abc.Mapping): for el in obj.values(): yield from _get_dimensions(el) @@ -479,7 +479,7 @@ def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.Tup assert domain.dims != "unknown" assert node.dtype return type_info.apply_to_primitive_constituents( - lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), # type: ignore[arg-type] # ensured by domain.dims != "unknown" above + lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), node.dtype, ) diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index eef8c75d0f..7825bf1c98 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -6,43 +6,29 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import dataclasses -from typing import Literal, Optional +from typing import Literal from gt4py.next import common from gt4py.next.type_system import type_specifications as ts -@dataclasses.dataclass(frozen=True) class NamedRangeType(ts.TypeSpec): dim: common.Dimension -@dataclasses.dataclass(frozen=True) class DomainType(ts.DataType): dims: list[common.Dimension] | Literal["unknown"] -@dataclasses.dataclass(frozen=True) class OffsetLiteralType(ts.TypeSpec): value: ts.ScalarType | common.Dimension -@dataclasses.dataclass(frozen=True) -class ListType(ts.DataType): - element_type: ts.DataType - # TODO(havogt): the `offset_type` is not yet used in type_inference, - # it is meant to describe the neighborhood (via the local dimension) - offset_type: Optional[common.Dimension] = None - - -@dataclasses.dataclass(frozen=True) class IteratorType(ts.DataType, ts.CallableType): position_dims: list[common.Dimension] | Literal["unknown"] defined_dims: list[common.Dimension] element_type: ts.DataType -@dataclasses.dataclass(frozen=True) class ProgramType(ts.TypeSpec): params: dict[str, ts.DataType] diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 5be9ed7438..22a04ec04a 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -155,18 +155,18 @@ def if_(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType @_register_builtin_type_synthesizer -def make_const_list(scalar: ts.ScalarType) -> it_ts.ListType: +def make_const_list(scalar: ts.ScalarType) -> ts.ListType: assert isinstance(scalar, ts.ScalarType) - return it_ts.ListType(element_type=scalar) + return ts.ListType(element_type=scalar) @_register_builtin_type_synthesizer -def list_get(index: ts.ScalarType | it_ts.OffsetLiteralType, list_: it_ts.ListType) -> ts.DataType: +def list_get(index: ts.ScalarType | it_ts.OffsetLiteralType, list_: ts.ListType) -> ts.DataType: if isinstance(index, it_ts.OffsetLiteralType): assert isinstance(index.value, ts.ScalarType) index = index.value assert isinstance(index, ts.ScalarType) and type_info.is_integral(index) - assert isinstance(list_, it_ts.ListType) + assert isinstance(list_, ts.ListType) return list_.element_type @@ -198,14 +198,14 @@ def index(arg: ts.DimensionType) -> ts.FieldType: @_register_builtin_type_synthesizer -def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> it_ts.ListType: +def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> ts.ListType: assert ( isinstance(offset_literal, it_ts.OffsetLiteralType) and isinstance(offset_literal.value, common.Dimension) and offset_literal.value.kind == common.DimensionKind.LOCAL ) assert isinstance(it, it_ts.IteratorType) - return it_ts.ListType(element_type=it.element_type) + return ts.ListType(element_type=it.element_type) @_register_builtin_type_synthesizer @@ -270,7 +270,7 @@ def _convert_as_fieldop_input_to_iterator( else: defined_dims.append(dim) if is_nb_field: - element_type = it_ts.ListType(element_type=element_type) + element_type = ts.ListType(element_type=element_type) return it_ts.IteratorType( position_dims=domain.dims, defined_dims=defined_dims, element_type=element_type @@ -342,14 +342,14 @@ def apply_scan( def map_(op: TypeSynthesizer) -> TypeSynthesizer: @TypeSynthesizer def applied_map( - *args: it_ts.ListType, offset_provider_type: common.OffsetProviderType - ) -> it_ts.ListType: + *args: ts.ListType, offset_provider_type: common.OffsetProviderType + ) -> ts.ListType: assert len(args) > 0 - assert all(isinstance(arg, it_ts.ListType) for arg in args) + assert all(isinstance(arg, ts.ListType) for arg in args) arg_el_types = [arg.element_type for arg in args] el_type = op(*arg_el_types, offset_provider_type=offset_provider_type) assert isinstance(el_type, ts.DataType) - return it_ts.ListType(element_type=el_type) + return ts.ListType(element_type=el_type) return applied_map @@ -357,8 +357,8 @@ def applied_map( @_register_builtin_type_synthesizer def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer: @TypeSynthesizer - def applied_reduce(*args: it_ts.ListType, offset_provider_type: common.OffsetProviderType): - assert all(isinstance(arg, it_ts.ListType) for arg in args) + def applied_reduce(*args: ts.ListType, offset_provider_type: common.OffsetProviderType): + assert all(isinstance(arg, ts.ListType) for arg in args) return op( init, *(arg.element_type for arg in args), offset_provider_type=offset_provider_type ) diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 24913a1365..edd56fad48 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -86,6 +86,7 @@ def _type_string(type_: ts.TypeSpec) -> str: return f"std::tuple<{','.join(_type_string(t) for t in type_.types)}>" elif isinstance(type_, ts.FieldType): ndims = len(type_.dims) + assert isinstance(type_.dtype, ts.ScalarType) dtype = cpp_interface.render_scalar_type(type_.dtype) shape = f"nanobind::shape<{', '.join(['gridtools::nanobind::dynamic_size'] * ndims)}>" buffer_t = f"nanobind::ndarray<{dtype}, {shape}>" diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index d5b34fd5b9..f7bb1805e0 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -701,7 +701,7 @@ def visit_Temporary( def dtype_to_cpp(x: ts.DataType) -> str: if isinstance(x, ts.TupleType): assert all(isinstance(i, ts.ScalarType) for i in x.types) - return "::gridtools::tuple<" + ", ".join(dtype_to_cpp(i) for i in x.types) + ">" + return "::gridtools::tuple<" + ", ".join(dtype_to_cpp(i) for i in x.types) + ">" # type: ignore[arg-type] # ensured by assert assert isinstance(x, ts.ScalarType) res = pytype_to_cpptype(x) assert isinstance(res, str) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index cffbd74c90..354a9692d8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -23,7 +23,6 @@ domain_utils, ir_makers as im, ) -from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_dataflow, @@ -119,7 +118,7 @@ def get_local_view( ) elif len(local_dims) == 1: - field_dtype = itir_ts.ListType( + field_dtype = ts.ListType( element_type=self.gt_type.dtype, offset_type=local_dims[0] ) field_domain = [ @@ -267,10 +266,11 @@ def _create_field_operator( if isinstance(output_edge.result.gt_dtype, ts.ScalarType): assert output_edge.result.gt_dtype == node_type.dtype assert isinstance(dataflow_output_desc, dace.data.Scalar) + assert isinstance(node_type.dtype, ts.ScalarType) assert dataflow_output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) field_dtype = output_edge.result.gt_dtype else: - assert isinstance(node_type.dtype, itir_ts.ListType) + assert isinstance(node_type.dtype, ts.ListType) assert output_edge.result.gt_dtype.element_type == node_type.dtype.element_type assert isinstance(dataflow_output_desc, dace.data.Array) assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index a3653fb519..0376143883 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -31,7 +31,6 @@ from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_python_codegen, @@ -64,7 +63,7 @@ class ValueExpr: """ dc_node: dace.nodes.AccessNode - gt_dtype: itir_ts.ListType | ts.ScalarType + gt_dtype: ts.ListType | ts.ScalarType @dataclasses.dataclass(frozen=True) @@ -79,7 +78,7 @@ class MemletExpr: """ dc_node: dace.nodes.AccessNode - gt_dtype: itir_ts.ListType | ts.ScalarType + gt_dtype: ts.ListType | ts.ScalarType subset: dace_subsets.Range @@ -112,7 +111,7 @@ class IteratorExpr: """ field: dace.nodes.AccessNode - gt_dtype: itir_ts.ListType | ts.ScalarType + gt_dtype: ts.ListType | ts.ScalarType field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]] indices: dict[gtx_common.Dimension, DataExpr] @@ -121,7 +120,7 @@ def get_memlet_subset(self, sdfg: dace.SDFG) -> dace_subsets.Range: raise ValueError(f"Cannot deref iterator {self}.") field_desc = self.field.desc(sdfg) - if isinstance(self.gt_dtype, itir_ts.ListType): + if isinstance(self.gt_dtype, ts.ListType): assert len(field_desc.shape) == len(self.field_domain) + 1 assert self.gt_dtype.offset_type is not None field_domain = [*self.field_domain, (self.gt_dtype.offset_type, 0)] @@ -444,7 +443,7 @@ def _construct_tasklet_result( return ValueExpr( dc_node=temp_node, gt_dtype=( - itir_ts.ListType(element_type=data_type, offset_type=_CONST_DIM) + ts.ListType(element_type=data_type, offset_type=_CONST_DIM) if use_array else data_type ), @@ -547,7 +546,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: return self._construct_tasklet_result(field_desc.dtype, deref_node, "val") def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: - assert isinstance(node.type, itir_ts.ListType) + assert isinstance(node.type, ts.ListType) assert len(node.args) == 2 assert isinstance(node.args[0], gtir.OffsetLiteral) @@ -650,7 +649,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: ) return ValueExpr( - dc_node=neighbors_node, gt_dtype=itir_ts.ListType(node.type.element_type, offset_type) + dc_node=neighbors_node, gt_dtype=ts.ListType(node.type.element_type, offset_type) ) def _visit_map(self, node: gtir.FunCall) -> ValueExpr: @@ -669,7 +668,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: In above example, the result would be an array with size V2E.max_neighbors, containing the V2E neighbor values incremented by 1.0. """ - assert isinstance(node.type, itir_ts.ListType) + assert isinstance(node.type, ts.ListType) assert isinstance(node.fun, gtir.FunCall) assert len(node.fun.args) == 1 # the operation to be mapped on the arguments @@ -689,7 +688,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: gtx_common.Dimension, gtx_common.NeighborConnectivityType ] = {} for input_arg in input_args: - assert isinstance(input_arg.gt_dtype, itir_ts.ListType) + assert isinstance(input_arg.gt_dtype, ts.ListType) assert input_arg.gt_dtype.offset_type is not None offset_type = input_arg.gt_dtype.offset_type if offset_type == _CONST_DIM: @@ -759,7 +758,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: connectivity_slice = self._construct_local_view( MemletExpr( dc_node=self.state.add_access(connectivity), - gt_dtype=itir_ts.ListType( + gt_dtype=ts.ListType( element_type=node.type.element_type, offset_type=offset_type ), subset=dace_subsets.Range.from_string( @@ -798,7 +797,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: return ValueExpr( dc_node=result_node, - gt_dtype=itir_ts.ListType(node.type.element_type, offset_type), + gt_dtype=ts.ListType(node.type.element_type, offset_type), ) def _make_reduce_with_skip_values( @@ -825,7 +824,7 @@ def _make_reduce_with_skip_values( origin_map_index = dace_gtir_utils.get_map_variable(offset_provider_type.source_dim) assert ( - isinstance(input_expr.gt_dtype, itir_ts.ListType) + isinstance(input_expr.gt_dtype, ts.ListType) and input_expr.gt_dtype.offset_type is not None ) offset_type = input_expr.gt_dtype.offset_type @@ -938,7 +937,7 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: input_expr = self.visit(node.args[0]) assert isinstance(input_expr, (MemletExpr, ValueExpr)) assert ( - isinstance(input_expr.gt_dtype, itir_ts.ListType) + isinstance(input_expr.gt_dtype, ts.ListType) and input_expr.gt_dtype.offset_type is not None ) offset_type = input_expr.gt_dtype.offset_type @@ -1232,7 +1231,7 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: connector, ) - if isinstance(node.type, itir_ts.ListType): + if isinstance(node.type, ts.ListType): # The only builtin function (so far) handled here that returns a list # is 'make_const_list'. There are other builtin functions (map_, neighbors) # that return a list but they are handled in specialized visit methods. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 9bd40f75f8..10895ce66e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -290,6 +290,7 @@ def _add_storage( # represent zero-dimensional fields as scalar arguments return self._add_storage(sdfg, symbolic_arguments, name, gt_type.dtype, transient) # handle default case: field with one or more dimensions + assert isinstance(gt_type.dtype, ts.ScalarType) dc_dtype = dace_utils.as_dace_type(gt_type.dtype) if tuple_name is None: # Use symbolic shape, which allows to invoke the program with fields of different size; diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 118f0449c8..c46420c24b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -45,17 +45,18 @@ def get_tuple_fields( ... ("a_1_1", sty), ... ] """ + assert all(isinstance(t, ts.DataType) for t in tuple_type.types) fields = [(f"{tuple_name}_{i}", field_type) for i, field_type in enumerate(tuple_type.types)] if flatten: - expanded_fields = [ + expanded_fields: list[list[tuple[str, ts.DataType]]] = [ get_tuple_fields(field_name, field_type) if isinstance(field_type, ts.TupleType) - else [(field_name, field_type)] + else [(field_name, field_type)] # type: ignore[list-item] # checked in assert for field_name, field_type in fields ] return list(itertools.chain(*expanded_fields)) else: - return fields + return fields # type: ignore[return-value] # checked in assert def replace_invalid_symbols(sdfg: dace.SDFG, ir: gtir.Program) -> gtir.Program: diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 66f8937dc5..983063a9cb 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -78,15 +78,15 @@ def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: >>> type_class(ts.TupleType(types=[])).__name__ 'TupleType' """ - match symbol_type: - case ts.DeferredType(constraint): - if constraint is None: - raise ValueError(f"No type information available for '{symbol_type}'.") - elif isinstance(constraint, tuple): - raise ValueError(f"Not sufficient type information available for '{symbol_type}'.") - return constraint - case ts.TypeSpec() as concrete_type: - return concrete_type.__class__ + if isinstance(symbol_type, ts.DeferredType): + constraint = symbol_type.constraint + if constraint is None: + raise ValueError(f"No type information available for '{symbol_type}'.") + elif isinstance(constraint, tuple): + raise ValueError(f"Not sufficient type information available for '{symbol_type}'.") + return constraint + if isinstance(symbol_type, ts.TypeSpec): + return symbol_type.__class__ raise ValueError( f"Invalid type for TypeInfo: requires '{ts.TypeSpec}', got '{type(symbol_type)}'." ) @@ -197,7 +197,7 @@ def apply_to_primitive_constituents( return fun(*symbol_types) -def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType: +def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.ListType: """ Extract the data type from ``symbol_type`` if it is either `FieldType` or `ScalarType`. @@ -234,7 +234,10 @@ def is_floating_point(symbol_type: ts.TypeSpec) -> bool: >>> is_floating_point(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))) True """ - return extract_dtype(symbol_type).kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] + return isinstance(dtype := extract_dtype(symbol_type), ts.ScalarType) and dtype.kind in [ + ts.ScalarKind.FLOAT32, + ts.ScalarKind.FLOAT64, + ] def is_integer(symbol_type: ts.TypeSpec) -> bool: @@ -295,7 +298,10 @@ def is_number(symbol_type: ts.TypeSpec) -> bool: def is_logical(symbol_type: ts.TypeSpec) -> bool: - return extract_dtype(symbol_type).kind is ts.ScalarKind.BOOL + return ( + isinstance(dtype := extract_dtype(symbol_type), ts.ScalarType) + and dtype.kind is ts.ScalarKind.BOOL + ) def is_arithmetic(symbol_type: ts.TypeSpec) -> bool: @@ -385,11 +391,10 @@ def extract_dims(symbol_type: ts.TypeSpec) -> list[common.Dimension]: >>> extract_dims(ts.FieldType(dims=[I, J], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64))) [Dimension(value='I', kind=), Dimension(value='J', kind=)] """ - match symbol_type: - case ts.ScalarType(): - return [] - case ts.FieldType(dims): - return dims + if isinstance(symbol_type, ts.ScalarType): + return [] + if isinstance(symbol_type, ts.FieldType): + return symbol_type.dims raise ValueError(f"Can not extract dimensions from '{symbol_type}'.") @@ -502,7 +507,9 @@ def promote( return types[0] elif all(isinstance(type_, (ts.ScalarType, ts.FieldType)) for type_ in types): dims = common.promote_dims(*(extract_dims(type_) for type_ in types)) - dtype = cast(ts.ScalarType, promote(*(extract_dtype(type_) for type_ in types))) + extracted_dtypes = [extract_dtype(type_) for type_ in types] + assert all(isinstance(dtype, ts.ScalarType) for dtype in extracted_dtypes) + dtype = cast(ts.ScalarType, promote(*extracted_dtypes)) # type: ignore[arg-type] # checked is `ScalarType` return ts.FieldType(dims=dims, dtype=dtype) raise TypeError("Expected a 'FieldType' or 'ScalarType'.") diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index fa8c9b9ab1..060d56aea2 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -6,21 +6,13 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from dataclasses import dataclass from typing import Iterator, Optional, Sequence, Union -from gt4py.eve.type_definitions import IntEnum -from gt4py.eve.utils import content_hash -from gt4py.next import common as func_common +from gt4py.eve import datamodels as eve_datamodels, type_definitions as eve_types +from gt4py.next import common -@dataclass(frozen=True) -class TypeSpec: - def __hash__(self) -> int: - return hash(content_hash(self)) - - def __init_subclass__(cls) -> None: - cls.__hash__ = TypeSpec.__hash__ # type: ignore[method-assign] +class TypeSpec(eve_datamodels.DataModel, kw_only=False, frozen=True): ... # type: ignore[call-arg] class DataType(TypeSpec): @@ -40,14 +32,12 @@ class CallableType: """ -@dataclass(frozen=True) class DeferredType(TypeSpec): """Dummy used to represent a type not yet inferred.""" constraint: Optional[type[TypeSpec] | tuple[type[TypeSpec], ...]] -@dataclass(frozen=True) class VoidType(TypeSpec): """ Return type of a function without return values. @@ -56,22 +46,20 @@ class VoidType(TypeSpec): """ -@dataclass(frozen=True) class DimensionType(TypeSpec): - dim: func_common.Dimension + dim: common.Dimension -@dataclass(frozen=True) class OffsetType(TypeSpec): # TODO(havogt): replace by ConnectivityType - source: func_common.Dimension - target: tuple[func_common.Dimension] | tuple[func_common.Dimension, func_common.Dimension] + source: common.Dimension + target: tuple[common.Dimension] | tuple[common.Dimension, common.Dimension] def __str__(self) -> str: return f"Offset[{self.source}, {self.target}]" -class ScalarKind(IntEnum): +class ScalarKind(eve_types.IntEnum): BOOL = 1 INT32 = 32 INT64 = 64 @@ -80,7 +68,6 @@ class ScalarKind(IntEnum): STRING = 3001 -@dataclass(frozen=True) class ScalarType(DataType): kind: ScalarKind shape: Optional[list[int]] = None @@ -92,31 +79,43 @@ def __str__(self) -> str: return f"{kind_str}{self.shape}" -@dataclass(frozen=True) -class TupleType(DataType): - types: list[DataType] - - def __str__(self) -> str: - return f"tuple[{', '.join(map(str, self.types))}]" +class ListType(DataType): + """Represents a neighbor list in the ITIR representation. - def __iter__(self) -> Iterator[DataType]: - yield from self.types + Note: not used in the frontend. + """ - def __len__(self) -> int: - return len(self.types) + element_type: DataType + # TODO(havogt): the `offset_type` is not yet used in type_inference, + # it is meant to describe the neighborhood (via the local dimension) + offset_type: Optional[common.Dimension] = None -@dataclass(frozen=True) class FieldType(DataType, CallableType): - dims: list[func_common.Dimension] - dtype: ScalarType + dims: list[common.Dimension] + dtype: ScalarType | ListType def __str__(self) -> str: dims = "..." if self.dims is Ellipsis else f"[{', '.join(dim.value for dim in self.dims)}]" return f"Field[{dims}, {self.dtype}]" -@dataclass(frozen=True) +class TupleType(DataType): + # TODO(tehrengruber): Remove `DeferredType` again. This was erroneously + # introduced before we checked the annotations at runtime. All attributes of + # a type that are types themselves must be concrete. + types: list[DataType | DimensionType | DeferredType] + + def __str__(self) -> str: + return f"tuple[{', '.join(map(str, self.types))}]" + + def __iter__(self) -> Iterator[DataType | DimensionType | DeferredType]: + yield from self.types + + def __len__(self) -> int: + return len(self.types) + + class FunctionType(TypeSpec, CallableType): pos_only_args: Sequence[TypeSpec] pos_or_kw_args: dict[str, TypeSpec] diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 62a6781316..e601556e55 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -10,7 +10,6 @@ import builtins import collections.abc -import dataclasses import functools import types import typing @@ -105,7 +104,7 @@ def from_type_hint( raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.") tuple_types = [recursive_make_symbol(arg) for arg in args] assert all(isinstance(elem, ts.DataType) for elem in tuple_types) - return ts.TupleType(types=tuple_types) # type: ignore[arg-type] # checked in assert + return ts.TupleType(types=tuple_types) case common.Field: if (n_args := len(args)) != 2: @@ -168,7 +167,6 @@ def from_type_hint( raise ValueError(f"'{type_hint}' type is not supported.") -@dataclasses.dataclass(frozen=True) class UnknownPythonObject(ts.TypeSpec): _object: Any @@ -217,9 +215,9 @@ def from_value(value: Any) -> ts.TypeSpec: # not needed anymore. elems = [from_value(el) for el in value] assert all(isinstance(elem, ts.DataType) for elem in elems) - return ts.TupleType(types=elems) # type: ignore[arg-type] # checked in assert + return ts.TupleType(types=elems) elif isinstance(value, types.ModuleType): - return UnknownPythonObject(_object=value) + return UnknownPythonObject(value) else: type_ = xtyping.infer_type(value, annotate_callable_kwargs=True) symbol_type = from_type_hint(type_) diff --git a/tests/eve_tests/unit_tests/test_datamodels.py b/tests/eve_tests/unit_tests/test_datamodels.py index 05be5f3db0..75b07fd8a0 100644 --- a/tests/eve_tests/unit_tests/test_datamodels.py +++ b/tests/eve_tests/unit_tests/test_datamodels.py @@ -10,9 +10,9 @@ import enum import numbers +import sys import types import typing -from typing import Set # noqa: F401 [unused-import] used in exec() context from typing import ( Any, Callable, @@ -26,6 +26,7 @@ MutableSequence, Optional, Sequence, + Set, # noqa: F401 [unused-import] used in exec() context Tuple, Type, TypeVar, @@ -555,6 +556,18 @@ class WrongModel: ("typing.MutableSequence[int]", ([1, 2, 3], []), ((1, 2, 3), tuple(), 1, [1.0], {1})), ("typing.Set[int]", ({1, 2, 3}, set()), (1, [1], (1,), {1: None})), ("typing.Union[int, float, str]", [1, 3.0, "one"], [[1], [], 1j]), + pytest.param( + "int | float | str", + [1, 3.0, "one"], + [[1], [], 1j], + marks=pytest.mark.skipif(sys.version_info < (3, 10), reason="| union syntax not supported"), + ), + pytest.param( + "typing.List[int|float]", + [[1, 2.0], []], + [1, 2.0, [1, "2.0"]], + marks=pytest.mark.skipif(sys.version_info < (3, 10), reason="| union syntax not supported"), + ), ("typing.Optional[int]", [1, None], [[1], [], 1j]), ( "typing.Dict[Union[int, float, str], Union[Tuple[int, Optional[float]], Set[int]]]", diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 7eb4e86adb..b6b70af07c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -46,8 +46,8 @@ bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) float64_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) -float64_list_type = it_ts.ListType(element_type=float64_type) -int_list_type = it_ts.ListType(element_type=int_type) +float64_list_type = ts.ListType(element_type=float64_type) +int_list_type = ts.ListType(element_type=int_type) float_i_field = ts.FieldType(dims=[IDim], dtype=float64_type) float_vertex_k_field = ts.FieldType(dims=[Vertex, KDim], dtype=float64_type) @@ -77,8 +77,8 @@ def expression_test_cases(): (im.deref(im.ref("it", it_on_e_of_e_type)), it_on_e_of_e_type.element_type), (im.call("can_deref")(im.ref("it", it_on_e_of_e_type)), bool_type), (im.if_(True, 1, 2), int_type), - (im.call("make_const_list")(True), it_ts.ListType(element_type=bool_type)), - (im.call("list_get")(0, im.ref("l", it_ts.ListType(element_type=bool_type))), bool_type), + (im.call("make_const_list")(True), ts.ListType(element_type=bool_type)), + (im.call("list_get")(0, im.ref("l", ts.ListType(element_type=bool_type))), bool_type), ( im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), it_ts.NamedRangeType(dim=Vertex), @@ -110,7 +110,7 @@ def expression_test_cases(): # neighbors ( im.neighbors("E2V", im.ref("a", it_on_e_of_e_type)), - it_ts.ListType(element_type=it_on_e_of_e_type.element_type), + ts.ListType(element_type=it_on_e_of_e_type.element_type), ), # cast (im.call("cast_")(1, "int32"), int_type), From 9a56fbd710c72bdceb9bbba26a6a6622b6e42e54 Mon Sep 17 00:00:00 2001 From: Rico Haeuselmann Date: Mon, 13 Jan 2025 18:56:24 +0100 Subject: [PATCH 086/178] feat[next]: SDFGConvertible Program for dace_fieldview backend (#1742) Add a `decorator.Program` subclass, which implements `SDFGConvertible` to `dace_fieldview` backend, analogous to the one in `dace_iterator`. --------- Co-authored-by: Edoardo Paone --- src/gt4py/next/ffront/decorator.py | 6 +- .../next/iterator/transforms/extractors.py | 72 +++++ .../runners/dace_fieldview/program.py | 248 ++++++++++++++++++ .../runners/dace_fieldview/workflow.py | 4 +- .../feature_tests/dace/test_orchestration.py | 93 +++---- .../feature_tests/dace/test_program.py | 134 ++++++++++ .../iterator_tests/test_extractors.py | 102 +++++++ 7 files changed, 606 insertions(+), 53 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/extractors.py create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/program.py create mode 100644 tests/next_tests/integration_tests/feature_tests/dace/test_program.py create mode 100644 tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index d187095019..d1631a461d 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -80,7 +80,9 @@ class Program: definition_stage: ffront_stages.ProgramDefinition backend: Optional[next_backend.Backend] - connectivities: Optional[common.OffsetProviderType] = None + connectivities: Optional[common.OffsetProvider] = ( + None # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information + ) @classmethod def from_function( @@ -304,7 +306,7 @@ def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: try: - from gt4py.next.program_processors.runners.dace_iterator import Program + from gt4py.next.program_processors.runners.dace_fieldview.program import Program except ImportError: pass diff --git a/src/gt4py/next/iterator/transforms/extractors.py b/src/gt4py/next/iterator/transforms/extractors.py new file mode 100644 index 0000000000..04c2b09139 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/extractors.py @@ -0,0 +1,72 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py import eve +from gt4py.next.iterator import ir as itir +from gt4py.next.type_system import type_specifications as ts + + +class SymbolNameSetExtractor(eve.NodeVisitor): + """Extract a set of symbol names""" + + def visit_Literal(self, node: itir.Literal) -> set[str]: + return set() + + def generic_visitor(self, node: itir.Node) -> set[str]: + input_fields: set[str] = set() + for child in eve.trees.iter_children_values(node): + input_fields |= self.visit(child) + return input_fields + + def visit_Node(self, node: itir.Node) -> set[str]: + return set() + + def visit_Program(self, node: itir.Program) -> set[str]: + names = set() + for stmt in node.body: + names |= self.visit(stmt) + return names + + def visit_IfStmt(self, node: itir.IfStmt) -> set[str]: + names = set() + for stmt in node.true_branch + node.false_branch: + names |= self.visit(stmt) + return names + + def visit_Temporary(self, node: itir.Temporary) -> set[str]: + return set() + + def visit_SymRef(self, node: itir.SymRef) -> set[str]: + return {str(node.id)} + + @classmethod + def only_fields(cls, program: itir.Program) -> set[str]: + field_param_names = [ + str(param.id) for param in program.params if isinstance(param.type, ts.FieldType) + ] + return {name for name in cls().visit(program) if name in field_param_names} + + +class InputNamesExtractor(SymbolNameSetExtractor): + """Extract the set of symbol names passed into field operators within a program.""" + + def visit_SetAt(self, node: itir.SetAt) -> set[str]: + return self.visit(node.expr) + + def visit_FunCall(self, node: itir.FunCall) -> set[str]: + input_fields = set() + for arg in node.args: + input_fields |= self.visit(arg) + return input_fields + + +class OutputNamesExtractor(SymbolNameSetExtractor): + """Extract the set of symbol names written to within a program""" + + def visit_SetAt(self, node: itir.SetAt) -> set[str]: + return self.visit(node.target) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/program.py b/src/gt4py/next/program_processors/runners/dace_fieldview/program.py new file mode 100644 index 0000000000..7f809152c5 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/program.py @@ -0,0 +1,248 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import collections +import dataclasses +import itertools +import typing +from typing import Any, ClassVar, Optional, Sequence + +import dace +import numpy as np + +from gt4py.next import backend as next_backend, common +from gt4py.next.ffront import decorator +from gt4py.next.iterator import ir as itir, transforms as itir_transforms +from gt4py.next.iterator.transforms import extractors as extractors +from gt4py.next.otf import arguments, recipes, toolchain +from gt4py.next.program_processors.runners.dace_common import utility as dace_utils +from gt4py.next.type_system import type_specifications as ts + + +@dataclasses.dataclass(frozen=True) +class Program(decorator.Program, dace.frontend.python.common.SDFGConvertible): + """Extension of GT4Py Program implementing the SDFGConvertible interface via GTIR.""" + + sdfg_closure_cache: dict[str, Any] = dataclasses.field(default_factory=dict) + # Being a ClassVar ensures that in an SDFG with multiple nested GT4Py Programs, + # there is no name mangling of the connectivity tables used across the nested SDFGs + # since they share the same memory address. + connectivity_tables_data_descriptors: ClassVar[ + dict[str, dace.data.Array] + ] = {} # symbolically defined + + def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: + if (self.backend is None) or "dace" not in self.backend.name.lower(): + raise ValueError("The SDFG can be generated only for the DaCe backend.") + + offset_provider: common.OffsetProvider = { + **(self.connectivities or {}), + **self._implicit_offset_provider, + } + column_axis = kwargs.get("column_axis", None) + + # TODO(ricoh): connectivity tables required here for now. + gtir_stage = typing.cast(next_backend.Transforms, self.backend.transforms).past_to_itir( + toolchain.CompilableProgram( + data=self.past_stage, + args=arguments.CompileTimeArgs( + args=tuple(p.type for p in self.past_stage.past_node.params), + kwargs={}, + column_axis=column_axis, + offset_provider=offset_provider, + ), + ) + ) + program = gtir_stage.data + program = itir_transforms.apply_fieldview_transforms( # run the transforms separately because they require the runtime info + program, offset_provider=offset_provider + ) + object.__setattr__( + gtir_stage, + "data", + program, + ) + object.__setattr__( + gtir_stage.args, "offset_provider", gtir_stage.args.offset_provider_type + ) # TODO(ricoh): currently this is circumventing the frozenness of CompileTimeArgs + # in order to isolate DaCe from the runtime tables in connectivities.offset_provider. + # These are needed at the time of writing for mandatory GTIR passes. + # Remove this as soon as Program does not expect connectivity tables anymore. + + _crosscheck_dace_parsing( + dace_parsed_args=[*args, *kwargs.values()], + gt4py_program_args=[p.type for p in program.params], + ) + + compile_workflow = typing.cast( + recipes.OTFCompileWorkflow, + self.backend.executor + if not hasattr(self.backend.executor, "step") + else self.backend.executor.step, + ) # We know which backend we are using, but we don't know if the compile workflow is cached. + # TODO(ricoh): switch 'itir_transforms_off=True' because we ran them separately previously + # and so we can ensure the SDFG does not know any runtime info it shouldn't know. Remove with + # the other parts of the workaround when possible. + sdfg = dace.SDFG.from_json( + compile_workflow.translation.replace(itir_transforms_off=True)(gtir_stage).source_code + ) + + self.sdfg_closure_cache["arrays"] = sdfg.arrays + + # Halo exchange related metadata, i.e. gt4py_program_input_fields, gt4py_program_output_fields, + # offset_providers_per_input_field. Add them as dynamic attributes to the SDFG + field_params = { + str(param.id): param for param in program.params if isinstance(param.type, ts.FieldType) + } + + def single_horizontal_dim_per_field( + fields: typing.Iterable[itir.Sym], + ) -> typing.Iterator[tuple[str, common.Dimension]]: + for field in fields: + assert isinstance(field.type, ts.FieldType) + horizontal_dims = [ + dim for dim in field.type.dims if dim.kind is common.DimensionKind.HORIZONTAL + ] + # do nothing for fields with multiple horizontal dimensions + # or without horizontal dimensions + # this is only meant for use with unstructured grids + if len(horizontal_dims) == 1: + yield str(field.id), horizontal_dims[0] + + input_fields = ( + field_params[name] for name in extractors.InputNamesExtractor.only_fields(program) + ) + sdfg.gt4py_program_input_fields = dict(single_horizontal_dim_per_field(input_fields)) + + output_fields = ( + field_params[name] for name in extractors.OutputNamesExtractor.only_fields(program) + ) + sdfg.gt4py_program_output_fields = dict(single_horizontal_dim_per_field(output_fields)) + + # TODO (ricoh): bring back sdfg.offset_providers_per_input_field. + # A starting point would be to use the "trace_shifts" pass on GTIR + # and associate the extracted shifts with each input field. + # Analogous to the version in `runners.dace_iterator.__init__`, which + # was removed when merging #1742. + + return sdfg + + def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[str, Any]: + """ + Return the closure arrays of the SDFG represented by this object + as a mapping between array name and the corresponding value. + + The connectivity tables are defined symbolically, i.e. table sizes & strides are DaCe symbols. + The need to define the connectivity tables in the `__sdfg_closure__` arises from the fact that + the offset providers are not part of GT4Py Program's arguments. + Keep in mind, that `__sdfg_closure__` is called after `__sdfg__` method. + """ + closure_dict: dict[str, Any] = {} + + if self.connectivities: + symbols = {} + with_table = [ + name for name, conn in self.connectivities.items() if common.is_neighbor_table(conn) + ] + in_arrays_with_id = [ + (name, conn_id) + for name in with_table + if (conn_id := dace_utils.connectivity_identifier(name)) + in self.sdfg_closure_cache["arrays"] + ] + in_arrays = (name for name, _ in in_arrays_with_id) + name_axis = list(itertools.product(in_arrays, [0, 1])) + + def size_symbol_name(name: str, axis: int) -> str: + return dace_utils.field_size_symbol_name( + dace_utils.connectivity_identifier(name), axis + ) + + connectivity_tables_size_symbols = { + (sname := size_symbol_name(name, axis)): dace.symbol(sname) + for name, axis in name_axis + } + + def stride_symbol_name(name: str, axis: int) -> str: + return dace_utils.field_stride_symbol_name( + dace_utils.connectivity_identifier(name), axis + ) + + connectivity_table_stride_symbols = { + (sname := stride_symbol_name(name, axis)): dace.symbol(sname) + for name, axis in name_axis + } + + symbols = connectivity_tables_size_symbols | connectivity_table_stride_symbols + + # Define the storage location (e.g. CPU, GPU) of the connectivity tables + if "storage" not in self.connectivity_tables_data_descriptors: + for _, conn_id in in_arrays_with_id: + self.connectivity_tables_data_descriptors["storage"] = self.sdfg_closure_cache[ + "arrays" + ][conn_id].storage + break + + # Build the closure dictionary + for name, conn_id in in_arrays_with_id: + if conn_id not in self.connectivity_tables_data_descriptors: + conn = self.connectivities[name] + assert common.is_neighbor_table(conn) + self.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( + dtype=dace.dtypes.dtype_to_typeclass(conn.dtype.dtype.type), + shape=[ + symbols[dace_utils.field_size_symbol_name(conn_id, 0)], + symbols[dace_utils.field_size_symbol_name(conn_id, 1)], + ], + strides=[ + symbols[dace_utils.field_stride_symbol_name(conn_id, 0)], + symbols[dace_utils.field_stride_symbol_name(conn_id, 1)], + ], + storage=Program.connectivity_tables_data_descriptors["storage"], + ) + closure_dict[conn_id] = self.connectivity_tables_data_descriptors[conn_id] + + return closure_dict + + def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]: + return [p.id for p in self.past_stage.past_node.params], [] + + +def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: list[Any]) -> None: + for dace_parsed_arg, gt4py_program_arg in zip( + dace_parsed_args, + gt4py_program_args, + strict=False, # dace does not see implicit size args + ): + match dace_parsed_arg: + case dace.data.Scalar(): + assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg) + case bool() | np.bool_(): + assert isinstance(gt4py_program_arg, ts.ScalarType) + assert gt4py_program_arg.kind == ts.ScalarKind.BOOL + case int() | np.integer(): + assert isinstance(gt4py_program_arg, ts.ScalarType) + assert gt4py_program_arg.kind in [ts.ScalarKind.INT32, ts.ScalarKind.INT64] + case float() | np.floating(): + assert isinstance(gt4py_program_arg, ts.ScalarType) + assert gt4py_program_arg.kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] + case str() | np.str_(): + assert isinstance(gt4py_program_arg, ts.ScalarType) + assert gt4py_program_arg.kind == ts.ScalarKind.STRING + case dace.data.Array(): + assert isinstance(gt4py_program_arg, ts.FieldType) + assert isinstance(gt4py_program_arg.dtype, ts.ScalarType) + assert len(dace_parsed_arg.shape) == len(gt4py_program_arg.dims) + assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg.dtype) + case dace.data.Structure() | dict() | collections.OrderedDict(): + # offset provider + pass + case _: + raise ValueError( + f"Unresolved case for {dace_parsed_arg} (==, !=) {gt4py_program_arg}" + ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index a38a50d886..779dc8a1c9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -37,6 +37,7 @@ class DaCeTranslator( ): device_type: core_defs.DeviceType auto_optimize: bool + itir_transforms_off: bool = False def _language_settings(self) -> languages.LanguageSettings: return languages.LanguageSettings( @@ -51,7 +52,8 @@ def generate_sdfg( auto_opt: bool, on_gpu: bool, ) -> dace.SDFG: - ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) + if not self.itir_transforms_off: + ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) sdfg = gtir_sdfg.build_sdfg_from_gtir( ir, offset_provider_type=common.offset_provider_to_type(offset_provider) ) diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index 08904c06f3..cd71c306eb 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -12,15 +12,16 @@ import gt4py.next as gtx from gt4py.next import allocators as gtx_allocators, common as gtx_common +from gt4py._core import definitions as core_defs from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import cartesian_case, unstructured_case +from next_tests.integration_tests.cases import cartesian_case, unstructured_case # noqa: F401 from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( E2V, E2VDim, Edge, Vertex, - exec_alloc_descriptor, - mesh_descriptor, + exec_alloc_descriptor, # noqa: F401 + mesh_descriptor, # noqa: F401 ) from next_tests.integration_tests.multi_feature_tests.ffront_tests.test_laplacian import ( lap_program, @@ -37,23 +38,17 @@ pytestmark = pytest.mark.requires_dace -def test_sdfgConvertible_laplap(cartesian_case): +def test_sdfgConvertible_laplap(cartesian_case): # noqa: F811 if not cartesian_case.backend or "dace" not in cartesian_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") - # TODO(ricoh): enable test after adding GTIR support - pytest.skip("DaCe SDFGConvertible interface does not support GTIR program.") - - allocator, backend = unstructured_case.allocator, unstructured_case.backend - - if gtx_allocators.is_field_allocator_factory_for(allocator, gtx_allocators.CUPY_DEVICE): - import cupy as xp - else: - import numpy as xp + backend = cartesian_case.backend in_field = cases.allocate(cartesian_case, laplap_program, "in_field")() out_field = cases.allocate(cartesian_case, laplap_program, "out_field")() + xp = in_field.array_ns + # Test DaCe closure support @dace.program def sdfg(): @@ -88,16 +83,13 @@ def testee(a: gtx.Field[gtx.Dims[Vertex], gtx.float64], b: gtx.Field[gtx.Dims[Ed @pytest.mark.uses_unstructured_shift -def test_sdfgConvertible_connectivities(unstructured_case): +def test_sdfgConvertible_connectivities(unstructured_case): # noqa: F811 if not unstructured_case.backend or "dace" not in unstructured_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") - # TODO(ricoh): enable test after adding GTIR support - pytest.skip("DaCe SDFGConvertible interface does not support GTIR program.") - allocator, backend = unstructured_case.allocator, unstructured_case.backend - if gtx_allocators.is_field_allocator_factory_for(allocator, gtx_allocators.CUPY_DEVICE): + if gtx_allocators.is_field_allocator_for(allocator, gtx_allocators.CUPY_DEVICE): import cupy as xp dace_storage_type = dace.StorageType.GPU_Global @@ -113,6 +105,15 @@ def test_sdfgConvertible_connectivities(unstructured_case): name="OffsetProvider", ) + e2v = gtx.as_connectivity( + [Edge, E2VDim], + codomain=Vertex, + data=xp.asarray([[0, 1], [1, 2], [2, 0]]), + allocator=allocator, + ) + + testee2 = testee.with_backend(backend).with_connectivities({"E2V": e2v}) + @dace.program def sdfg( a: dace.data.Array(dtype=dace.float64, shape=(rows,), storage=dace_storage_type), @@ -120,17 +121,10 @@ def sdfg( offset_provider: OffsetProvider_t, connectivities: dace.compiletime, ): - testee.with_backend(backend).with_connectivities(connectivities)( - a, out, offset_provider=offset_provider - ) + testee2.with_connectivities(connectivities)(a, out, offset_provider=offset_provider) + return out - e2v = gtx.as_connectivity( - [Edge, E2VDim], - codomain=Vertex, - data=xp.asarray([[0, 1], [1, 2], [2, 0]]), - allocator=allocator, - ) - connectivities = {"E2V": e2v.__gt_type__()} + connectivities = {"E2V": e2v} # replace 'e2v' with 'e2v.__gt_type__()' when GTIR is AOT offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) SDFG = sdfg.to_sdfg(connectivities=connectivities) @@ -138,23 +132,21 @@ def sdfg( a = gtx.as_field([Vertex], xp.asarray([0.0, 1.0, 2.0]), allocator=allocator) out = gtx.zeros({Edge: 3}, allocator=allocator) - e2v_ndarray_copy = ( - e2v.ndarray.copy() - ) # otherwise DaCe complains about the gt4py custom allocated view - # This is a low level interface to call the compiled SDFG. - # It is not supposed to be used in user code. - # The high level interface should be provided by a DaCe Orchestrator, - # i.e. decorator that hides the low level operations. - # This test checks only that the SDFGConvertible interface works correctly. + + def get_stride_from_numpy_to_dace(arg: core_defs.NDArrayObject, axis: int) -> int: + # NumPy strides: number of bytes to jump + # DaCe strides: number of elements to jump + return arg.strides[axis] // arg.itemsize + cSDFG( a, out, offset_provider, rows=3, cols=2, - connectivity_E2V=e2v_ndarray_copy, - __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 0), - __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 1), + connectivity_E2V=e2v, + __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v.ndarray, 0), + __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v.ndarray, 1), ) e2v_np = e2v.asnumpy() @@ -166,18 +158,19 @@ def sdfg( data=xp.asarray([[1, 0], [2, 1], [0, 2]]), allocator=allocator, ) - e2v_ndarray_copy = e2v.ndarray.copy() offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) - cSDFG( - a, - out, - offset_provider, - rows=3, - cols=2, - connectivity_E2V=e2v_ndarray_copy, - __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 0), - __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 1), - ) + with dace.config.temporary_config(): + dace.config.Config.set("compiler", "allow_view_arguments", value=True) + cSDFG( + a, + out, + offset_provider, + rows=3, + cols=2, + connectivity_E2V=e2v, + __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v.ndarray, 0), + __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v.ndarray, 1), + ) e2v_np = e2v.asnumpy() assert np.allclose(out.asnumpy(), a.asnumpy()[e2v_np[:, 0]]) diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py new file mode 100644 index 0000000000..db0f90b409 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py @@ -0,0 +1,134 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from gt4py import next as gtx +from gt4py.next import common + +from next_tests.integration_tests import cases +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + Cell, + Edge, + IDim, + JDim, + KDim, + Vertex, + mesh_descriptor, # noqa: F401 +) + + +try: + import dace + + from gt4py.next.program_processors.runners import dace as dace_backends +except ImportError: + from types import ModuleType + from typing import Optional + + from gt4py.next import backend as next_backend + + dace: Optional[ModuleType] = None + dace_backends: Optional[ModuleType] = None + + +@pytest.fixture( + params=[ + pytest.param(dace_backends.run_dace_cpu, marks=pytest.mark.requires_dace), + pytest.param( + dace_backends.run_dace_gpu, marks=(pytest.mark.requires_gpu, pytest.mark.requires_dace) + ), + ] +) +def gtir_dace_backend(request): + yield request.param + + +@pytest.fixture +def cartesian(request, gtir_dace_backend): + if gtir_dace_backend is None: + yield None + + yield cases.Case( + backend=gtir_dace_backend, + offset_provider={ + "Ioff": IDim, + "Joff": JDim, + "Koff": KDim, + }, + default_sizes={IDim: 10, JDim: 10, KDim: 10}, + grid_type=common.GridType.CARTESIAN, + allocator=gtir_dace_backend.allocator, + ) + + +@pytest.fixture +def unstructured(request, gtir_dace_backend, mesh_descriptor): # noqa: F811 + if gtir_dace_backend is None: + yield None + + yield cases.Case( + backend=gtir_dace_backend, + offset_provider=mesh_descriptor.offset_provider, + default_sizes={ + Vertex: mesh_descriptor.num_vertices, + Edge: mesh_descriptor.num_edges, + Cell: mesh_descriptor.num_cells, + KDim: 10, + }, + grid_type=common.GridType.UNSTRUCTURED, + allocator=gtir_dace_backend.allocator, + ) + + +@pytest.mark.skipif(dace is None, reason="DaCe not found") +def test_halo_exchange_helper_attrs(unstructured): + local_int = gtx.int + + @gtx.field_operator(backend=unstructured.backend) + def testee_op( + a: gtx.Field[[Vertex, KDim], gtx.int], + ) -> gtx.Field[[Vertex, KDim], gtx.int]: + return a + local_int(10) + + @gtx.program(backend=unstructured.backend) + def testee_prog( + a: gtx.Field[[Vertex, KDim], gtx.int], + b: gtx.Field[[Vertex, KDim], gtx.int], + c: gtx.Field[[Vertex, KDim], gtx.int], + ): + testee_op(b, out=c) + testee_op(a, out=b) + + dace_storage_type = ( + dace.StorageType.GPU_Global + if unstructured.backend == dace_backends.run_dace_gpu + else dace.StorageType.Default + ) + + rows = dace.symbol("rows") + cols = dace.symbol("cols") + + @dace.program + def testee_dace( + a: dace.data.Array(dtype=dace.int64, shape=(rows, cols), storage=dace_storage_type), + b: dace.data.Array(dtype=dace.int64, shape=(rows, cols), storage=dace_storage_type), + c: dace.data.Array(dtype=dace.int64, shape=(rows, cols), storage=dace_storage_type), + ): + testee_prog(a, b, c) + + # if simplify=True, DaCe might inline the nested SDFG coming from Program.__sdfg__, + # effectively erasing the attributes we want to test for here + sdfg = testee_dace.to_sdfg(simplify=False) + + testee = next( + subgraph for subgraph in sdfg.all_sdfgs_recursive() if subgraph.name == "testee_prog" + ) + + assert testee.gt4py_program_input_fields == {"a": Vertex, "b": Vertex} + assert testee.gt4py_program_output_fields == {"b": Vertex, "c": Vertex} diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py new file mode 100644 index 0000000000..7358ab3d8f --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py @@ -0,0 +1,102 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import typing + +import pytest + +from gt4py import next as gtx +from gt4py.next import common +from gt4py.next.iterator.transforms import extractors + +from next_tests.integration_tests import cases +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + IDim, + JDim, + KDim, +) + + +if typing.TYPE_CHECKING: + from types import ModuleType + from typing import Optional + +try: + import dace + + from gt4py.next.program_processors.runners.dace import run_dace_cpu +except ImportError: + from gt4py.next import backend as next_backend + + dace: Optional[ModuleType] = None + run_dace_cpu: Optional[next_backend.Backend] = None + + +@pytest.fixture(params=[pytest.param(run_dace_cpu, marks=pytest.mark.requires_dace), gtx.gtfn_cpu]) +def gtir_dace_backend(request): + yield request.param + + +@pytest.fixture +def cartesian(request, gtir_dace_backend): + if gtir_dace_backend is None: + yield None + + yield cases.Case( + backend=gtir_dace_backend, + offset_provider={ + "Ioff": IDim, + "Joff": JDim, + "Koff": KDim, + }, + default_sizes={IDim: 10, JDim: 10, KDim: 10}, + grid_type=common.GridType.CARTESIAN, + allocator=gtir_dace_backend.allocator, + ) + + +@pytest.mark.skipif(dace is None, reason="DaCe not found") +def test_input_names_extractor_cartesian(cartesian): + @gtx.field_operator(backend=cartesian.backend) + def testee_op( + a: gtx.Field[[IDim, JDim, KDim], gtx.int], + ) -> gtx.Field[[IDim, JDim, KDim], gtx.int]: + return a + + @gtx.program(backend=cartesian.backend) + def testee( + a: gtx.Field[[IDim, JDim, KDim], gtx.int], + b: gtx.Field[[IDim, JDim, KDim], gtx.int], + c: gtx.Field[[IDim, JDim, KDim], gtx.int], + ): + testee_op(b, out=c) + testee_op(a, out=b) + + input_field_names = extractors.InputNamesExtractor.only_fields(testee.gtir) + assert input_field_names == {"a", "b"} + + +@pytest.mark.skipif(dace is None, reason="DaCe not found") +def test_output_names_extractor(cartesian): + @gtx.field_operator(backend=cartesian.backend) + def testee_op( + a: gtx.Field[[IDim, JDim, KDim], gtx.int], + ) -> gtx.Field[[IDim, JDim, KDim], gtx.int]: + return a + + @gtx.program(backend=cartesian.backend) + def testee( + a: gtx.Field[[IDim, JDim, KDim], gtx.int], + b: gtx.Field[[IDim, JDim, KDim], gtx.int], + c: gtx.Field[[IDim, JDim, KDim], gtx.int], + ): + testee_op(a, out=b) + testee_op(a, out=c) + + output_field_names = extractors.OutputNamesExtractor.only_fields(testee.gtir) + assert output_field_names == {"b", "c"} From 4dc153137d51d7305d4c743415b7076032b36948 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 14 Jan 2025 09:53:41 +0100 Subject: [PATCH 087/178] ci: Re-enable CI on GH200 (#1653) Co-authored-by: edopao --- ci/cscs-ci.yml | 4 +--- .../multi_feature_tests/test_code_generation.py | 5 +++++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index c2a872c1c4..349089ebfa 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -50,8 +50,6 @@ stages: CUPY_PACKAGE: cupy-cuda12x CUPY_VERSION: 13.3.0 UBUNTU_VERSION: 22.04 - # TODO: enable CI job when Todi is back in operational state - when: manual build_py311_baseimage_x86_64: extends: .build_baseimage_x86_64 @@ -133,7 +131,7 @@ build_py310_image_aarch64: VARIANT: [-nomesh, -atlas] SUBVARIANT: [-cuda11x, -cpu] .test_helper_aarch64: - extends: [.container-runner-todi-gh200, .test_helper] + extends: [.container-runner-daint-gh200, .test_helper] parallel: matrix: - SUBPACKAGE: [cartesian, storage] diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 5a43144b4b..57c52eae12 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -583,6 +583,11 @@ def test_K_offset_write(backend): if backend == "cuda": pytest.skip("cuda K-offset write generates bad code") + if backend == "dace:gpu": + pytest.skip( + f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1684" + ) + arraylib = get_array_library(backend) array_shape = (1, 1, 4) K_values = arraylib.arange(start=40, stop=44) From 8346bcd226c0588cd21baf4102d08c96993c763e Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 14 Jan 2025 11:31:30 +0100 Subject: [PATCH 088/178] bug[next]: Fix CSE inside stencil (#1793) The common subexpression elimination did not work at all inside of stencils due to a typo. Fixed & added a test to cover this. --- src/gt4py/next/iterator/transforms/cse.py | 12 ++++++---- .../transforms_tests/test_cse.py | 24 ++++++++++--------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 4f3fcbfdd5..ccaaf563f5 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -86,7 +86,7 @@ def _is_collectable_expr(node: itir.Node) -> bool: # conceptual problems (other parts of the tool chain rely on the arguments being present directly # on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend # backend (single pass eager depth first visit approach) - if isinstance(node.fun, itir.SymRef) and node.fun.id in ["lift", "shift", "reduce"]: + if isinstance(node.fun, itir.SymRef) and node.fun.id in ["lift", "shift", "reduce", "map_"]: return False return True elif isinstance(node, itir.Lambda): @@ -429,9 +429,9 @@ def apply( return cls().visit(node, within_stencil=within_stencil) def generic_visit(self, node, **kwargs): - if cpm.is_call_to("as_fieldop", node): + if cpm.is_call_to(node, "as_fieldop"): assert not kwargs.get("within_stencil") - within_stencil = cpm.is_call_to("as_fieldop", node) or kwargs.get("within_stencil") + within_stencil = cpm.is_call_to(node, "as_fieldop") or kwargs.get("within_stencil") return super().generic_visit(node, **(kwargs | {"within_stencil": within_stencil})) @@ -443,10 +443,14 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): def predicate(subexpr: itir.Expr, num_occurences: int): # note: be careful here with the syntatic context: the expression might be in local - # view, even though the syntactic context `node` is in field view. + # view, even though the syntactic context of `node` is in field view. # note: what is extracted is sketched in the docstring above. keep it updated. if num_occurences > 1: if within_stencil: + # TODO(tehrengruber): Lists must not be extracted to avoid errors in partial + # shift detection of UnrollReduce pass. Solve there. See #1795. + if isinstance(subexpr.type, ts.ListType): + return False return True # condition is only necessary since typing on lambdas is not preserved during # the transformation diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index f4ea2d7fe1..14860d9bdd 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -26,17 +26,9 @@ def offset_provider_type(request): def test_trivial(): - common = ir.FunCall(fun=ir.SymRef(id="plus"), args=[ir.SymRef(id="x"), ir.SymRef(id="y")]) - testee = ir.FunCall(fun=ir.SymRef(id="plus"), args=[common, common]) - expected = ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="_cs_1")], - expr=ir.FunCall( - fun=ir.SymRef(id="plus"), args=[ir.SymRef(id="_cs_1"), ir.SymRef(id="_cs_1")] - ), - ), - args=[common], - ) + common = im.plus("x", "y") + testee = im.plus(common, common) + expected = im.let("_cs_1", common)(im.plus("_cs_1", "_cs_1")) actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -291,3 +283,13 @@ def test_field_extraction_outside_asfieldop(): actual = CSE.apply(testee, within_stencil=False) assert actual == expected + + +def test_scalar_extraction_inside_as_fieldop(): + common = im.plus(1, 2) + + testee = im.as_fieldop(im.lambda_()(im.plus(common, common)))() + expected = im.as_fieldop(im.lambda_()(im.let("_cs_1", common)(im.plus("_cs_1", "_cs_1"))))() + + actual = CSE.apply(testee, within_stencil=False) + assert actual == expected From db5325bf060a5b783b185241a39d5fa27331cf7e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 14 Jan 2025 12:36:38 +0100 Subject: [PATCH 089/178] fix[next]: gtfn with offset name != local dimension name (#1789) --- src/gt4py/next/otf/binding/nanobind.py | 1 + .../codegens/gtfn/gtfn_module.py | 5 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 4 ++ .../runners/dace_fieldview/gtir_sdfg.py | 1 + .../next/type_system/type_specifications.py | 2 +- tests/next_tests/integration_tests/cases.py | 2 +- .../test_offset_dimensions_names.py | 63 +++++++++++++++++++ 7 files changed, 74 insertions(+), 4 deletions(-) create mode 100644 tests/next_tests/regression_tests/ffront_tests/test_offset_dimensions_names.py diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index edd56fad48..3abf49788f 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -86,6 +86,7 @@ def _type_string(type_: ts.TypeSpec) -> str: return f"std::tuple<{','.join(_type_string(t) for t in type_.types)}>" elif isinstance(type_, ts.FieldType): ndims = len(type_.dims) + # cannot be ListType: the concept is represented as Field with local Dimension in this interface assert isinstance(type_.dtype, ts.ScalarType) dtype = cpp_interface.render_scalar_type(type_.dtype) shape = f"nanobind::shape<{', '.join(['gridtools::nanobind::dynamic_size'] * ndims)}>" diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 020b1f55ea..48f15acffb 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -135,8 +135,9 @@ def _process_connectivity_args( # connectivity argument expression nbtbl = ( f"gridtools::fn::sid_neighbor_table::as_neighbor_table<" - f"generated::{connectivity_type.source_dim.value}_t, " - f"generated::{name}_t, {connectivity_type.max_neighbors}" + f"generated::{connectivity_type.domain[0].value}_t, " + f"generated::{connectivity_type.domain[1].value}_t, " + f"{connectivity_type.max_neighbors}" f">(std::forward({GENERATED_CONNECTIVITY_PARAM_PREFIX}{name.lower()}))" ) arg_exprs.append( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index f7bb1805e0..3dc7998a54 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -209,6 +209,10 @@ def _collect_offset_definitions( ): assert grid_type == common.GridType.UNSTRUCTURED offset_definitions[offset_name] = TagDefinition(name=Sym(id=offset_name)) + if offset_name != connectivity_type.neighbor_dim.value: + offset_definitions[connectivity_type.neighbor_dim.value] = TagDefinition( + name=Sym(id=connectivity_type.neighbor_dim.value) + ) for dim in [connectivity_type.source_dim, connectivity_type.codomain]: if dim.kind != common.DimensionKind.HORIZONTAL: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 10895ce66e..baddb7b699 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -290,6 +290,7 @@ def _add_storage( # represent zero-dimensional fields as scalar arguments return self._add_storage(sdfg, symbolic_arguments, name, gt_type.dtype, transient) # handle default case: field with one or more dimensions + # ListType not supported: concept is represented as Field with local Dimension assert isinstance(gt_type.dtype, ts.ScalarType) dc_dtype = dace_utils.as_dace_type(gt_type.dtype) if tuple_name is None: diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 060d56aea2..c1c0f0b5e1 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -82,7 +82,7 @@ def __str__(self) -> str: class ListType(DataType): """Represents a neighbor list in the ITIR representation. - Note: not used in the frontend. + Note: not used in the frontend. The concept is represented as Field with local Dimension. """ element_type: DataType diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 759cd1cf1f..8a78307f87 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -28,6 +28,7 @@ common, constructors, field_utils, + utils as gt_utils, ) from gt4py.next.ffront import decorator from gt4py.next.type_system import type_specifications as ts, type_translation @@ -55,7 +56,6 @@ mesh_descriptor, ) -from gt4py.next import utils as gt_utils # mypy does not accept [IDim, ...] as a type diff --git a/tests/next_tests/regression_tests/ffront_tests/test_offset_dimensions_names.py b/tests/next_tests/regression_tests/ffront_tests/test_offset_dimensions_names.py new file mode 100644 index 0000000000..f95ed4c3a7 --- /dev/null +++ b/tests/next_tests/regression_tests/ffront_tests/test_offset_dimensions_names.py @@ -0,0 +1,63 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from gt4py import next as gtx +from gt4py.next import Dims, Field, common + +from next_tests import definitions as test_defs +from next_tests.integration_tests import cases +from next_tests.integration_tests.feature_tests.ffront_tests import ffront_test_utils + + +V = gtx.Dimension("V") +E = gtx.Dimension("E") +Neigh = gtx.Dimension("Neigh", kind=common.DimensionKind.LOCAL) +Off = gtx.FieldOffset("Off", source=E, target=(V, Neigh)) + + +@pytest.fixture +def case(): + mesh = ffront_test_utils.simple_mesh() + exec_alloc_descriptor = test_defs.ProgramBackendId.GTFN_CPU.load() + v2e_arr = mesh.offset_provider["V2E"].ndarray + return cases.Case( + exec_alloc_descriptor, + offset_provider={ + "Off": common._connectivity( + v2e_arr, + codomain=E, + domain={V: v2e_arr.shape[0], Neigh: 4}, + skip_value=None, + ), + }, + default_sizes={ + V: mesh.num_vertices, + E: mesh.num_edges, + }, + grid_type=common.GridType.UNSTRUCTURED, + allocator=exec_alloc_descriptor.allocator, + ) + + +def test_offset_dimension_name_differ(case): + """ + Ensure that gtfn works with offset name that differs from the name of the local dimension. + + If the value of the `NeighborConnectivityType.neighbor_dim` did not match the `FieldOffset` value, + gtfn would silently ignore the neighbor index, see https://github.com/GridTools/gridtools/pull/1814. + """ + + @gtx.field_operator + def foo(a: Field[Dims[E], float]) -> Field[Dims[V], float]: + return a(Off[1]) + + cases.verify_with_default_data( + case, foo, lambda a: a[case.offset_provider["Off"].ndarray[:, 1]] + ) From 21e7f6463b362717a360fabb23557e0922f9be29 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 15 Jan 2025 03:01:42 +0100 Subject: [PATCH 090/178] feature[next]: Non-tree-size-increasing collapse tuple on ifs (#1762) Removing the tuple expressions across `if_` calls on ITIR has been a pain point in the past. While the `PROPAGATE_TO_IF_ON_TUPLES` option of the `CollapseTuplePass` works very reliably, the resulting increase in the tree size has been prohibitive. With the refactoring to GTIR this problem became much less pronounced, as we could restrict the propagation to field-level, i.e., outside of stencils, but the tree still grew exponentially in the number of references to boolean arguments used inside `if_` conditions. This PR adds an additional option `PROPAGATE_TO_IF_ON_TUPLES_CPS` to the `CollapseTuplePass`, which is similar to the existing `PROPAGATE_TO_IF_ON_TUPLES`, but propagates in the opposite direction, i.e. into the tree. This allows removal of tuple expressions across `if_` calls without increasing the size of the tree. This is particularly important for `if` statements in the frontend, where outwards propagation can have devastating effects on the tree size, without any gained optimization potential. For example ``` complex_lambda(if cond1 if cond2 {...} else: {...} else {...}) ``` is problematic, since `PROPAGATE_TO_IF_ON_TUPLES` would propagate, and hence duplicate, `complex_lambda` three times, while we only want to get rid of the tuple expressions inside of the `if_`s. Note that this transformation is not mutually exclusive to `PROPAGATE_TO_IF_ON_TUPLES`. --- .../iterator/transforms/collapse_tuple.py | 191 ++++++++++++++++-- .../next/iterator/transforms/pass_manager.py | 25 ++- .../next/iterator/type_system/inference.py | 76 +++++-- .../iterator/type_system/type_synthesizer.py | 4 + .../iterator_tests/test_type_inference.py | 54 ++++- .../transforms_tests/test_collapse_tuple.py | 76 ++++++- 6 files changed, 376 insertions(+), 50 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index b64886f729..0a0cf6d37e 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -28,10 +28,11 @@ from gt4py.next.type_system import type_info, type_specifications as ts -def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): +def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr | str): """Given a itir.FunCall return a new call with one of its argument replaced.""" return ir.FunCall( - fun=node.fun, args=[arg if i != arg_idx else new_arg for i, arg in enumerate(node.args)] + fun=node.fun, + args=[arg if i != arg_idx else im.ensure_expr(new_arg) for i, arg in enumerate(node.args)], ) @@ -47,6 +48,39 @@ def _is_trivial_make_tuple_call(node: ir.Expr): return True +def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: + """ + Return `true` if the expr is a trivial expression (`SymRef` or `Literal`) or tuple thereof. + + Let forms with trivial body and args as well as `if` calls with trivial branches are also + considered trivial. + + >>> _is_trivial_or_tuple_thereof_expr(im.make_tuple("a", "b")) + True + >>> _is_trivial_or_tuple_thereof_expr(im.tuple_get(1, "a")) + True + >>> _is_trivial_or_tuple_thereof_expr( + ... im.let("t", im.make_tuple("a", "b"))(im.tuple_get(1, "t")) + ... ) + True + """ + if isinstance(node, (ir.SymRef, ir.Literal)): + return True + if cpm.is_call_to(node, "make_tuple"): + return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args) + if cpm.is_call_to(node, "tuple_get"): + return _is_trivial_or_tuple_thereof_expr(node.args[1]) + # This will duplicate the condition and increase the size of the tree, but this is probably + # acceptable. + if cpm.is_call_to(node, "if_"): + return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args[1:]) + if cpm.is_let(node): + return _is_trivial_or_tuple_thereof_expr(node.fun.expr) and all( # type: ignore[attr-defined] # ensured by is_let + _is_trivial_or_tuple_thereof_expr(arg) for arg in node.args + ) + return False + + # TODO(tehrengruber): Conceptually the structure of this pass makes sense: Visit depth first, # transform each node until no transformations apply anymore, whenever a node is to be transformed # go through all available transformation and apply them. However the final result here still @@ -76,28 +110,42 @@ class Flag(enum.Flag): #: `let(tup, {trivial_expr1, trivial_expr2})(foo(tup))` #: -> `foo({trivial_expr1, trivial_expr2})` INLINE_TRIVIAL_MAKE_TUPLE = enum.auto() + #: Similar as `PROPAGATE_TO_IF_ON_TUPLES`, but propagates in the opposite direction, i.e. + #: into the tree, allowing removal of tuple expressions across `if_` calls without + #: increasing the size of the tree. This is particularly important for `if` statements + #: in the frontend, where outwards propagation can have devastating effects on the tree + #: size, without any gained optimization potential. For example + #: ``` + #: complex_lambda(if cond1 + #: if cond2 + #: {...} + #: else: + #: {...} + #: else + #: {...}) + #: ``` + #: is problematic, since `PROPAGATE_TO_IF_ON_TUPLES` would propagate and hence duplicate + #: `complex_lambda` three times, while we only want to get rid of the tuple expressions + #: inside of the `if_`s. + #: Note that this transformation is not mutually exclusive to `PROPAGATE_TO_IF_ON_TUPLES`. + PROPAGATE_TO_IF_ON_TUPLES_CPS = enum.auto() #: `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` PROPAGATE_TO_IF_ON_TUPLES = enum.auto() #: `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` PROPAGATE_NESTED_LET = enum.auto() - #: `let(a, 1)(a)` -> `1` + #: `let(a, 1)(a)` -> `1` or `let(a, b)(f(a))` -> `f(a)` INLINE_TRIVIAL_LET = enum.auto() @classmethod def all(self) -> CollapseTuple.Flag: return functools.reduce(operator.or_, self.__members__.values()) + uids: eve_utils.UIDGenerator ignore_tuple_size: bool flags: Flag = Flag.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] PRESERVED_ANNEX_ATTRS = ("type",) - # we use one UID generator per instance such that the generated ids are - # stable across multiple runs (required for caching to properly work) - _letify_make_tuple_uids: eve_utils.UIDGenerator = dataclasses.field( - init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_tuple_el") - ) - @classmethod def apply( cls, @@ -111,6 +159,7 @@ def apply( flags: Optional[Flag] = None, # allow sym references without a symbol declaration, mostly for testing allow_undeclared_symbols: bool = False, + uids: Optional[eve_utils.UIDGenerator] = None, ) -> ir.Node: """ Simplifies `make_tuple`, `tuple_get` calls. @@ -127,6 +176,7 @@ def apply( """ flags = flags or cls.flags offset_provider_type = offset_provider_type or {} + uids = uids or eve_utils.UIDGenerator() if isinstance(node, ir.Program): within_stencil = False @@ -145,6 +195,7 @@ def apply( new_node = cls( ignore_tuple_size=ignore_tuple_size, flags=flags, + uids=uids, ).visit(node, within_stencil=within_stencil) # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important @@ -185,6 +236,10 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: method = getattr(self, f"transform_{transformation.name.lower()}") result = method(node, **kwargs) if result is not None: + assert ( + result is not node + ) # transformation should have returned None, since nothing changed + itir_type_inference.reinfer(result) return result return None @@ -263,13 +318,13 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall, **kwargs) -> Op if node.fun == ir.SymRef(id="make_tuple"): # `make_tuple(expr1, expr1)` # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` - bound_vars: dict[str, ir.Expr] = {} + bound_vars: dict[ir.Sym, ir.Expr] = {} new_args: list[ir.Expr] = [] for arg in node.args: if cpm.is_call_to(node, "make_tuple") and not _is_trivial_make_tuple_call(node): - el_name = self._letify_make_tuple_uids.sequential_id() - new_args.append(im.ref(el_name)) - bound_vars[el_name] = arg + el_name = self.uids.sequential_id(prefix="__ct_el") + new_args.append(im.ref(el_name, arg.type)) + bound_vars[im.sym(el_name, arg.type)] = arg else: new_args.append(arg) @@ -312,6 +367,102 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt return im.if_(cond, new_true_branch, new_false_branch) return None + def transform_propagate_to_if_on_tuples_cps( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: + # The basic idea of this transformation is to remove tuples across if-stmts by rewriting + # the expression in continuation passing style, e.g. something like a tuple reordering + # ``` + # let t = if True then {1, 2} else {3, 4} in + # {t[1], t[0]}) + # end + # ``` + # is rewritten into: + # ``` + # let cont = λ(el0, el1) → {el1, el0} in + # if True then cont(1, 2) else cont(3, 4) + # end + # ``` + # Note how the `make_tuple` call argument of the `if` disappears. Since lambda functions + # are currently inlined (due to limitations of the domain inference) we will only + # gain something compared `PROPAGATE_TO_IF_ON_TUPLES` if the continuation `cont` is trivial, + # e.g. a `make_tuple` call like in the example. In that case we can inline the trivial + # continuation and end up with an only moderately larger tree, e.g. + # `if True then {2, 1} else {4, 3}`. The examples in the comments below all refer to this + # tuple reordering example here. + + if cpm.is_call_to(node, "if_"): + return None + + # The first argument that is eligible also transforms all remaining args (They will be + # part of the continuation which is recursively transformed). + for i, arg in enumerate(node.args): + if cpm.is_call_to(arg, "if_"): + itir_type_inference.reinfer(arg) + + cond, true_branch, false_branch = arg.args # e.g. `True`, `{1, 2}`, `{3, 4}` + if not any( + isinstance(branch.type, ts.TupleType) for branch in [true_branch, false_branch] + ): + continue + tuple_type: ts.TupleType = true_branch.type # type: ignore[assignment] # type ensured above + tuple_len = len(tuple_type.types) + + # build and simplify continuation, e.g. λ(el0, el1) → {el1, el0} + itir_type_inference.reinfer(node) + assert node.type + f_type = ts.FunctionType( # type of continuation in order to keep full type info + pos_only_args=tuple_type.types, + pos_or_kw_args={}, + kw_only_args={}, + returns=node.type, + ) + f_params = [ + im.sym(self.uids.sequential_id(prefix="__ct_el_cps"), type_) + for type_ in tuple_type.types + ] + f_args = [im.ref(param.id, param.type) for param in f_params] + f_body = _with_altered_arg(node, i, im.make_tuple(*f_args)) + # simplify, e.g., inline trivial make_tuple args + new_f_body = self.fp_transform(f_body, **kwargs) + # if the continuation did not simplify there is nothing to gain. Skip + # transformation of this argument. + if new_f_body is f_body: + continue + # if the function is not trivial the transformation we would create a larger tree + # after inlining so we skip transformation this argument. + if not _is_trivial_or_tuple_thereof_expr(new_f_body): + continue + f = im.lambda_(*f_params)(new_f_body) + + # this is the symbol refering to the tuple value inside the two branches of the + # if, e.g. a symbol refering to `{1, 2}` and `{3, 4}` respectively + tuple_var = self.uids.sequential_id(prefix="__ct_tuple_cps") + # this is the symbol refering to our continuation, e.g. `cont` in our example. + f_var = self.uids.sequential_id(prefix="__ct_cont") + new_branches = [] + for branch in [true_branch, false_branch]: + new_branch = im.let(tuple_var, branch)( + im.call(im.ref(f_var, f_type))( # call to the continuation + *( + im.tuple_get(i, im.ref(tuple_var, branch.type)) + for i in range(tuple_len) + ) + ) + ) + new_branches.append(self.fp_transform(new_branch, **kwargs)) + + # assemble everything together + new_node = im.let(f_var, f)(im.if_(cond, *new_branches)) + new_node = inline_lambda(new_node, eligible_params=[True]) + assert cpm.is_call_to(new_node, "if_") + new_node = im.if_( + cond, *(self.fp_transform(branch, **kwargs) for branch in new_node.args[1:]) + ) + return new_node + + return None + def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` @@ -339,9 +490,13 @@ def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional return None def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if cpm.is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let - # `let(a, 1)(a)` -> `1` - for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let - if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let - return arg + if cpm.is_let(node): + if isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let + # `let(a, 1)(a)` -> `1` + for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let + if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let + return arg + if any(trivial_args := [isinstance(arg, (ir.SymRef, ir.Literal)) for arg in node.args]): + return inline_lambda(node, eligible_params=trivial_args) + return None diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index d967c8fbb8..6906f81e3f 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -62,6 +62,7 @@ def apply_common_transforms( tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") mergeasfop_uids = eve_utils.UIDGenerator() + collapse_tuple_uids = eve_utils.UIDGenerator() ir = MergeLet().visit(ir) ir = inline_fundefs.InlineFundefs().visit(ir) @@ -73,7 +74,12 @@ def apply_common_transforms( # Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) - ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + ir = CollapseTuple.apply( + ir, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + uids=collapse_tuple_uids, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir ) # domain inference does not support dynamic offsets yet @@ -90,7 +96,12 @@ def apply_common_transforms( inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # always an itir.Program # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply(inlined, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + inlined = CollapseTuple.apply( + inlined, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + uids=collapse_tuple_uids, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program inlined = InlineScalar.apply(inlined, offset_provider_type=offset_provider_type) # This pass is required to run after CollapseTuple as otherwise we can not inline @@ -122,7 +133,11 @@ def apply_common_transforms( # only run the unconditional version here instead of in the loop above. if unconditionally_collapse_tuples: ir = CollapseTuple.apply( - ir, ignore_tuple_size=True, offset_provider_type=offset_provider_type + ir, + ignore_tuple_size=True, + uids=collapse_tuple_uids, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + offset_provider_type=offset_provider_type, ) # type: ignore[assignment] # always an itir.Program ir = NormalizeShifts().visit(ir) @@ -160,7 +175,9 @@ def apply_fieldview_transforms( ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) ir = CollapseTuple.apply( - ir, offset_provider_type=common.offset_provider_to_type(offset_provider) + ir, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + offset_provider_type=common.offset_provider_to_type(offset_provider), ) # type: ignore[assignment] # type is still `itir.Program` ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 1da59546c0..d0d39cbd34 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -292,7 +292,9 @@ def type_synthesizer(*args, **kwargs): assert type_info.accepts_args(fun_type, with_args=list(args), with_kwargs=kwargs) return fun_type.returns - return type_synthesizer + return ObservableTypeSynthesizer( + type_synthesizer=type_synthesizer, store_inferred_type_in_node=False + ) class SanitizeTypes(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): @@ -312,6 +314,15 @@ def visit_Node(self, node: itir.Node, *, symtable: dict[str, itir.Node]) -> itir T = TypeVar("T", bound=itir.Node) +_INITIAL_CONTEXT = { + name: ObservableTypeSynthesizer( + type_synthesizer=type_synthesizer.builtin_type_synthesizers[name], + # builtin functions are polymorphic + store_inferred_type_in_node=False, + ) + for name in type_synthesizer.builtin_type_synthesizers.keys() +} + @dataclasses.dataclass class ITIRTypeInference(eve.NodeTranslator): @@ -323,11 +334,13 @@ class ITIRTypeInference(eve.NodeTranslator): PRESERVED_ANNEX_ATTRS = ("domain",) - offset_provider_type: common.OffsetProviderType + offset_provider_type: Optional[common.OffsetProviderType] #: Mapping from a dimension name to the actual dimension instance. - dimensions: dict[str, common.Dimension] + dimensions: Optional[dict[str, common.Dimension]] #: Allow sym refs to symbols that have not been declared. Mostly used in testing. allow_undeclared_symbols: bool + #: Reinference-mode skipping already typed nodes. + reinfer: bool @classmethod def apply( @@ -420,24 +433,45 @@ def apply( ) ), allow_undeclared_symbols=allow_undeclared_symbols, + reinfer=False, ) if not inplace: node = copy.deepcopy(node) - instance.visit( - node, - ctx={ - name: ObservableTypeSynthesizer( - type_synthesizer=type_synthesizer.builtin_type_synthesizers[name], - # builtin functions are polymorphic - store_inferred_type_in_node=False, - ) - for name in type_synthesizer.builtin_type_synthesizers.keys() - }, + instance.visit(node, ctx=_INITIAL_CONTEXT) + return node + + @classmethod + def apply_reinfer(cls, node: T) -> T: + """ + Given a partially typed node infer the type of ``node`` and its sub-nodes. + + Contrary to the regular inference, this method does not descend into already typed sub-nodes + and can be used as a lightweight way to restore type information during a pass. + + Note that this function alters the input node, which is usually desired, and more + performant. + + Arguments: + node: The :class:`itir.Node` to infer the types of. + """ + if node.type: # already inferred + return node + + instance = cls( + offset_provider_type=None, dimensions=None, allow_undeclared_symbols=True, reinfer=True ) + instance.visit(node, ctx=_INITIAL_CONTEXT) return node def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + # we found a node that is typed, do not descend into children + if self.reinfer and isinstance(node, itir.Node) and node.type: + if isinstance(node.type, ts.FunctionType): + return _type_synthesizer_from_function_type(node.type) + return node.type + result = super().visit(node, **kwargs) + if isinstance(node, itir.Node): if isinstance(result, ts.TypeSpec): if node.type and not isinstance(node.type, ts.DeferredType): @@ -519,19 +553,23 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: ) def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionType: - assert ( - node.value in self.dimensions - ), f"Dimension {node.value} not present in offset provider." - return ts.DimensionType(dim=self.dimensions[node.value]) + return ts.DimensionType(dim=common.Dimension(value=node.value, kind=node.kind)) # TODO: revisit what we want to do with OffsetLiterals as we already have an Offset type in # the frontend. - def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs) -> it_ts.OffsetLiteralType: + def visit_OffsetLiteral( + self, node: itir.OffsetLiteral, **kwargs + ) -> it_ts.OffsetLiteralType | ts.DeferredType: + # `self.dimensions` not available in re-inference mode. Skip since we don't care anyway. + if self.reinfer: + return ts.DeferredType(constraint=it_ts.OffsetLiteralType) + if _is_representable_as_int(node.value): return it_ts.OffsetLiteralType( value=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())) ) else: + assert isinstance(self.dimensions, dict) assert isinstance(node.value, str) and node.value in self.dimensions return it_ts.OffsetLiteralType(value=self.dimensions[node.value]) @@ -608,3 +646,5 @@ def visit_Node(self, node: itir.Node, **kwargs): infer = ITIRTypeInference.apply + +reinfer = ITIRTypeInference.apply_reinfer diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 22a04ec04a..6e9936c4af 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -94,6 +94,10 @@ def power(base: ts.ScalarType, exponent: ts.ScalarType) -> ts.ScalarType: @_register_builtin_type_synthesizer(fun_names=itir.BINARY_MATH_NUMBER_BUILTINS) def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType: + if isinstance(lhs, ts.DeferredType): + return rhs + if isinstance(rhs, ts.DeferredType): + return lhs assert lhs == rhs return lhs diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index b6b70af07c..d4d7c60d69 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -5,6 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import copy # TODO: test failure when something is not typed after inference is run # TODO: test lift with no args @@ -15,6 +16,7 @@ import pytest +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.type_system import ( @@ -80,7 +82,9 @@ def expression_test_cases(): (im.call("make_const_list")(True), ts.ListType(element_type=bool_type)), (im.call("list_get")(0, im.ref("l", ts.ListType(element_type=bool_type))), bool_type), ( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 + ), it_ts.NamedRangeType(dim=Vertex), ), ( @@ -91,7 +95,9 @@ def expression_test_cases(): ), ( im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1) + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 + ) ), it_ts.DomainType(dims=[Vertex]), ), @@ -157,8 +163,14 @@ def expression_test_cases(): im.call("as_fieldop")( im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), - im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), + 0, + 1, + ), + im.call("named_range")( + itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 + ), ), ) )(im.ref("inp", float_edge_k_field)), @@ -309,8 +321,12 @@ def test_cartesian_fencil_definition(): def test_unstructured_fencil_definition(): mesh = simple_mesh() unstructured_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), - im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 + ), + im.call("named_range")( + itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 + ), ) testee = itir.Program( @@ -376,8 +392,12 @@ def test_function_definition(): def test_fencil_with_nb_field_input(): mesh = simple_mesh() unstructured_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), - im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 + ), + im.call("named_range")( + itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 + ), ) testee = itir.Program( @@ -501,3 +521,21 @@ def test_as_fieldop_without_domain(): assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( position_dims="unknown", defined_dims=float_i_field.dims, element_type=float_i_field.dtype ) + + +def test_reinference(): + testee = im.make_tuple(im.ref("inp1", float_i_field), im.ref("inp2", float_i_field)) + result = itir_type_inference.reinfer(copy.deepcopy(testee)) + assert result.type == ts.TupleType(types=[float_i_field, float_i_field]) + + +def test_func_reinference(): + f_type = ts.FunctionType( + pos_only_args=[], + pos_or_kw_args={}, + kw_only_args={}, + returns=float_i_field, + ) + testee = im.call(im.ref("f", f_type))() + result = itir_type_inference.reinfer(copy.deepcopy(testee)) + assert result.type == float_i_field diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 28090ff1e2..938b998565 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -9,6 +9,7 @@ from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.type_system import type_specifications as ts +from next_tests.unit_tests.iterator_tests.test_type_inference import int_type def test_simple_make_tuple_tuple_get(): @@ -127,8 +128,8 @@ def test_letify_make_tuple_elements(): # anything that is not trivial, i.e. a SymRef, works here el1, el2 = im.let("foo", "foo")("foo"), im.let("bar", "bar")("bar") testee = im.make_tuple(el1, el2) - expected = im.let(("_tuple_el_1", el1), ("_tuple_el_2", el2))( - im.make_tuple("_tuple_el_1", "_tuple_el_2") + expected = im.let(("__ct_el_1", el1), ("__ct_el_2", el2))( + im.make_tuple("__ct_el_1", "__ct_el_2") ) actual = CollapseTuple.apply( @@ -239,3 +240,74 @@ def test_tuple_get_on_untyped_ref(): actual = CollapseTuple.apply(testee, allow_undeclared_symbols=True, within_stencil=False) assert actual == testee + + +def test_if_make_tuple_reorder_cps(): + testee = im.let("t", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.make_tuple(im.tuple_get(1, "t"), im.tuple_get(0, "t")) + ) + expected = im.if_(True, im.make_tuple(2, 1), im.make_tuple(4, 3)) + actual = CollapseTuple.apply( + testee, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_nested_if_make_tuple_reorder_cps(): + testee = im.let( + ("t1", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4))), + ("t2", im.if_(False, im.make_tuple(5, 6), im.make_tuple(7, 8))), + )( + im.make_tuple( + im.tuple_get(1, "t1"), + im.tuple_get(0, "t1"), + im.tuple_get(1, "t2"), + im.tuple_get(0, "t2"), + ) + ) + expected = im.if_( + True, + im.if_(False, im.make_tuple(2, 1, 6, 5), im.make_tuple(2, 1, 8, 7)), + im.if_(False, im.make_tuple(4, 3, 6, 5), im.make_tuple(4, 3, 8, 7)), + ) + actual = CollapseTuple.apply( + testee, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_if_make_tuple_reorder_cps_nested(): + testee = im.let("t", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.let("c", im.tuple_get(0, "t"))( + im.make_tuple(im.tuple_get(1, "t"), im.tuple_get(0, "t"), "c") + ) + ) + expected = im.if_(True, im.make_tuple(2, 1, 1), im.make_tuple(4, 3, 3)) + actual = CollapseTuple.apply( + testee, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_if_make_tuple_reorder_cps_external(): + external_ref = im.tuple_get(0, im.ref("external", ts.TupleType(types=[int_type]))) + testee = im.let("t", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.make_tuple(external_ref, im.tuple_get(1, "t"), im.tuple_get(0, "t")) + ) + expected = im.if_(True, im.make_tuple(external_ref, 2, 1), im.make_tuple(external_ref, 4, 3)) + actual = CollapseTuple.apply( + testee, + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected From 33bb68bddb4be31027138fdcac69dade5bbe9ae8 Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 15 Jan 2025 11:52:04 +0100 Subject: [PATCH 091/178] fix[next][dace]: remove unused connectivities (#1797) By design, the arrays for connectivity tables are initially created as transient and marked as non-transient during the lowering when they are used. At the end of lowering to SDFG, the unused connectivities (still transient arrays) should be removed. --- .../runners/dace_common/utility.py | 16 +++++++++++++++- .../runners/dace_fieldview/gtir_dataflow.py | 2 +- .../runners/dace_fieldview/gtir_sdfg.py | 12 ++++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index ac15bc1cbf..3e99c27049 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -18,6 +18,11 @@ from gt4py.next.type_system import type_specifications as ts +# arrays for connectivity tables use the following prefix +CONNECTIVITY_INDENTIFIER_PREFIX: Final[str] = "connectivity_" +CONNECTIVITY_INDENTIFIER_RE: Final[re.Pattern] = re.compile(r"^connectivity_(.+)$") + + # regex to match the symbols for field shape and strides FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile(r"__.+_(size|stride)_\d+") @@ -48,7 +53,16 @@ def as_itir_type(dtype: dace.typeclass) -> ts.ScalarType: def connectivity_identifier(name: str) -> str: - return f"connectivity_{name}" + return f"{CONNECTIVITY_INDENTIFIER_PREFIX}{name}" + + +def is_connectivity_identifier( + name: str, offset_provider_type: gtx_common.OffsetProviderType +) -> bool: + m = CONNECTIVITY_INDENTIFIER_RE.match(name) + if m is None: + return False + return m[1] in offset_provider_type def field_symbol_name(field_name: str, axis: int, sym: Literal["size", "stride"]) -> str: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 0376143883..22d6e17cad 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -401,7 +401,7 @@ def _construct_local_view(self, field: MemletExpr | ValueExpr) -> ValueExpr: view_shape = tuple(desc.shape[i] for i in local_dim_indices) view_strides = tuple(desc.strides[i] for i in local_dim_indices) view, _ = self.sdfg.add_view( - f"{field.dc_node.data}_view", + f"view_{field.dc_node.data}", view_shape, desc.dtype, strides=view_strides, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index baddb7b699..7cb1461746 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -470,6 +470,18 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: head_state._debuginfo = dace_utils.debug_info(stmt, default=sdfg.debuginfo) head_state = self.visit(stmt, sdfg=sdfg, state=head_state) + # remove unused connectivity tables (by design, arrays are marked as non-transient when they are used) + for nsdfg in sdfg.all_sdfgs_recursive(): + unused_connectivities = [ + data + for data, datadesc in nsdfg.arrays.items() + if dace_utils.is_connectivity_identifier(data, self.offset_provider_type) + and datadesc.transient + ] + for data in unused_connectivities: + assert isinstance(nsdfg.arrays[data], dace.data.Array) + nsdfg.arrays.pop(data) + # Create the call signature for the SDFG. # Only the arguments required by the GT4Py program, i.e. `node.params`, are added # as positional arguments. The implicit arguments, such as the offset providers or From 17bae8ebabbf3bff8c862656083a369bb91cd28e Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 15 Jan 2025 12:20:05 +0100 Subject: [PATCH 092/178] feat[next][dace]: iterator-view support to DaCe backend (#1790) The lowering of scan to SDFG requires the support for iterator view. This PR introduces a subset of iterator features: - Local `if_` with exclusive branch execution - Lowering of `list_get`, `make_tuple` and `tuple_get` in iterator view - Field operators returning a tuple of fields - Tuple of fields with different size Iterator tests are enabled on dace CPU backend without SDFG transformations (`auto_optimize=False`). --------- Co-authored-by: Philip Mueller --- pyproject.toml | 8 + .../runners/dace_common/dace_backend.py | 31 +- .../runners/dace_common/workflow.py | 12 +- .../gtir_builtin_translators.py | 206 +++++-- .../runners/dace_fieldview/gtir_dataflow.py | 575 +++++++++++++++--- .../runners/dace_fieldview/gtir_sdfg.py | 135 ++-- .../runners/dace_fieldview/utility.py | 81 ++- tests/next_tests/definitions.py | 25 +- .../ffront_tests/test_execution.py | 6 +- .../iterator_tests/test_builtins.py | 1 + .../iterator_tests/test_trivial.py | 2 + .../iterator_tests/test_tuple.py | 4 + .../iterator_tests/test_column_stencil.py | 9 +- .../test_with_toy_connectivity.py | 4 + tests/next_tests/unit_tests/conftest.py | 4 + .../dace_tests/test_gtir_to_sdfg.py | 62 +- 16 files changed, 875 insertions(+), 290 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 78735116ed..88bb2feac6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -237,16 +237,23 @@ markers = [ 'requires_dace: tests that require `dace` package', 'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)', 'uses_applied_shifts: tests that require backend support for applied-shifts', + 'uses_can_deref: tests that require backend support for can_deref builtin function', + 'uses_composite_shifts: tests that use composite shifts in unstructured domain', 'uses_constant_fields: tests that require backend support for constant fields', 'uses_dynamic_offsets: tests that require backend support for dynamic offsets', 'uses_floordiv: tests that require backend support for floor division', 'uses_if_stmts: tests that require backend support for if-statements', 'uses_index_fields: tests that require backend support for index fields', + 'uses_ir_if_stmts', + 'uses_lift: tests that require backend support for lift builtin function', 'uses_negative_modulo: tests that require backend support for modulo on negative numbers', 'uses_origin: tests that require backend support for domain origin', + 'uses_reduce_with_lambda: tests that use lambdas as reduce functions', 'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields', + 'uses_scalar_in_domain_and_fo', 'uses_scan: tests that uses scan', 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', + 'uses_scan_in_stencil: tests that require backend support for scan in stencil', 'uses_scan_without_field_args: tests that require calls to scan that do not have any fields as arguments', 'uses_scan_nested: tests that use nested scans', 'uses_scan_requiring_projector: tests need a projector implementation in gtfn', @@ -254,6 +261,7 @@ markers = [ 'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields', 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', 'uses_tuple_args: tests that require backend support for tuple arguments', + 'uses_tuple_iterator: tests that require backend support to deref tuple iterators', 'uses_tuple_returns: tests that require backend support for tuple results', 'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields', 'uses_cartesian_shift: tests that use a Cartesian connectivity', diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index 90e7e07ad5..387619c667 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -7,13 +7,13 @@ # SPDX-License-Identifier: BSD-3-Clause import warnings from collections.abc import Mapping, Sequence -from typing import Any, Iterable +from typing import Any import dace import numpy as np from gt4py._core import definitions as core_defs -from gt4py.next import common as gtx_common, utils as gtx_utils +from gt4py.next import common as gtx_common from . import utility as dace_utils @@ -46,10 +46,9 @@ def _convert_arg(arg: Any, sdfg_param: str) -> Any: def _get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: sdfg_params: Sequence[str] = sdfg.arg_names - flat_args: Iterable[Any] = gtx_utils.flatten_nested_tuple(tuple(args)) return { sdfg_param: _convert_arg(arg, sdfg_param) - for sdfg_param, arg in zip(sdfg_params, flat_args, strict=True) + for sdfg_param, arg in zip(sdfg_params, args, strict=True) } @@ -73,17 +72,8 @@ def _get_shape_args( for name, value in args.items(): for sym, size in zip(arrays[name].shape, value.shape, strict=True): if isinstance(sym, dace.symbol): - if sym.name not in shape_args: - shape_args[sym.name] = size - elif shape_args[sym.name] != size: - # The same shape symbol is used by all fields of a tuple, because the current assumption is that all fields - # in a tuple have the same dimensions and sizes. Therefore, this if-branch only exists to ensure that array - # size (i.e. the value assigned to the shape symbol) is the same for all fields in a tuple. - # TODO(edopao): change to `assert sym.name not in shape_args` to ensure that shape symbols are unique, - # once the assumption on tuples is removed. - raise ValueError( - f"Expected array size {sym.name} for arg {name} to be {shape_args[sym.name]}, got {size}." - ) + assert sym.name not in shape_args + shape_args[sym.name] = size elif sym != size: raise ValueError( f"Expected shape {arrays[name].shape} for arg {name}, got {value.shape}." @@ -103,15 +93,8 @@ def _get_stride_args( f"Stride ({stride_size} bytes) for argument '{sym}' must be a multiple of item size ({value.itemsize} bytes)." ) if isinstance(sym, dace.symbol): - if sym.name not in stride_args: - stride_args[str(sym)] = stride - elif stride_args[sym.name] != stride: - # See above comment in `_get_shape_args`, same for stride symbols of fields in a tuple. - # TODO(edopao): change to `assert sym.name not in stride_args` to ensure that stride symbols are unique, - # once the assumption on tuples is removed. - raise ValueError( - f"Expected array stride {sym.name} for arg {name} to be {stride_args[sym.name]}, got {stride}." - ) + assert sym.name not in stride_args + stride_args[sym.name] = stride elif sym != stride: raise ValueError( f"Expected stride {arrays[name].strides} for arg {name}, got {value.strides}." diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace_common/workflow.py index 5d9ac863c5..f0577ffaf2 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_common/workflow.py @@ -10,7 +10,7 @@ import ctypes import dataclasses -from typing import Any +from typing import Any, Sequence import dace import factory @@ -112,11 +112,13 @@ def decorated_program( ) -> None: if out is not None: args = (*args, out) - if len(sdfg.arg_names) > len(args): - args = (*args, *arguments.iter_size_args(args)) + flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(tuple(args)) + if len(sdfg.arg_names) > len(flat_args): + # The Ahead-of-Time (AOT) workflow for FieldView programs requires domain size arguments. + flat_args = (*flat_args, *arguments.iter_size_args(args)) if sdfg_program._lastargs: - kwargs = dict(zip(sdfg.arg_names, gtx_utils.flatten_nested_tuple(args), strict=True)) + kwargs = dict(zip(sdfg.arg_names, flat_args, strict=True)) kwargs.update(dace_backend.get_sdfg_conn_args(sdfg, offset_provider, on_gpu)) use_fast_call = True @@ -151,7 +153,7 @@ def decorated_program( sdfg_args = dace_backend.get_sdfg_args( sdfg, offset_provider, - *args, + *flat_args, check_args=False, on_gpu=on_gpu, use_field_canonical_representation=use_field_canonical_representation, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 354a9692d8..4cbc737312 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -10,7 +10,7 @@ import abc import dataclasses -from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias +from typing import TYPE_CHECKING, Any, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace from dace import subsets as dace_subsets @@ -27,6 +27,7 @@ from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_dataflow, gtir_python_codegen, + gtir_sdfg, utility as dace_gtir_utils, ) from gt4py.next.type_system import type_info as ti, type_specifications as ts @@ -157,6 +158,33 @@ def get_local_view( """Data type used for field indexing.""" +def get_tuple_type(data: tuple[FieldopResult, ...]) -> ts.TupleType: + """ + Compute the `ts.TupleType` corresponding to the tuple structure of `FieldopResult`. + """ + return ts.TupleType( + types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] + ) + + +def flatten_tuples(name: str, arg: FieldopResult) -> list[tuple[str, FieldopData]]: + """ + Visit a `FieldopResult`, potentially containing nested tuples, and construct a list + of pairs `(str, FieldopData)` containing the symbol name of each tuple field and + the corresponding `FieldopData`. + """ + if isinstance(arg, tuple): + tuple_type = get_tuple_type(arg) + tuple_symbols = dace_gtir_utils.flatten_tuple_fields(name, tuple_type) + tuple_data_fields = gtx_utils.flatten_nested_tuple(arg) + return [ + (str(tsym.id), tfield) + for tsym, tfield in zip(tuple_symbols, tuple_data_fields, strict=True) + ] + else: + return [(name, arg)] + + class PrimitiveTranslator(Protocol): @abc.abstractmethod def __call__( @@ -191,16 +219,20 @@ def _parse_fieldop_arg( state: dace.SDFGState, sdfg_builder: gtir_sdfg.SDFGBuilder, domain: FieldopDomain, -) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: +) -> ( + gtir_dataflow.IteratorExpr + | gtir_dataflow.MemletExpr + | tuple[gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr | tuple[Any, ...], ...] +): """Helper method to visit an expression passed as argument to a field operator.""" arg = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) - # arguments passed to field operator should be plain fields, not tuples of fields - if not isinstance(arg, FieldopData): - raise ValueError(f"Received {node} as argument to field operator, expected a field.") - - return arg.get_local_view(domain) + if isinstance(arg, FieldopData): + return arg.get_local_view(domain) + else: + # handle tuples of fields + return gtx_utils.tree_map(lambda x: x.get_local_view(domain))(arg) def _get_field_layout( @@ -232,62 +264,107 @@ def _get_field_layout( return list(domain_dims), list(domain_lbs), domain_sizes -def _create_field_operator( +def _create_field_operator_impl( + sdfg_builder: gtir_sdfg.SDFGBuilder, sdfg: dace.SDFG, state: dace.SDFGState, domain: FieldopDomain, - node_type: ts.FieldType, - sdfg_builder: gtir_sdfg.SDFGBuilder, - input_edges: Sequence[gtir_dataflow.DataflowInputEdge], output_edge: gtir_dataflow.DataflowOutputEdge, + output_type: ts.FieldType, + map_exit: dace.nodes.MapExit, ) -> FieldopData: """ - Helper method to allocate a temporary field to store the output of a field operator. + Helper method to allocate a temporary array that stores one field computed by a field operator. + + This method is called by `_create_field_operator()`. Args: + sdfg_builder: The object used to build the map scope in the provided SDFG. sdfg: The SDFG that represents the scope of the field data. state: The SDFG state where to create an access node to the field data. domain: The domain of the field operator that computes the field. - node_type: The GT4Py type of the IR node that produces this field. - sdfg_builder: The object used to build the map scope in the provided SDFG. - input_edges: List of edges to pass input data into the dataflow. - output_edge: Edge representing the dataflow output data. + output_edge: The dataflow write edge representing the output data. + output_type: The GT4Py field type descriptor. + map_exit: The `MapExit` node of the field operator map scope. Returns: The field data descriptor, which includes the field access node in the given `state` and the field domain offset. """ - field_dims, field_offset, field_shape = _get_field_layout(domain) - field_indices = _get_domain_indices(field_dims, field_offset) - dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) - field_subset = dace_subsets.Range.from_indices(field_indices) + domain_dims, domain_offset, domain_shape = _get_field_layout(domain) + domain_indices = _get_domain_indices(domain_dims, domain_offset) + domain_subset = dace_subsets.Range.from_indices(domain_indices) + if isinstance(output_edge.result.gt_dtype, ts.ScalarType): - assert output_edge.result.gt_dtype == node_type.dtype - assert isinstance(dataflow_output_desc, dace.data.Scalar) - assert isinstance(node_type.dtype, ts.ScalarType) - assert dataflow_output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) + assert output_edge.result.gt_dtype == output_type.dtype field_dtype = output_edge.result.gt_dtype + field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset) + assert isinstance(dataflow_output_desc, dace.data.Scalar) + field_subset = domain_subset else: - assert isinstance(node_type.dtype, ts.ListType) - assert output_edge.result.gt_dtype.element_type == node_type.dtype.element_type - assert isinstance(dataflow_output_desc, dace.data.Array) + assert isinstance(output_type.dtype, ts.ListType) assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) + assert output_edge.result.gt_dtype.element_type == output_type.dtype.element_type field_dtype = output_edge.result.gt_dtype.element_type + assert isinstance(dataflow_output_desc, dace.data.Array) + assert len(dataflow_output_desc.shape) == 1 # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) assert output_edge.result.gt_dtype.offset_type is not None - field_dims.append(output_edge.result.gt_dtype.offset_type) - field_shape.extend(dataflow_output_desc.shape) - field_offset.extend(dataflow_output_desc.offset) - field_subset = field_subset + dace_subsets.Range.from_array(dataflow_output_desc) + field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type] + field_shape = [*domain_shape, dataflow_output_desc.shape[0]] + field_offset = [*domain_offset, dataflow_output_desc.offset[0]] + field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc) # allocate local temporary storage - field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) + assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype) + field_name, _ = sdfg_builder.add_temp_array(sdfg, field_shape, dataflow_output_desc.dtype) field_node = state.add_access(field_name) + # and here the edge writing the dataflow result data through the map exit node + output_edge.connect(map_exit, field_node, field_subset) + + return FieldopData( + field_node, + ts.FieldType(field_dims, field_dtype), + offset=(field_offset if set(field_offset) != {0} else None), + ) + + +def _create_field_operator( + sdfg: dace.SDFG, + state: dace.SDFGState, + domain: FieldopDomain, + node_type: ts.FieldType | ts.TupleType, + sdfg_builder: gtir_sdfg.SDFGBuilder, + input_edges: Iterable[gtir_dataflow.DataflowInputEdge], + output_edges: gtir_dataflow.DataflowOutputEdge + | tuple[gtir_dataflow.DataflowOutputEdge | tuple[Any, ...], ...], +) -> FieldopResult: + """ + Helper method to build the output of a field operator, which can consist of + a single field or a tuple of fields. + + A tuple of fields is returned when one stencil computes a grid point on multiple + fields: for each field, this method will call `_create_field_operator_impl()`. + + Args: + sdfg: The SDFG that represents the scope of the field data. + state: The SDFG state where to create an access node to the field data. + domain: The domain of the field operator that computes the field. + node_type: The GT4Py type of the IR node that produces this field. + sdfg_builder: The object used to build the map scope in the provided SDFG. + input_edges: List of edges to pass input data into the dataflow. + output_edges: Single edge or tuple of edges representing the dataflow output data. + + Returns: + The descriptor of the field operator result, which can be either a single field + or a tuple fields. + """ + # create map range corresponding to the field operator domain - me, mx = sdfg_builder.add_map( + map_entry, map_exit = sdfg_builder.add_map( "fieldop", state, ndrange={ @@ -298,16 +375,21 @@ def _create_field_operator( # here we setup the edges passing through the map entry node for edge in input_edges: - edge.connect(me) - - # and here the edge writing the dataflow result data through the map exit node - output_edge.connect(mx, field_node, field_subset) + edge.connect(map_entry) - return FieldopData( - field_node, - ts.FieldType(field_dims, field_dtype), - offset=(field_offset if set(field_offset) != {0} else None), - ) + if isinstance(node_type, ts.FieldType): + assert isinstance(output_edges, gtir_dataflow.DataflowOutputEdge) + return _create_field_operator_impl( + sdfg_builder, sdfg, state, domain, output_edges, node_type, map_exit + ) + else: + # handle tuples of fields + output_symbol_tree = dace_gtir_utils.make_symbol_tree("x", node_type) + return gtx_utils.tree_map( + lambda output_edge, output_sym: _create_field_operator_impl( + sdfg_builder, sdfg, state, domain, output_edge, output_sym.type, map_exit + ) + )(output_edges, output_symbol_tree) def extract_domain(node: gtir.Node) -> FieldopDomain: @@ -366,16 +448,17 @@ def translate_as_fieldop( """ assert isinstance(node, gtir.FunCall) assert cpm.is_call_to(node.fun, "as_fieldop") + assert isinstance(node.type, (ts.FieldType, ts.TupleType)) fun_node = node.fun assert len(fun_node.args) == 2 fieldop_expr, domain_expr = fun_node.args - assert isinstance(node.type, ts.FieldType) if cpm.is_ref_to(fieldop_expr, "deref"): # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. + assert isinstance(node.type, ts.FieldType) stencil_expr = im.lambda_("a")(im.deref("a")) stencil_expr.expr.type = node.type.dtype elif isinstance(fieldop_expr, gtir.Lambda): @@ -394,12 +477,12 @@ def translate_as_fieldop( fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain - input_edges, output_edge = gtir_dataflow.visit_lambda( + input_edges, output_edges = gtir_dataflow.translate_lambda_to_dataflow( sdfg, state, sdfg_builder, stencil_expr, fieldop_args ) return _create_field_operator( - sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edges ) @@ -458,7 +541,7 @@ def translate_if( def construct_output(inner_data: FieldopData) -> FieldopData: inner_desc = inner_data.dc_node.desc(sdfg) - outer, _ = sdfg.add_temp_transient_like(inner_desc) + outer, _ = sdfg_builder.add_temp_array_like(sdfg, inner_desc) outer_node = state.add_access(outer) return inner_data.make_copy(outer_node) @@ -518,8 +601,7 @@ def translate_index( dim, _, _ = domain[0] dim_index = dace_gtir_utils.get_map_variable(dim) - index_data = sdfg.temp_data_name() - sdfg.add_scalar(index_data, INDEX_DTYPE, transient=True) + index_data, _ = sdfg_builder.add_temp_scalar(sdfg, INDEX_DTYPE) index_node = state.add_access(index_data) index_value = gtir_dataflow.ValueExpr( dc_node=index_node, @@ -570,11 +652,10 @@ def _get_data_nodes( return sdfg_builder.make_field(data_node, data_type) elif isinstance(data_type, ts.TupleType): - tuple_fields = dace_gtir_utils.get_tuple_fields(data_name, data_type) - return tuple( - _get_data_nodes(sdfg, state, sdfg_builder, fname, ftype) - for fname, ftype in tuple_fields - ) + symbol_tree = dace_gtir_utils.make_symbol_tree(data_name, data_type) + return gtx_utils.tree_map( + lambda sym: _get_data_nodes(sdfg, state, sdfg_builder, sym.id, sym.type) + )(symbol_tree) else: raise NotImplementedError(f"Symbol type {type(data_type)} not supported.") @@ -691,13 +772,11 @@ def translate_scalar_expr( visit_expr = True if isinstance(arg_expr, gtir.SymRef): try: - # `gt_symbol` refers to symbols defined in the GT4Py program - gt_symbol_type = sdfg_builder.get_symbol_type(arg_expr.id) - if not isinstance(gt_symbol_type, ts.ScalarType): - raise ValueError(f"Invalid argument to scalar expression {arg_expr}.") + # check if symbol is defined in the GT4Py program, throws `KeyError` exception if undefined + sdfg_builder.get_symbol_type(arg_expr.id) except KeyError: - # this is the case of non-variable argument, e.g. target type such as `float64`, - # used in a casting expression like `cast_(variable, float64)` + # all `SymRef` should refer to symbols defined in the program, except in case of non-variable argument, + # e.g. the type name `float64` used in casting expressions like `cast_(variable, float64)` visit_expr = False if visit_expr: @@ -708,7 +787,7 @@ def translate_scalar_expr( sdfg=sdfg, head_state=state, ) - if not (isinstance(arg, FieldopData) and isinstance(arg.gt_type, ts.ScalarType)): + if not (isinstance(arg, FieldopData) and isinstance(node.type, ts.ScalarType)): raise ValueError(f"Invalid argument to scalar expression {arg_expr}.") param = f"__arg{i}" args.append(arg.dc_node) @@ -738,12 +817,7 @@ def translate_scalar_expr( dace.Memlet(data=arg_node.data, subset="0"), ) # finally, create temporary for the result value - temp_name, _ = sdfg.add_scalar( - sdfg.temp_data_name(), - dace_utils.as_dace_type(node.type), - find_new_name=True, - transient=True, - ) + temp_name, _ = sdfg_builder.add_temp_scalar(sdfg, dace_utils.as_dace_type(node.type)) temp_node = state.add_access(temp_name) state.add_edge( tasklet_node, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 22d6e17cad..d086b26a2d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -28,9 +28,10 @@ from dace import subsets as dace_subsets from gt4py import eve -from gt4py.next import common as gtx_common +from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.transforms import symbol_ref_utils from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_python_codegen, @@ -115,6 +116,9 @@ class IteratorExpr: field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]] indices: dict[gtx_common.Dimension, DataExpr] + def get_field_type(self) -> ts.FieldType: + return ts.FieldType([dim for dim, _ in self.field_domain], self.gt_dtype) + def get_memlet_subset(self, sdfg: dace.SDFG) -> dace_subsets.Range: if not all(isinstance(self.indices[dim], SymbolExpr) for dim, _ in self.field_domain): raise ValueError(f"Cannot deref iterator {self}.") @@ -140,16 +144,19 @@ def get_memlet_subset(self, sdfg: dace.SDFG) -> dace_subsets.Range: class DataflowInputEdge(Protocol): """ - This protocol represents an open connection into the dataflow. + This protocol describes how to concretize a data edge to read data from a source node + into the dataflow. It provides the `connect` method to setup an input edge from an external data source. - Since the dataflow represents a stencil, we instantiate the dataflow inside a map scope - and connect its inputs and outputs to external data nodes by means of memlets that - traverse the map entry and exit nodes. + The most common case is that the dataflow represents a stencil, which is instantied + inside a map scope and whose inputs and outputs are connected to external data nodes + by means of memlets that traverse the map entry and exit nodes. + The dataflow can also be instatiated without a map, in which case the `map_entry` + argument is set to `None`. """ @abc.abstractmethod - def connect(self, me: dace.nodes.MapEntry) -> None: ... + def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: ... @dataclasses.dataclass(frozen=True) @@ -167,15 +174,18 @@ class MemletInputEdge(DataflowInputEdge): dest: dace.nodes.AccessNode | dace.nodes.Tasklet dest_conn: Optional[str] - def connect(self, me: dace.nodes.MapEntry) -> None: + def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: memlet = dace.Memlet(data=self.source.data, subset=self.subset) - self.state.add_memlet_path( - self.source, - me, - self.dest, - dst_conn=self.dest_conn, - memlet=memlet, - ) + if map_entry is None: + self.state.add_edge(self.source, None, self.dest, self.dest_conn, memlet) + else: + self.state.add_memlet_path( + self.source, + map_entry, + self.dest, + dst_conn=self.dest_conn, + memlet=memlet, + ) @dataclasses.dataclass(frozen=True) @@ -190,8 +200,12 @@ class EmptyInputEdge(DataflowInputEdge): state: dace.SDFGState node: dace.nodes.Tasklet - def connect(self, me: dace.nodes.MapEntry) -> None: - self.state.add_nedge(me, self.node, dace.Memlet()) + def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: + if map_entry is None: + # outside of a map scope it is possible to instantiate a tasklet node + # without input connectors + return + self.state.add_nedge(map_entry, self.node, dace.Memlet()) @dataclasses.dataclass(frozen=True) @@ -200,10 +214,12 @@ class DataflowOutputEdge: Allows to setup an output memlet through a map exit node. The result of a dataflow subgraph needs to be written to an external data node. - Since the dataflow represents a stencil and the dataflow is computed over - a field domain, the dataflow is instatiated inside a map scope. The `connect` - method creates a memlet that writes the dataflow result to the external array - passing through the map exit node. + The most common case is that the dataflow represents a stencil and the dataflow + is computed over a field domain, therefore the dataflow is instatiated inside + a map scope. The `connect` method creates a memlet that writes the dataflow + result to the external array passing through the `map_exit` node. + The dataflow can also be instatiated without a map, in which case the `map_exit` + argument is set to `None`. """ state: dace.SDFGState @@ -211,13 +227,13 @@ class DataflowOutputEdge: def connect( self, - mx: dace.nodes.MapExit, + map_exit: Optional[dace.nodes.MapExit], dest: dace.nodes.AccessNode, subset: dace_subsets.Range, ) -> None: # retrieve the node which writes the result last_node = self.state.in_edges(self.result.dc_node)[0].src - if isinstance(last_node, dace.nodes.Tasklet): + if isinstance(last_node, (dace.nodes.Tasklet, dace.nodes.NestedSDFG)): # the last transient node can be deleted last_node_connector = self.state.in_edges(self.result.dc_node)[0].src_conn self.state.remove_node(self.result.dc_node) @@ -225,13 +241,22 @@ def connect( last_node = self.result.dc_node last_node_connector = None - self.state.add_memlet_path( - last_node, - mx, - dest, - src_conn=last_node_connector, - memlet=dace.Memlet(data=dest.data, subset=subset), - ) + if map_exit is None: + self.state.add_edge( + last_node, + last_node_connector, + dest, + None, + dace.Memlet(data=dest.data, subset=subset), + ) + else: + self.state.add_memlet_path( + last_node, + map_exit, + dest, + src_conn=last_node_connector, + memlet=dace.Memlet(data=dest.data, subset=subset), + ) DACE_REDUCTION_MAPPING: dict[str, dace.dtypes.ReductionType] = { @@ -267,6 +292,25 @@ def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: return op_name, reduce_init, reduce_identity +def get_tuple_type( + data: tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...], +) -> ts.TupleType: + """ + Compute the `ts.TupleType` corresponding to the tuple structure of input data expressions. + """ + data_types: list[ts.DataType] = [] + for dataitem in data: + if isinstance(dataitem, tuple): + data_types.append(get_tuple_type(dataitem)) + elif isinstance(dataitem, IteratorExpr): + data_types.append(dataitem.get_field_type()) + elif isinstance(dataitem, MemletExpr): + data_types.append(dataitem.gt_dtype) + else: + data_types.append(dataitem.gt_dtype) + return ts.TupleType(data_types) + + @dataclasses.dataclass(frozen=True) class LambdaToDataflow(eve.NodeVisitor): """ @@ -289,9 +333,10 @@ class LambdaToDataflow(eve.NodeVisitor): state: dace.SDFGState subgraph_builder: gtir_sdfg.DataflowBuilder input_edges: list[DataflowInputEdge] = dataclasses.field(default_factory=lambda: []) - symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] = dataclasses.field( - default_factory=lambda: {} - ) + symbol_map: dict[ + str, + IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...], + ] = dataclasses.field(default_factory=dict) def _add_input_data_edge( self, @@ -370,9 +415,9 @@ def _add_mapped_tasklet( name: str, map_ranges: Dict[str, str | dace.subsets.Subset] | List[Tuple[str, str | dace.subsets.Subset]], - inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + inputs: Dict[str, dace.Memlet], code: str, - outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + outputs: Dict[str, dace.Memlet], **kwargs: Any, ) -> tuple[dace.nodes.Tasklet, dace.nodes.MapEntry, dace.nodes.MapExit]: """ @@ -427,10 +472,9 @@ def _construct_tasklet_result( # In some cases, such as result data with list-type annotation, we want # that output data is represented as an array (single-element 1D array) # in order to allow for composition of array shape in external memlets. - temp_name, _ = self.sdfg.add_temp_transient((1,), dc_dtype) + temp_name, _ = self.subgraph_builder.add_temp_array(self.sdfg, (1,), dc_dtype) else: - temp_name = self.sdfg.temp_data_name() - self.sdfg.add_scalar(temp_name, dc_dtype, transient=True) + temp_name, _ = self.subgraph_builder.add_temp_scalar(self.sdfg, dc_dtype) temp_node = self.state.add_access(temp_name) self._add_edge( @@ -467,6 +511,9 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: # format used for field index tasklet connector IndexConnectorFmt: Final = "__index_{dim}" + if isinstance(node.type, ts.TupleType): + raise NotImplementedError("Tuple deref not supported.") + assert len(node.args) == 1 arg_expr = self.visit(node.args[0]) @@ -545,6 +592,274 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: return self._construct_tasklet_result(field_desc.dtype, deref_node, "val") + def _visit_if_branch_arg( + self, + if_sdfg: dace.SDFG, + if_branch_state: dace.SDFGState, + param_name: str, + arg: IteratorExpr | DataExpr, + if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr], + ) -> IteratorExpr | ValueExpr: + """ + Helper method to be called by `_visit_if_branch()` to visit the input arguments. + + Args: + if_sdfg: The nested SDFG where the if expression is lowered. + if_branch_state: The state inside the nested SDFG where the if branch is lowered. + param_name: The parameter name of the input argument. + arg: The input argument expression. + if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function. + """ + if isinstance(arg, (MemletExpr, ValueExpr)): + arg_expr = arg + arg_node = arg.dc_node + arg_desc = arg_node.desc(self.sdfg) + if isinstance(arg, MemletExpr): + assert arg.subset.num_elements() == 1 + arg_desc = dace.data.Scalar(arg_desc.dtype) + else: + assert isinstance(arg_desc, dace.data.Scalar) + elif isinstance(arg, IteratorExpr): + arg_node = arg.field + arg_desc = arg_node.desc(self.sdfg) + arg_expr = MemletExpr(arg_node, arg.gt_dtype, dace_subsets.Range.from_array(arg_desc)) + else: + raise TypeError(f"Unexpected {arg} as input argument.") + + if param_name in if_sdfg.arrays: + inner_desc = if_sdfg.data(param_name) + assert not inner_desc.transient + else: + inner_desc = arg_desc.clone() + inner_desc.transient = False + if_sdfg.add_datadesc(param_name, inner_desc) + if_sdfg_input_memlets[param_name] = arg_expr + + inner_node = if_branch_state.add_access(param_name) + if isinstance(arg, IteratorExpr): + return IteratorExpr(inner_node, arg.gt_dtype, arg.field_domain, arg.indices) + else: + return ValueExpr(inner_node, arg.gt_dtype) + + def _visit_if_branch( + self, + if_sdfg: dace.SDFG, + if_branch_state: dace.SDFGState, + expr: gtir.Expr, + if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr], + ) -> tuple[ + list[DataflowInputEdge], + DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...], + ]: + """ + Helper method to visit an if-branch expression and lower it to a dataflow inside the given nested SDFG and state. + + This function is called by `_visit_if()` for each if-branch. + + Args: + if_sdfg: The nested SDFG where the if expression is lowered. + if_branch_state: The state inside the nested SDFG where the if branch is lowered. + expr: The if branch expression to lower. + if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function. + + Returns: + A tuple containing: + - the list of input edges for the parent dataflow + - the output data, in the form of a single data edge or a tuple of data edges. + """ + assert if_branch_state in if_sdfg.states() + + lambda_args = [] + lambda_params = [] + for pname in symbol_ref_utils.collect_symbol_refs(expr, self.symbol_map.keys()): + arg = self.symbol_map[pname] + if isinstance(arg, tuple): + ptype = get_tuple_type(arg) # type: ignore[arg-type] + psymbol = im.sym(pname, ptype) + psymbol_tree = dace_gtir_utils.make_symbol_tree(pname, ptype) + inner_arg = gtx_utils.tree_map( + lambda tsym, targ: self._visit_if_branch_arg( + if_sdfg, if_branch_state, tsym.id, targ, if_sdfg_input_memlets + ) + )(psymbol_tree, arg) + else: + psymbol = im.sym(pname, arg.gt_dtype) # type: ignore[union-attr] + inner_arg = self._visit_if_branch_arg( + if_sdfg, if_branch_state, pname, arg, if_sdfg_input_memlets + ) + lambda_args.append(inner_arg) + lambda_params.append(psymbol) + + # visit each branch of the if-statement as if it was a Lambda node + lambda_node = gtir.Lambda(params=lambda_params, expr=expr) + input_edges, output_edges = translate_lambda_to_dataflow( + if_sdfg, if_branch_state, self.subgraph_builder, lambda_node, args=lambda_args + ) + + for data_node in if_branch_state.data_nodes(): + # In case tuple arguments, isolated non-transient nodes might be left in the state, + # because not all tuple fields are necessarily used in the lambda scope + if if_branch_state.degree(data_node) == 0: + assert not data_node.desc(if_sdfg).transient + if_branch_state.remove_node(data_node) + + return input_edges, output_edges + + def _visit_if_branch_result( + self, sdfg: dace.SDFG, state: dace.SDFGState, edge: DataflowOutputEdge, sym: gtir.Sym + ) -> ValueExpr: + """ + Helper function to be called by `_visit_if` to create an output connector + on the nested SDFG that will write the result to the parent SDFG. + The result data inside the nested SDFG must have the same name as the connector. + """ + output_data = str(sym.id) + if output_data in sdfg.arrays: + output_desc = sdfg.data(output_data) + assert not output_desc.transient + else: + # If the result is currently written to a transient node, inside the nested SDFG, + # we need to allocate a non-transient data node. + result_desc = edge.result.dc_node.desc(sdfg) + output_desc = result_desc.clone() + output_desc.transient = False + output_data = sdfg.add_datadesc(output_data, output_desc, find_new_name=True) + output_node = state.add_access(output_data) + state.add_nedge( + edge.result.dc_node, + output_node, + dace.Memlet.from_array(output_data, output_desc), + ) + return ValueExpr(output_node, edge.result.gt_dtype) + + def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[Any, ...], ...]: + """ + Lowers an if-expression with exclusive branch execution into a nested SDFG, + in which each branch is lowered into a dataflow in a separate state and + the if-condition is represented as the inter-state edge condtion. + """ + + def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExpr: + # Each output connector of the nested SDFG writes to a transient node in the parent SDFG + inner_data = inner_value.dc_node.data + inner_desc = inner_value.dc_node.desc(nsdfg) + assert not inner_desc.transient + output, output_desc = self.subgraph_builder.add_temp_array_like(self.sdfg, inner_desc) + output_node = self.state.add_access(output) + self.state.add_edge( + nsdfg_node, + inner_data, + output_node, + None, + dace.Memlet.from_array(output, output_desc), + ) + return ValueExpr(output_node, inner_value.gt_dtype) + + assert len(node.args) == 3 + + # TODO(edopao): enable once supported in next DaCe release + use_conditional_block: Final[bool] = False + + # evaluate the if-condition that will write to a boolean scalar node + condition_value = self.visit(node.args[0]) + assert ( + ( + isinstance(condition_value.gt_dtype, ts.ScalarType) + and condition_value.gt_dtype.kind == ts.ScalarKind.BOOL + ) + if isinstance(condition_value, (MemletExpr, ValueExpr)) + else (condition_value.dc_dtype == dace.dtypes.bool_) + ) + + nsdfg = dace.SDFG(self.unique_nsdfg_name(prefix="if_stmt")) + nsdfg.debuginfo = dace_utils.debug_info(node, default=self.sdfg.debuginfo) + + # create states inside the nested SDFG for the if-branches + if use_conditional_block: + if_region = dace.sdfg.state.ConditionalBlock("if") + nsdfg.add_node(if_region) + entry_state = nsdfg.add_state("entry", is_start_block=True) + nsdfg.add_edge(entry_state, if_region, dace.InterstateEdge()) + + then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=nsdfg) + tstate = then_body.add_state("true_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock("__cond"), then_body) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=nsdfg) + fstate = else_body.add_state("false_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock("not (__cond)"), else_body) + + else: + entry_state = nsdfg.add_state("entry", is_start_block=True) + tstate = nsdfg.add_state("true_branch") + nsdfg.add_edge(entry_state, tstate, dace.InterstateEdge(condition="__cond")) + fstate = nsdfg.add_state("false_branch") + nsdfg.add_edge(entry_state, fstate, dace.InterstateEdge(condition="not (__cond)")) + + input_memlets: dict[str, MemletExpr | ValueExpr] = {} + + # define scalar or symbol for the condition value inside the nested SDFG + if isinstance(condition_value, SymbolExpr): + nsdfg.add_symbol("__cond", dace.dtypes.bool) + else: + nsdfg.add_scalar("__cond", dace.dtypes.bool) + input_memlets["__cond"] = condition_value + + for nstate, arg in zip([tstate, fstate], node.args[1:3]): + # visit each if-branch in the corresponding state of the nested SDFG + in_edges, out_edge = self._visit_if_branch(nsdfg, nstate, arg, input_memlets) + for edge in in_edges: + edge.connect(map_entry=None) + + if isinstance(out_edge, tuple): + assert isinstance(node.type, ts.TupleType) + out_symbol_tree = dace_gtir_utils.make_symbol_tree("__output", node.type) + outer_value = gtx_utils.tree_map( + lambda x, y, nstate=nstate: self._visit_if_branch_result(nsdfg, nstate, x, y) + )(out_edge, out_symbol_tree) + else: + assert isinstance(node.type, ts.FieldType | ts.ScalarType) + outer_value = self._visit_if_branch_result( + nsdfg, nstate, out_edge, im.sym("__output", node.type) + ) + # Isolated access node will make validation fail. + # Isolated access nodes can be found in `make_tuple` expressions that + # construct tuples from input arguments. + for data_node in nstate.data_nodes(): + if nstate.degree(data_node) == 0: + assert not data_node.desc(nsdfg).transient + nsdfg.remove_node(data_node) + else: + result = outer_value + + outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple((result,))} + + nsdfg_node = self.state.add_nested_sdfg( + nsdfg, + self.sdfg, + inputs=set(input_memlets.keys()), + outputs=outputs, + symbol_mapping=None, # implicitly map all free symbols to the symbols available in parent SDFG + ) + + for inner, input_expr in input_memlets.items(): + if isinstance(input_expr, MemletExpr): + self._add_input_data_edge(input_expr.dc_node, input_expr.subset, nsdfg_node, inner) + else: + self._add_edge( + input_expr.dc_node, + None, + nsdfg_node, + inner, + self.sdfg.make_array_memlet(input_expr.dc_node.data), + ) + + return ( + gtx_utils.tree_map(write_output_of_nested_sdfg_to_temporary)(result) + if isinstance(result, tuple) + else write_output_of_nested_sdfg_to_temporary(result) + ) + def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(node.type, ts.ListType) assert len(node.args) == 2 @@ -605,8 +920,8 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: ) ) - neighbors_temp, _ = self.sdfg.add_temp_transient( - (offset_provider.max_neighbors,), field_desc.dtype + neighbors_temp, _ = self.subgraph_builder.add_temp_array( + self.sdfg, (offset_provider.max_neighbors,), field_desc.dtype ) neighbors_node = self.state.add_access(neighbors_temp) offset_type = gtx_common.Dimension(offset, gtx_common.DimensionKind.LOCAL) @@ -652,6 +967,56 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: dc_node=neighbors_node, gt_dtype=ts.ListType(node.type.element_type, offset_type) ) + def _visit_list_get(self, node: gtir.FunCall) -> ValueExpr: + assert len(node.args) == 2 + index_arg = self.visit(node.args[0]) + list_arg = self.visit(node.args[1]) + assert isinstance(list_arg, ValueExpr) + assert isinstance(list_arg.gt_dtype, ts.ListType) + assert isinstance(list_arg.gt_dtype.element_type, ts.ScalarType) + + list_desc = list_arg.dc_node.desc(self.sdfg) + assert len(list_desc.shape) == 1 + + result_dtype = dace_utils.as_dace_type(list_arg.gt_dtype.element_type) + result, _ = self.subgraph_builder.add_temp_scalar(self.sdfg, result_dtype) + result_node = self.state.add_access(result) + + if isinstance(index_arg, SymbolExpr): + assert index_arg.dc_dtype in dace.dtypes.INTEGER_TYPES + self._add_edge( + list_arg.dc_node, + None, + result_node, + None, + dace.Memlet(data=list_arg.dc_node.data, subset=index_arg.value), + ) + elif isinstance(index_arg, ValueExpr): + tasklet_node = self._add_tasklet( + "list_get", inputs={"index", "list"}, outputs={"value"}, code="value = list[index]" + ) + self._add_edge( + index_arg.dc_node, + None, + tasklet_node, + "index", + dace.Memlet(data=index_arg.dc_node.data, subset="0"), + ) + self._add_edge( + list_arg.dc_node, + None, + tasklet_node, + "list", + self.sdfg.make_array_memlet(list_arg.dc_node.data), + ) + self._add_edge( + tasklet_node, "value", result_node, None, dace.Memlet(data=result, subset="0") + ) + else: + raise TypeError(f"Unexpected value {index_arg} as index argument.") + + return ValueExpr(dc_node=result_node, gt_dtype=list_arg.gt_dtype.element_type) + def _visit_map(self, node: gtir.FunCall) -> ValueExpr: """ A map node defines an operation to be mapped on all elements of input arguments. @@ -743,7 +1108,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: ) input_nodes[input_node.data] = input_node - result, _ = self.sdfg.add_temp_transient((local_size,), dc_dtype) + result, _ = self.subgraph_builder.add_temp_array(self.sdfg, (local_size,), dc_dtype) result_node = self.state.add_access(result) if offset_provider_type.has_skip_values: @@ -930,8 +1295,7 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: op_name, reduce_init, reduce_identity = get_reduce_params(node) reduce_wcr = "lambda x, y: " + gtir_python_codegen.format_builtin(op_name, "x", "y") - result = self.sdfg.temp_data_name() - self.sdfg.add_scalar(result, reduce_identity.dc_dtype, transient=True) + result, _ = self.subgraph_builder.add_temp_scalar(self.sdfg, reduce_identity.dc_dtype) result_node = self.state.add_access(result) input_expr = self.visit(node.args[0]) @@ -1119,10 +1483,7 @@ def _make_unstructured_shift( """Implements shift in unstructured domain by means of a neighbor table.""" assert any(dim == connectivity.codomain for dim, _ in it.field_domain) neighbor_dim = connectivity.codomain - assert neighbor_dim not in it.indices - origin_dim = connectivity.source_dim - assert origin_dim in it.indices origin_index = it.indices[origin_dim] assert isinstance(origin_index, SymbolExpr) @@ -1253,13 +1614,45 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: return self._construct_tasklet_result(dc_dtype, tasklet_node, "result", use_array=use_array) - def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: + def _visit_make_tuple(self, node: gtir.FunCall) -> tuple[IteratorExpr | DataExpr]: + assert cpm.is_call_to(node, "make_tuple") + return tuple(self.visit(arg) for arg in node.args) + + def _visit_tuple_get( + self, node: gtir.FunCall + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr]: + assert cpm.is_call_to(node, "tuple_get") + assert len(node.args) == 2 + + if not isinstance(node.args[0], gtir.Literal): + raise ValueError("Tuple can only be subscripted with compile-time constants.") + assert ti.is_integral(node.args[0].type) + index = int(node.args[0].value) + + tuple_fields = self.visit(node.args[1]) + return tuple_fields[index] + + def visit_FunCall( + self, node: gtir.FunCall + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) + elif cpm.is_call_to(node, "if_"): + return self._visit_if(node) + elif cpm.is_call_to(node, "neighbors"): return self._visit_neighbors(node) + elif cpm.is_call_to(node, "list_get"): + return self._visit_list_get(node) + + elif cpm.is_call_to(node, "make_tuple"): + return self._visit_make_tuple(node) + + elif cpm.is_call_to(node, "tuple_get"): + return self._visit_tuple_get(node) + elif cpm.is_applied_map(node): return self._visit_map(node) @@ -1279,35 +1672,52 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: else: raise NotImplementedError(f"Invalid 'FunCall' node: {node}.") - def visit_Lambda(self, node: gtir.Lambda) -> DataflowOutputEdge: - result: DataExpr = self.visit(node.expr) + def visit_Lambda( + self, node: gtir.Lambda + ) -> DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...]: + def _visit_Lambda_impl( + output_expr: DataflowOutputEdge | ValueExpr | MemletExpr | SymbolExpr, + ) -> DataflowOutputEdge: + if isinstance(output_expr, DataflowOutputEdge): + return output_expr + if isinstance(output_expr, ValueExpr): + return DataflowOutputEdge(self.state, output_expr) + + if isinstance(output_expr, MemletExpr): + # special case where the field operator is simply copying data from source to destination node + output_dtype = output_expr.dc_node.desc(self.sdfg).dtype + tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") + self._add_input_data_edge( + output_expr.dc_node, + output_expr.subset, + tasklet_node, + "__inp", + ) + else: + # even simpler case, where a constant value is written to destination node + output_dtype = output_expr.dc_dtype + tasklet_node = self._add_tasklet( + "write", {}, {"__out"}, f"__out = {output_expr.value}" + ) - if isinstance(result, ValueExpr): - return DataflowOutputEdge(self.state, result) + output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") + return DataflowOutputEdge(self.state, output_expr) - if isinstance(result, MemletExpr): - # special case where the field operator is simply copying data from source to destination node - output_dtype = result.dc_node.desc(self.sdfg).dtype - tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") - self._add_input_data_edge( - result.dc_node, - result.subset, - tasklet_node, - "__inp", - ) - else: - # even simpler case, where a constant value is written to destination node - output_dtype = result.dc_dtype - tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {result.value}") + result = self.visit(node.expr) - output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") - return DataflowOutputEdge(self.state, output_expr) + return ( + gtx_utils.tree_map(_visit_Lambda_impl)(result) + if isinstance(result, tuple) + else _visit_Lambda_impl(result) + ) def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: dc_dtype = dace_utils.as_dace_type(node.type) return SymbolExpr(node.value, dc_dtype) - def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolExpr: + def visit_SymRef( + self, node: gtir.SymRef + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: param = str(node.id) if param in self.symbol_map: return self.symbol_map[param] @@ -1318,8 +1728,13 @@ def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolE def visit_let( self, node: gtir.Lambda, - args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], - ) -> DataflowOutputEdge: + args: Sequence[ + IteratorExpr + | MemletExpr + | ValueExpr + | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] + ], + ) -> DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...]: """ Maps lambda arguments to internal parameters. @@ -1349,13 +1764,21 @@ def visit_let( return self.visit(node) -def visit_lambda( +def translate_lambda_to_dataflow( sdfg: dace.SDFG, state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, + sdfg_builder: gtir_sdfg.DataflowBuilder, node: gtir.Lambda, - args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], -) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: + args: Sequence[ + IteratorExpr + | MemletExpr + | ValueExpr + | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] + ], +) -> tuple[ + list[DataflowInputEdge], + DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...], +]: """ Entry point to visit a `Lambda` node and lower it to a dataflow graph, that can be instantiated inside a map scope implementing the field operator. @@ -1367,7 +1790,7 @@ def visit_lambda( Args: sdfg: The SDFG where the dataflow graph will be instantiated. state: The SDFG state where the dataflow graph will be instantiated. - sdfg_builder: Helper class to build the SDFG. + sdfg_builder: Helper class to build the dataflow inside the given SDFG. node: Lambda node to visit. args: Arguments passed to lambda node. @@ -1377,5 +1800,5 @@ def visit_lambda( - Output data connection. """ taskgen = LambdaToDataflow(sdfg, state, sdfg_builder) - output_edge = taskgen.visit_let(node, args) - return taskgen.input_edges, output_edge + output_edges = taskgen.visit_let(node, args) + return taskgen.input_edges, output_edges diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 7cb1461746..23a36ba79f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -53,6 +53,25 @@ def unique_map_name(self, name: str) -> str: ... @abc.abstractmethod def unique_tasklet_name(self, name: str) -> str: ... + def add_temp_array( + self, sdfg: dace.SDFG, shape: Sequence[Any], dtype: dace.dtypes.typeclass + ) -> tuple[str, dace.data.Scalar]: + """Add a temporary array to the SDFG.""" + return sdfg.add_temp_transient(shape, dtype) + + def add_temp_array_like( + self, sdfg: dace.SDFG, datadesc: dace.data.Array + ) -> tuple[str, dace.data.Scalar]: + """Add a temporary array to the SDFG.""" + return sdfg.add_temp_transient_like(datadesc) + + def add_temp_scalar( + self, sdfg: dace.SDFG, dtype: dace.dtypes.typeclass + ) -> tuple[str, dace.data.Scalar]: + """Add a temporary scalar to the SDFG.""" + temp_name = sdfg.temp_data_name() + return sdfg.add_scalar(temp_name, dtype, transient=True) + def add_map( self, name: str, @@ -86,9 +105,9 @@ def add_mapped_tasklet( state: dace.SDFGState, map_ranges: Dict[str, str | dace.subsets.Subset] | List[Tuple[str, str | dace.subsets.Subset]], - inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + inputs: Dict[str, dace.Memlet], code: str, - outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + outputs: Dict[str, dace.Memlet], **kwargs: Any, ) -> tuple[dace.nodes.Tasklet, dace.nodes.MapEntry, dace.nodes.MapExit]: """Wrapper of `dace.SDFGState.add_mapped_tasklet` that assigns unique name.""" @@ -149,15 +168,6 @@ def _collect_symbols_in_domain_expressions( ) -def _get_tuple_type(data: tuple[gtir_builtin_translators.FieldopResult, ...]) -> ts.TupleType: - """ - Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. - """ - return ts.TupleType( - types=[_get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] - ) - - @dataclasses.dataclass(frozen=True) class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. @@ -173,9 +183,9 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """ offset_provider_type: gtx_common.OffsetProviderType - global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {}) + global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=dict) field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = dataclasses.field( - default_factory=lambda: {} + default_factory=dict ) map_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") @@ -246,7 +256,6 @@ def _add_storage( name: str, gt_type: ts.DataType, transient: bool = True, - tuple_name: Optional[str] = None, ) -> list[tuple[str, ts.DataType]]: """ Add storage in the SDFG for a given GT4Py data symbol. @@ -266,7 +275,6 @@ def _add_storage( name: Symbol Name to be allocated. gt_type: GT4Py symbol type. transient: True when the data symbol has to be allocated as internal storage. - tuple_name: Must be set for tuple fields in order to use the same array shape and strides symbols. Returns: List of tuples '(data_name, gt_type)' where 'data_name' is the name of @@ -277,11 +285,10 @@ def _add_storage( """ if isinstance(gt_type, ts.TupleType): tuple_fields = [] - for tname, ttype in dace_gtir_utils.get_tuple_fields(name, gt_type, flatten=True): + for sym in dace_gtir_utils.flatten_tuple_fields(name, gt_type): + assert isinstance(sym.type, ts.DataType) tuple_fields.extend( - self._add_storage( - sdfg, symbolic_arguments, tname, ttype, transient, tuple_name=name - ) + self._add_storage(sdfg, symbolic_arguments, sym.id, sym.type, transient) ) return tuple_fields @@ -293,16 +300,9 @@ def _add_storage( # ListType not supported: concept is represented as Field with local Dimension assert isinstance(gt_type.dtype, ts.ScalarType) dc_dtype = dace_utils.as_dace_type(gt_type.dtype) - if tuple_name is None: - # Use symbolic shape, which allows to invoke the program with fields of different size; - # and symbolic strides, which enables decoupling the memory layout from generated code. - sym_shape, sym_strides = self._make_array_shape_and_strides(name, gt_type.dims) - else: - # All fields in a tuple must have the same dims and sizes, - # therefore we use the same shape and strides symbols based on 'tuple_name'. - sym_shape, sym_strides = self._make_array_shape_and_strides( - tuple_name, gt_type.dims - ) + # Use symbolic shape, which allows to invoke the program with fields of different size; + # and symbolic strides, which enables decoupling the memory layout from generated code. + sym_shape, sym_strides = self._make_array_shape_and_strides(name, gt_type.dims) sdfg.add_array(name, sym_shape, dc_dtype, strides=sym_strides, transient=transient) return [(name, gt_type)] @@ -367,7 +367,7 @@ def make_temps( if desc.transient or not use_temp: return field else: - temp, _ = sdfg.add_temp_transient_like(desc) + temp, _ = self.add_temp_array_like(sdfg, desc) temp_node = head_state.add_access(temp) head_state.add_nedge( field.dc_node, temp_node, sdfg.make_array_memlet(field.dc_node.data) @@ -438,13 +438,7 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: assert len(self.field_offsets) == 0 sdfg = dace.SDFG(node.id) - sdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) - - # DaCe requires C-compatible strings for the names of data containers, - # such as arrays and scalars. GT4Py uses a unicode symbols ('ᐞ') as name - # separator in the SSA pass, which generates invalid symbols for DaCe. - # Here we find new names for invalid symbols present in the IR. - node = dace_gtir_utils.replace_invalid_symbols(sdfg, node) + sdfg.debuginfo = dace_utils.debug_info(node) # start block of the stateful graph entry_state = sdfg.add_state("program_entry", is_start_block=True) @@ -633,24 +627,13 @@ def visit_Lambda( (str(param.id), arg) for param, arg in zip(node.params, args, strict=True) ] - def flatten_tuples( - name: str, - arg: gtir_builtin_translators.FieldopResult, - ) -> list[tuple[str, gtir_builtin_translators.FieldopData]]: - if isinstance(arg, tuple): - tuple_type = _get_tuple_type(arg) - tuple_field_names = [ - arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type) - ] - tuple_args = zip(tuple_field_names, arg, strict=True) - return list( - itertools.chain(*[flatten_tuples(fname, farg) for fname, farg in tuple_args]) - ) - else: - return [(name, arg)] - lambda_arg_nodes = dict( - itertools.chain(*[flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) + itertools.chain( + *[ + gtir_builtin_translators.flatten_tuples(pname, arg) + for pname, arg in lambda_args_mapping + ] + ) ) # inherit symbols from parent scope but eventually override with local symbols @@ -658,7 +641,9 @@ def flatten_tuples( sym: self.global_symbols[sym] for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) } | { - pname: _get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type + pname: gtir_builtin_translators.get_tuple_type(arg) + if isinstance(arg, tuple) + else arg.gt_type for pname, arg in lambda_args_mapping } @@ -673,12 +658,12 @@ def get_field_domain_offset( elif field_domain_offset := self.field_offsets.get(p_name, None): return {p_name: field_domain_offset} elif isinstance(p_type, ts.TupleType): - p_fields = dace_gtir_utils.get_tuple_fields(p_name, p_type, flatten=True) + tsyms = dace_gtir_utils.flatten_tuple_fields(p_name, p_type) return functools.reduce( - lambda field_offsets, field: ( - field_offsets | get_field_domain_offset(field[0], field[1]) + lambda field_offsets, sym: ( + field_offsets | get_field_domain_offset(sym.id, sym.type) # type: ignore[arg-type] ), - p_fields, + tsyms, {}, ) return {} @@ -722,15 +707,24 @@ def get_field_domain_offset( } input_memlets = {} - nsdfg_symbols_mapping: dict[str, dace.symbolic.SymExpr] = {} + nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} for nsdfg_dataname, nsdfg_datadesc in nsdfg.arrays.items(): if nsdfg_datadesc.transient: continue - datadesc: Optional[dace.dtypes.Data] = None + if nsdfg_dataname in lambda_arg_nodes: src_node = lambda_arg_nodes[nsdfg_dataname].dc_node dataname = src_node.data datadesc = src_node.desc(sdfg) + nsdfg_symbols_mapping |= { + str(nested_symbol): parent_symbol + for nested_symbol, parent_symbol in zip( + [*nsdfg_datadesc.shape, *nsdfg_datadesc.strides], + [*datadesc.shape, *datadesc.strides], + strict=True, + ) + if isinstance(nested_symbol, dace.symbol) + } else: dataname = nsdfg_dataname datadesc = sdfg.arrays[nsdfg_dataname] @@ -741,16 +735,6 @@ def get_field_domain_offset( input_memlets[nsdfg_dataname] = sdfg.make_array_memlet(dataname) - nsdfg_symbols_mapping |= { - str(nested_symbol): parent_symbol - for nested_symbol, parent_symbol in zip( - [*nsdfg_datadesc.shape, *nsdfg_datadesc.strides], - [*datadesc.shape, *datadesc.strides], - strict=True, - ) - if isinstance(nested_symbol, dace.symbol) - } - # Process lambda outputs # # The output arguments do not really exist, so they are not allocated before @@ -817,7 +801,7 @@ def construct_output_for_nested_sdfg( # that is externally allocated, as required by the SDFG IR. An output edge will write the result # from the nested-SDFG to a new intermediate data container allocated in the parent SDFG. inner_desc.transient = False - outer, outer_desc = sdfg.add_temp_transient_like(inner_desc) + outer, outer_desc = self.add_temp_array_like(sdfg, inner_desc) # We cannot use a copy of the inner data descriptor directly, we have to apply the symbol mapping. dace.symbolic.safe_replace( nsdfg_symbols_mapping, @@ -884,6 +868,13 @@ def build_sdfg_from_gtir( ir = gtir_type_inference.infer(ir, offset_provider_type=offset_provider_type) ir = ir_prune_casts.PruneCasts().visit(ir) + + # DaCe requires C-compatible strings for the names of data containers, + # such as arrays and scalars. GT4Py uses a unicode symbols ('ᐞ') as name + # separator in the SSA pass, which generates invalid symbols for DaCe. + # Here we find new names for invalid symbols present in the IR. + ir = dace_gtir_utils.replace_invalid_symbols(ir) + sdfg_genenerator = GTIRToSDFG(offset_provider_type) sdfg = sdfg_genenerator.visit(ir) assert isinstance(sdfg, dace.SDFG) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index c46420c24b..6121529161 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -8,14 +8,14 @@ from __future__ import annotations -import itertools from typing import Dict, TypeVar import dace from gt4py import eve -from gt4py.next import common as gtx_common +from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts @@ -27,43 +27,55 @@ def get_map_variable(dim: gtx_common.Dimension) -> str: return f"i_{dim.value}_gtx_{dim.kind}{suffix}" -def get_tuple_fields( - tuple_name: str, tuple_type: ts.TupleType, flatten: bool = False -) -> list[tuple[str, ts.DataType]]: +def make_symbol_tree(tuple_name: str, tuple_type: ts.TupleType) -> tuple[gtir.Sym, ...]: """ - Creates a list of names with the corresponding data type for all elements of the given tuple. + Creates a tree representation of the symbols corresponding to the tuple fields. + The constructed tree preserves the nested nature of the tuple type, if any. Examples -------- >>> sty = ts.ScalarType(kind=ts.ScalarKind.INT32) >>> fty = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) >>> t = ts.TupleType(types=[sty, ts.TupleType(types=[fty, sty])]) - >>> assert get_tuple_fields("a", t) == [("a_0", sty), ("a_1", ts.TupleType(types=[fty, sty]))] - >>> assert get_tuple_fields("a", t, flatten=True) == [ - ... ("a_0", sty), - ... ("a_1_0", fty), - ... ("a_1_1", sty), - ... ] + >>> assert make_symbol_tree("a", t) == ( + ... im.sym("a_0", sty), + ... (im.sym("a_1_0", fty), im.sym("a_1_1", sty)), + ... ) """ assert all(isinstance(t, ts.DataType) for t in tuple_type.types) fields = [(f"{tuple_name}_{i}", field_type) for i, field_type in enumerate(tuple_type.types)] - if flatten: - expanded_fields: list[list[tuple[str, ts.DataType]]] = [ - get_tuple_fields(field_name, field_type) - if isinstance(field_type, ts.TupleType) - else [(field_name, field_type)] # type: ignore[list-item] # checked in assert - for field_name, field_type in fields - ] - return list(itertools.chain(*expanded_fields)) - else: - return fields # type: ignore[return-value] # checked in assert - - -def replace_invalid_symbols(sdfg: dace.SDFG, ir: gtir.Program) -> gtir.Program: + return tuple( + make_symbol_tree(field_name, field_type) # type: ignore[misc] + if isinstance(field_type, ts.TupleType) + else im.sym(field_name, field_type) + for field_name, field_type in fields + ) + + +def flatten_tuple_fields(tuple_name: str, tuple_type: ts.TupleType) -> list[gtir.Sym]: + """ + Creates a list of symbols, annotated with the data type, for all elements of the given tuple. + + Examples + -------- + >>> sty = ts.ScalarType(kind=ts.ScalarKind.INT32) + >>> fty = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) + >>> t = ts.TupleType(types=[sty, ts.TupleType(types=[fty, sty])]) + >>> assert flatten_tuple_fields("a", t) == [ + ... im.sym("a_0", sty), + ... im.sym("a_1_0", fty), + ... im.sym("a_1_1", sty), + ... ] + """ + symbol_tree = make_symbol_tree(tuple_name, tuple_type) + return list(gtx_utils.flatten_nested_tuple(symbol_tree)) + + +def replace_invalid_symbols(ir: gtir.Program) -> gtir.Program: """ Ensure that all symbols used in the program IR are valid strings (e.g. no unicode-strings). - If any invalid symbol present, this funtion returns a copy of the input IR where + If any invalid symbol present, this function returns a copy of the input IR where the invalid symbols have been replaced with new names. If all symbols are valid, the input IR is returned without copying it. """ @@ -85,12 +97,17 @@ def visit_SymRef(self, node: gtir.SymRef, *, symtable: Dict[str, str]) -> gtir.S if not all(dace.dtypes.validate_name(str(sym.id)) for sym in ir.params): raise ValueError("Invalid symbol in program parameters.") + ir_sym_ids = {str(sym.id) for sym in eve.walk_values(ir).if_isinstance(gtir.Sym).to_set()} + ir_ssa_uuid = eve.utils.UIDGenerator(prefix="gtir_tmp") + invalid_symbols_mapping = { - sym_id: sdfg.temp_data_name() - for sym in eve.walk_values(ir).if_isinstance(gtir.Sym).to_set() - if not dace.dtypes.validate_name(sym_id := str(sym.id)) + sym_id: ir_ssa_uuid.sequential_id() + for sym_id in ir_sym_ids + if not dace.dtypes.validate_name(sym_id) } - if len(invalid_symbols_mapping) != 0: - return ReplaceSymbols().visit(ir, symtable=invalid_symbols_mapping) - else: + if len(invalid_symbols_mapping) == 0: return ir + + # assert that the new symbol names are not used in the IR + assert ir_sym_ids.isdisjoint(invalid_symbols_mapping.values()) + return ReplaceSymbols().visit(ir, symtable=invalid_symbols_mapping) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index bed6e89a52..e19d9e1d81 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -86,14 +86,18 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ALL = "all" REQUIRES_ATLAS = "requires_atlas" USES_APPLIED_SHIFTS = "uses_applied_shifts" +USES_CAN_DEREF = "uses_can_deref" +USES_COMPOSITE_SHIFTS = "uses_composite_shifts" USES_CONSTANT_FIELDS = "uses_constant_fields" USES_DYNAMIC_OFFSETS = "uses_dynamic_offsets" USES_FLOORDIV = "uses_floordiv" USES_IF_STMTS = "uses_if_stmts" USES_IR_IF_STMTS = "uses_ir_if_stmts" USES_INDEX_FIELDS = "uses_index_fields" +USES_LIFT = "uses_lift" USES_NEGATIVE_MODULO = "uses_negative_modulo" USES_ORIGIN = "uses_origin" +USES_REDUCE_WITH_LAMBDA = "uses_reduce_with_lambda" USES_SCAN = "uses_scan" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" USES_SCAN_IN_STENCIL = "uses_scan_in_stencil" @@ -105,6 +109,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields" USES_STRIDED_NEIGHBOR_OFFSET = "uses_strided_neighbor_offset" USES_TUPLE_ARGS = "uses_tuple_args" +USES_TUPLE_ITERATOR = "uses_tuple_iterator" USES_TUPLE_RETURNS = "uses_tuple_returns" USES_ZERO_DIMENSIONAL_FIELDS = "uses_zero_dimensional_fields" USES_CARTESIAN_SHIFT = "uses_cartesian_shift" @@ -132,11 +137,21 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): DOMAIN_INFERENCE_SKIP_LIST = [ (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ] -DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ - (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), -] +DACE_SKIP_TEST_LIST = ( + COMMON_SKIP_TEST_LIST + + DOMAIN_INFERENCE_SKIP_LIST + + [ + (USES_CAN_DEREF, XFAIL, UNSUPPORTED_MESSAGE), + (USES_COMPOSITE_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_LIFT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_REDUCE_WITH_LAMBDA, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_ITERATOR, XFAIL, UNSUPPORTED_MESSAGE), + ] +) EMBEDDED_SKIP_LIST = [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 9de4449ac2..2e40cb897a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -360,6 +360,7 @@ def testee(qc: cases.IKFloatField, scalar: float): @pytest.mark.uses_scan @pytest.mark.uses_scan_in_field_operator +@pytest.mark.uses_tuple_iterator def test_tuple_scalar_scan(cartesian_case): @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def testee_scan( @@ -867,8 +868,9 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]): ) -@pytest.mark.uses_tuple_args @pytest.mark.uses_scan +@pytest.mark.uses_tuple_args +@pytest.mark.uses_tuple_iterator def test_scan_nested_tuple_input(cartesian_case): init = 1.0 k_size = cartesian_case.default_sizes[KDim] @@ -897,6 +899,7 @@ def simple_scan_operator(carry: float, a: tuple[float, float]) -> float: @pytest.mark.uses_scan +@pytest.mark.uses_tuple_iterator def test_scan_different_domain_in_tuple(cartesian_case): init = 1.0 i_size = cartesian_case.default_sizes[IDim] @@ -936,6 +939,7 @@ def foo( @pytest.mark.uses_scan +@pytest.mark.uses_tuple_iterator def test_scan_tuple_field_scalar_mixed(cartesian_case): init = 1.0 i_size = cartesian_case.default_sizes[IDim] diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index c0a4cd166d..885a272bfe 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -244,6 +244,7 @@ def foo(a): @pytest.mark.parametrize("stencil", [_can_deref, _can_deref_lifted]) +@pytest.mark.uses_can_deref def test_can_deref(program_processor, stencil): program_processor, validate = program_processor diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index fe89fe7c9d..7836b1b110 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -38,6 +38,7 @@ def baz(baz_inp): return deref(lift(bar)(baz_inp)) +@pytest.mark.uses_lift def test_trivial(program_processor): program_processor, validate = program_processor @@ -66,6 +67,7 @@ def stencil_shifted_arg_to_lift(inp): return deref(lift(deref)(shift(I, -1)(inp))) +@pytest.mark.uses_lift def test_shifted_arg_to_lift(program_processor): program_processor, validate = program_processor diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 39d0bd69c3..ea89bb23ba 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -219,6 +219,7 @@ def tuple_input(inp): @pytest.mark.uses_tuple_args +@pytest.mark.uses_tuple_iterator def test_tuple_field_input(program_processor): program_processor, validate = program_processor @@ -272,6 +273,7 @@ def tuple_tuple_input(inp): @pytest.mark.uses_tuple_args +@pytest.mark.uses_tuple_iterator def test_tuple_of_tuple_of_field_input(program_processor): program_processor, validate = program_processor @@ -319,6 +321,7 @@ def test_field_of_2_extra_dim_input(program_processor): @pytest.mark.uses_tuple_args +@pytest.mark.uses_tuple_iterator def test_scalar_tuple_args(program_processor): @fundef def stencil(inp): @@ -348,6 +351,7 @@ def stencil(inp): @pytest.mark.uses_tuple_args +@pytest.mark.uses_tuple_iterator def test_mixed_field_scalar_tuple_arg(program_processor): @fundef def stencil(inp): diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index f8e9f22eff..3b4fc0a70c 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -122,19 +122,19 @@ def k_level_condition_upper_tuple(k_idx, k_level): @pytest.mark.parametrize( "fun, k_level, inp_function, ref_function", [ - ( + pytest.param( k_level_condition_lower, lambda inp: 0, lambda k_size: gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), lambda inp: np.concatenate([[0], inp[:-1]]), ), - ( + pytest.param( k_level_condition_upper, lambda inp: inp.shape[0] - 1, lambda k_size: gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), lambda inp: np.concatenate([inp[1:], [0]]), ), - ( + pytest.param( k_level_condition_upper_tuple, lambda inp: inp[0].shape[0] - 1, lambda k_size: ( @@ -142,6 +142,7 @@ def k_level_condition_upper_tuple(k_idx, k_level): gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), ), lambda inp: np.concatenate([(inp[0][1:] + inp[1][1:]), [0]]), + marks=pytest.mark.uses_tuple_iterator, ), ], ) @@ -184,6 +185,7 @@ def ksum_fencil(i_size, k_start, k_end, inp, out): "kstart, reference", [(0, np.asarray([[0, 1, 3, 6, 10, 15, 21]])), (2, np.asarray([[0, 0, 2, 5, 9, 14, 20]]))], ) +@pytest.mark.uses_scan def test_ksum_scan(program_processor, kstart, reference): program_processor, validate = program_processor shape = [1, 7] @@ -211,6 +213,7 @@ def ksum_back_fencil(i_size, k_size, inp, out): set_at(as_fieldop(scan(ksum, False, 0.0), domain)(inp), domain, out) +@pytest.mark.uses_scan def test_ksum_back_scan(program_processor): program_processor, validate = program_processor shape = [1, 7] diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index ac7ce9e544..ff87de7348 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -149,6 +149,7 @@ def first_vertex_neigh_of_first_edge_neigh_of_cells(in_vertices): return deref(shift(E2V, 0)(shift(C2E, 0)(in_vertices))) +@pytest.mark.uses_composite_shifts def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processor): program_processor, validate = program_processor inp = vertex_index_field() @@ -174,6 +175,7 @@ def sparse_stencil(non_sparse, inp): return reduce(lambda a, b, c: a + c, 0)(neighbors(V2E, non_sparse), deref(inp)) +@pytest.mark.uses_reduce_with_lambda def test_sparse_input_field(program_processor): program_processor, validate = program_processor @@ -196,6 +198,7 @@ def test_sparse_input_field(program_processor): assert np.allclose(out.asnumpy(), ref) +@pytest.mark.uses_reduce_with_lambda def test_sparse_input_field_v2v(program_processor): program_processor, validate = program_processor @@ -330,6 +333,7 @@ def lift_stencil(inp): return deref(shift(V2V, 2)(lift(deref_stencil)(inp))) +@pytest.mark.uses_lift def test_lift(program_processor): program_processor, validate = program_processor inp = vertex_index_field() diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 03662f8dcc..0bd8653a03 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -60,6 +60,10 @@ def _program_processor(request) -> tuple[ProgramProcessor, bool]: # pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation (next_tests.definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), (next_tests.definitions.ProgramFormatterId.GTFN_CPP_FORMATTER, False), + pytest.param( + (next_tests.definitions.OptionalProgramBackendId.DACE_CPU_NO_OPT, True), + marks=pytest.mark.requires_dace, + ), ], ids=lambda p: p[0].short_id() if p[0] is not None else "None", ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 03b8e3bc15..225d22562f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -256,7 +256,16 @@ def test_gtir_tuple_args(): x_fields = (a, a, b) - sdfg(*x_fields, c, **FSYMBOLS) + tuple_symbols = { + "__x_0_size_0": N, + "__x_0_stride_0": 1, + "__x_1_0_size_0": N, + "__x_1_0_stride_0": 1, + "__x_1_1_size_0": N, + "__x_1_1_stride_0": 1, + } + + sdfg(*x_fields, c, **FSYMBOLS, **tuple_symbols) assert np.allclose(c, a * 2 + b) @@ -418,7 +427,16 @@ def test_gtir_tuple_return(): z_fields = (np.empty_like(a), np.empty_like(a), np.empty_like(a)) - sdfg(a, b, *z_fields, **FSYMBOLS) + tuple_symbols = { + "__z_0_0_size_0": N, + "__z_0_0_stride_0": 1, + "__z_0_1_size_0": N, + "__z_0_1_stride_0": 1, + "__z_1_size_0": N, + "__z_1_stride_0": 1, + } + + sdfg(a, b, *z_fields, **FSYMBOLS, **tuple_symbols) assert np.allclose(z_fields[0], a + b) assert np.allclose(z_fields[1], a) assert np.allclose(z_fields[2], b) @@ -673,9 +691,16 @@ def test_gtir_cond_with_tuple_return(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + tuple_symbols = { + "__z_0_size_0": N, + "__z_0_stride_0": 1, + "__z_1_size_0": N, + "__z_1_stride_0": 1, + } + for s in [False, True]: z_fields = (np.empty_like(a), np.empty_like(a)) - sdfg(a, b, c, *z_fields, pred=np.bool_(s), **FSYMBOLS) + sdfg(a, b, c, *z_fields, pred=np.bool_(s), **FSYMBOLS, **tuple_symbols) assert np.allclose(z_fields[0], a if s else b) assert np.allclose(z_fields[1], b if s else a) @@ -1846,7 +1871,14 @@ def test_gtir_let_lambda_with_tuple1(): a_ref = np.concatenate((z_fields[0][:1], a[1 : N - 1], z_fields[0][N - 1 :])) b_ref = np.concatenate((z_fields[1][:1], b[1 : N - 1], z_fields[1][N - 1 :])) - sdfg(a, b, *z_fields, **FSYMBOLS) + tuple_symbols = { + "__z_0_size_0": N, + "__z_0_stride_0": 1, + "__z_1_size_0": N, + "__z_1_stride_0": 1, + } + + sdfg(a, b, *z_fields, **FSYMBOLS, **tuple_symbols) assert np.allclose(z_fields[0], a_ref) assert np.allclose(z_fields[1], b_ref) @@ -1886,7 +1918,16 @@ def test_gtir_let_lambda_with_tuple2(): z_fields = (np.empty_like(a), np.empty_like(a), np.empty_like(a)) - sdfg(a, b, *z_fields, **FSYMBOLS) + tuple_symbols = { + "__z_0_size_0": N, + "__z_0_stride_0": 1, + "__z_1_size_0": N, + "__z_1_stride_0": 1, + "__z_2_size_0": N, + "__z_2_stride_0": 1, + } + + sdfg(a, b, *z_fields, **FSYMBOLS, **tuple_symbols) assert np.allclose(z_fields[0], a + b) assert np.allclose(z_fields[1], val) assert np.allclose(z_fields[2], b) @@ -1938,8 +1979,17 @@ def test_gtir_if_scalars(): sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + tuple_symbols = { + "__x_0_size_0": N, + "__x_0_stride_0": 1, + "__x_1_0_size_0": N, + "__x_1_0_stride_0": 1, + "__x_1_1_size_0": N, + "__x_1_1_stride_0": 1, + } + for s in [False, True]: - sdfg(x_0=a, x_1_0=d1, x_1_1=d2, z=b, pred=np.bool_(s), **FSYMBOLS) + sdfg(x_0=a, x_1_0=d1, x_1_1=d2, z=b, pred=np.bool_(s), **FSYMBOLS, **tuple_symbols) assert np.allclose(b, (a + d1 if s else a + d2)) From b73b6ff0d78eea712ebe1e329ae816032636fbac Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 15 Jan 2025 17:40:59 +0100 Subject: [PATCH 093/178] ci[cartesian]: disable test_K_offset_write on gt:gpu backend (#1800) Found other test failures related to #1684. --- .../multi_feature_tests/test_code_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 57c52eae12..000dc34c7f 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -583,7 +583,7 @@ def test_K_offset_write(backend): if backend == "cuda": pytest.skip("cuda K-offset write generates bad code") - if backend == "dace:gpu": + if backend in ["gt:gpu", "dace:gpu"]: pytest.skip( f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1684" ) From 99b504277c48be45f72b24260827d9b9054318c4 Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 16 Jan 2025 10:24:21 +0100 Subject: [PATCH 094/178] ci: disable x86-targets (DaintXC) in CSCS CI-Ext (#1801) Disable x86-targets (DaintXC) in CSCS CI-Ext, since DaintXC has been decommissioned. We only disable but still keep the x86 configuration in case in future we need to set it up on a Alps vCluster. --- ci/cscs-ci.yml | 56 +++++++++++++++++++++++++------------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index 349089ebfa..ad919d6bc0 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -51,19 +51,19 @@ stages: CUPY_VERSION: 13.3.0 UBUNTU_VERSION: 22.04 -build_py311_baseimage_x86_64: - extends: .build_baseimage_x86_64 - variables: - <<: *py311 +# build_py311_baseimage_x86_64: +# extends: .build_baseimage_x86_64 +# variables: +# <<: *py311 build_py311_baseimage_aarch64: extends: .build_baseimage_aarch64 variables: <<: *py311 -build_py310_baseimage_x86_64: - extends: .build_baseimage_x86_64 - variables: - <<: *py310 +# build_py310_baseimage_x86_64: +# extends: .build_baseimage_x86_64 +# variables: +# <<: *py310 build_py310_baseimage_aarch64: extends: .build_baseimage_aarch64 variables: @@ -83,22 +83,22 @@ build_py310_baseimage_aarch64: .build_image_aarch64: extends: [.container-builder-cscs-gh200, .build_image] -build_py311_image_x86_64: - extends: .build_image_x86_64 - needs: [build_py311_baseimage_x86_64] - variables: - <<: *py311 +# build_py311_image_x86_64: +# extends: .build_image_x86_64 +# needs: [build_py311_baseimage_x86_64] +# variables: +# <<: *py311 build_py311_image_aarch64: extends: .build_image_aarch64 needs: [build_py311_baseimage_aarch64] variables: <<: *py311 -build_py310_image_x86_64: - extends: .build_image_x86_64 - needs: [build_py310_baseimage_x86_64] - variables: - <<: *py310 +# build_py310_image_x86_64: +# extends: .build_image_x86_64 +# needs: [build_py310_baseimage_x86_64] +# variables: +# <<: *py310 build_py310_image_aarch64: extends: .build_image_aarch64 needs: [build_py310_baseimage_aarch64] @@ -149,22 +149,22 @@ build_py310_image_aarch64: # when high test parallelism is used. NUM_PROCESSES: 16 -test_py311_x86_64: - extends: [.test_helper_x86_64] - needs: [build_py311_image_x86_64] - variables: - <<: *py311 +# test_py311_x86_64: +# extends: [.test_helper_x86_64] +# needs: [build_py311_image_x86_64] +# variables: +# <<: *py311 test_py311_aarch64: extends: [.test_helper_aarch64] needs: [build_py311_image_aarch64] variables: <<: *py311 -test_py310_x86_64: - extends: [.test_helper_x86_64] - needs: [build_py310_image_x86_64] - variables: - <<: *py310 +# test_py310_x86_64: +# extends: [.test_helper_x86_64] +# needs: [build_py310_image_x86_64] +# variables: +# <<: *py310 test_py310_aarch64: extends: [.test_helper_aarch64] needs: [build_py310_image_aarch64] From 1b882761a58d0a58d4456dfa950e8a3c68b9b114 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Thu, 16 Jan 2025 11:14:07 +0100 Subject: [PATCH 095/178] fix[dace][next]: Fix for DistributedBufferRelocator (#1799) This PR fixes an error that was reported by Edoardo (@edopao). The bug was because the `DistributedBufferRelocator` transformation did not check if its insertion would create a read-write conflict. This commit adds such a check, that is, however, not very sophisticated and needs some improvements. However, the example /`model/atmosphere/dycore/tests/dycore_stencil_tests/test_compute_exner_from_rhotheta.py`) where it surfaced, does hold more challenges. The main purpose of this PR is to unblock further development in ICON4Py. Link to ICON4Py PR: https://github.com/C2SM/icon4py/pull/638 --- .../transformations/simplify.py | 254 ++++++++++++++---- .../test_distributed_buffer_relocator.py | 217 ++++++++++++++- 2 files changed, 406 insertions(+), 65 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index 4339a761fa..bb95244aef 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -374,7 +374,7 @@ def apply( raise -AccessLocation: TypeAlias = tuple[dace.SDFGState, dace_nodes.AccessNode] +AccessLocation: TypeAlias = tuple[dace_nodes.AccessNode, dace.SDFGState] """Describes an access node and the state in which it is located. """ @@ -387,29 +387,38 @@ class DistributedBufferRelocator(dace_transformation.Pass): in each branch and then in the join state written back. Thus there is some additional storage needed. The transformation will look for the following situation: - - A transient data container, called `src_cont`, is written into another - container, called `dst_cont`, which is not transient. - - The access node of `src_cont` has an in degree of zero and an out degree of one. - - The access node of `dst_cont` has an in degree of of one and an + - A transient data container, called `temp_storage`, is written into another + container, called `dest_storage`, which is not transient. + - The access node of `temp_storage` has an in degree of zero and an out degree of one. + - The access node of `dest_storage` has an in degree of of one and an out degree of zero (this might be lifted). - - `src_cont` is not used afterwards. - - `dst_cont` is only used to implement the buffering. + - `temp_storage` is not used afterwards. + - `dest_storage` is only used to implement the buffering. - The function will relocate the writing of `dst_cont` to where `src_cont` is + The function will relocate the writing of `dest_storage` to where `temp_storage` is written, which might be multiple locations. It will also remove the writing back. It is advised that after this transformation simplify is run again. + The relocation will not take place if it might create data race. A necessary + but not sufficient condition for a data race is if `dest_storage` is present + in the state where `temp_storage` is defined. In addition at least one of the + following conditions has to be met: + - There are accesses to `dest_storage` that are not predecessor to the node where + the data is stored inside `temp_storage`. This check will ignore empty Memlets. + - There is a `dest_storage` access node, that has an output degree larger + than one. + Note: - Essentially this transformation removes the double buffering of `dst_cont`. - Because we ensure that that `dst_cont` is non transient this is okay, as our - rule guarantees this. + - Essentially this transformation removes the double buffering of + `dest_storage`. Because we ensure that that `dest_storage` is non + transient this is okay, as our rule guarantees this. Todo: - - Allow that `dst_cont` can also be transient. - - Allow that `dst_cont` does not need to be a sink node, this is most + - Allow that `dest_storage` can also be transient. + - Allow that `dest_storage` does not need to be a sink node, this is most likely most relevant if it is transient. - - Check if `dst_cont` is used between where we want to place it and + - Check if `dest_storage` is used between where we want to place it and where it is currently used. """ @@ -489,10 +498,10 @@ def _find_candidates( where the temporary is defined. """ # All nodes that are used as distributed buffers. - candidate_src_cont: list[AccessLocation] = [] + candidate_temp_storage: list[AccessLocation] = [] - # Which `src_cont` access node is written back to which global memory. - src_cont_to_global: dict[dace_nodes.AccessNode, str] = {} + # Which `temp_storage` access node is written back to which global memory. + temp_storage_to_global: dict[dace_nodes.AccessNode, str] = {} for state in sdfg.states(): # These are the possible targets we want to write into. @@ -508,26 +517,26 @@ def _find_candidates( if len(candidate_dst_nodes) == 0: continue - for src_cont in state.source_nodes(): - if not isinstance(src_cont, dace_nodes.AccessNode): + for temp_storage in state.source_nodes(): + if not isinstance(temp_storage, dace_nodes.AccessNode): continue - if not src_cont.desc(sdfg).transient: + if not temp_storage.desc(sdfg).transient: continue - if state.out_degree(src_cont) != 1: + if state.out_degree(temp_storage) != 1: continue dst_candidate: dace_nodes.AccessNode = next( - iter(edge.dst for edge in state.out_edges(src_cont)) + iter(edge.dst for edge in state.out_edges(temp_storage)) ) if dst_candidate not in candidate_dst_nodes: continue - candidate_src_cont.append((src_cont, state)) - src_cont_to_global[src_cont] = dst_candidate.data + candidate_temp_storage.append((temp_storage, state)) + temp_storage_to_global[temp_storage] = dst_candidate.data - if len(candidate_src_cont) == 0: + if len(candidate_temp_storage) == 0: return [] # Now we have to find the places where the temporary sources are defined. - # I.e. This is also the location where the original value is defined. + # I.e. This is also the location where the temporary source was initialized. result_candidates: list[tuple[AccessLocation, list[AccessLocation]]] = [] def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: @@ -537,72 +546,199 @@ def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: if dst_state in reachable[src_state] and dst_state is not src_state } - for src_cont in candidate_src_cont: + for temp_storage in candidate_temp_storage: + temp_storage_node, temp_storage_state = temp_storage def_locations: list[AccessLocation] = [] - for upstream_state in find_upstream_states(src_cont[1]): - if src_cont[0].data in access_sets[upstream_state][1]: + for upstream_state in find_upstream_states(temp_storage_state): + if temp_storage_node.data in access_sets[upstream_state][1]: def_locations.extend( (data_node, upstream_state) for data_node in upstream_state.data_nodes() - if data_node.data == src_cont[0].data + if data_node.data == temp_storage_node.data ) if len(def_locations) != 0: - result_candidates.append((src_cont, def_locations)) + result_candidates.append((temp_storage, def_locations)) - # This transformation removes `src_cont` by writing its content directly - # to `dst_cont`, at the point where it is defined. + # This transformation removes `temp_storage` by writing its content directly + # to `dest_storage`, at the point where it is defined. # For this transformation to be valid the following conditions have to be met: - # - Between the definition of `src_cont` and the write back to `dst_cont`, - # `dst_cont` can not be accessed. - # - Between the definitions of `src_cont` and the point where it is written - # back, `src_cont` can only be accessed in the range that is written back. - # - After the write back point, `src_cont` shall not be accessed. This + # - Between the definition of `temp_storage` and the write back to `dest_storage`, + # `dest_storage` can not be accessed. + # - Between the definitions of `temp_storage` and the point where it is written + # back, `temp_storage` can only be accessed in the range that is written back. + # - After the write back point, `temp_storage` shall not be accessed. This # restriction could be lifted. # # To keep the implementation simple, we use the conditions: - # - `src_cont` is only accessed were it is defined and at the write back + # - `temp_storage` is only accessed were it is defined and at the write back # point. - # - Between the definitions of `src_cont` and the write back point, - # `dst_cont` is not used. + # - Between the definitions of `temp_storage` and the write back point, + # `dest_storage` is not used. result: list[tuple[AccessLocation, list[AccessLocation]]] = [] - for wb_localation, def_locations in result_candidates: + for wb_location, def_locations in result_candidates: + # Get the state and the location where the temporary is written back + # into the global data container. + wb_node, wb_state = wb_location + for def_node, def_state in def_locations: - # Test if `src_cont` is only accessed where it is defined and + # Test if `temp_storage` is only accessed where it is defined and # where it is written back. if gtx_transformations.util.is_accessed_downstream( start_state=def_state, sdfg=sdfg, - data_to_look=wb_localation[0].data, - nodes_to_ignore={def_node, wb_localation[0]}, + data_to_look=wb_node.data, + nodes_to_ignore={def_node, wb_node}, ): break # check if the global data is not used between the definition of - # `dst_cont` and where its written back. We allow one exception, - # if the global data is used in the state the distributed temporary - # is defined is used only for reading then it is ignored. This is - # allowed because of rule 3 of ADR0018. - glob_nodes_in_def_state = { - dnode - for dnode in def_state.data_nodes() - if dnode.data == src_cont_to_global[wb_localation[0]] + # `dest_storage` and where its written back. However, we ignore + # the state were `temp_storage` is defined. The checks if these + # checks are performed by the `_check_read_write_dependency()` + # function. + global_data_name = temp_storage_to_global[wb_node] + global_nodes_in_def_state = { + dnode for dnode in def_state.data_nodes() if dnode.data == global_data_name } - if any(def_state.in_degree(gdnode) != 0 for gdnode in glob_nodes_in_def_state): - break if gtx_transformations.util.is_accessed_downstream( start_state=def_state, sdfg=sdfg, - data_to_look=src_cont_to_global[wb_localation[0]], - nodes_to_ignore=glob_nodes_in_def_state, - states_to_ignore={wb_localation[1]}, + data_to_look=global_data_name, + nodes_to_ignore=global_nodes_in_def_state, + states_to_ignore={wb_state}, ): break + if self._check_read_write_dependency(sdfg, wb_location, def_locations): + break else: - result.append((wb_localation, def_locations)) + result.append((wb_location, def_locations)) return result + def _check_read_write_dependency( + self, + sdfg: dace.SDFG, + write_back_location: AccessLocation, + target_locations: list[AccessLocation], + ) -> bool: + """Tests if read-write conflicts would be created. + + This function ensures that the substitution of `write_back_location` into + `target_locations` will not create a read-write conflict. + The rules that are used for this are outlined in the class description. + + Args: + sdfg: The SDFG on which we operate. + write_back_location: Where currently the write back occurs. + target_locations: List of the locations where we would like to perform + the write back instead. + + Returns: + If a read-write dependency is detected then the function will return + `True` and if none was detected `False` will be returned. + """ + for target_location in target_locations: + if self._check_read_write_dependency_impl(sdfg, write_back_location, target_location): + return True + return False + + def _check_read_write_dependency_impl( + self, + sdfg: dace.SDFG, + write_back_location: AccessLocation, + target_location: AccessLocation, + ) -> bool: + """Tests if read-write conflict would be created for a single location. + + Args: + sdfg: The SDFG on which we operate. + write_back_location: Where currently the write back occurs. + target_locations: Location where the new write back should be performed. + + Todo: + Refine these checks later. + + Returns: + If a read-write dependency is detected then the function will return + `True` and if none was detected `False` will be returned. + """ + assert write_back_location[0].data == target_location[0].data + + # Get the state and the location where the temporary is written back + # into the global data container. Because `write_back_node` refers to + # the temporary we must query the graph to find the global node. + write_back_node, write_back_state = write_back_location + write_back_edge = next(iter(write_back_state.out_edges(write_back_node))) + global_data_name = write_back_edge.dst.data + assert not sdfg.arrays[global_data_name].transient + assert write_back_state.out_degree(write_back_node) == 1 + assert write_back_state.in_degree(write_back_node) == 0 + + # Get the location and the state where the temporary is originally defined. + def_location_of_intermediate, state_to_inspect = target_location + assert state_to_inspect.out_degree(def_location_of_intermediate) == 0 + + # These are all access nodes that refers to the global data, that we want + # to move into the state `state_to_inspect`. We need them to do the + # second test. + accesses_to_global_data: set[dace_nodes.AccessNode] = set() + + # In the first check we look for an access node, to the global data, that + # has an output degree larger than one. However, for this we ignore all + # empty Memlets. This is done because such Memlets are used to induce a + # schedule or order in the dataflow graph. + # As a byproduct, for the second test, we also collect all of these nodes. + for dnode in state_to_inspect.data_nodes(): + if dnode.data != global_data_name: + continue + dnode_degree = sum( + (1 for oedge in state_to_inspect.out_edges(dnode) if not oedge.data.is_empty()) + ) + if dnode_degree > 1: + return True + # TODO(phimuell): Maybe AccessNodes with zero input degree should be ignored. + accesses_to_global_data.add(dnode) + + # There is no reference to the global data, so no need to do more tests. + if len(accesses_to_global_data) == 0: + return False + + # For the second test we will explore the dataflow graph, in reverse order, + # starting from the definition of the temporary node. If we find an access + # to the global data we remove it from the `accesses_to_global_data` list. + # If the list has not become empty, then we know that there is some sind + # branch (or concurrent dataflow) in this state that accesses the global + # data and we will have read-write conflicts. + # It is however, important to realize that passing this check does not + # imply that there are no read-write. We assume here that all accesses to + # the global data that was made before the write back were constructed in + # a correct way. + to_process: list[dace_nodes.Node] = [def_location_of_intermediate] + seen: set[dace_nodes.Node] = set() + while len(to_process) != 0: + node = to_process.pop() + seen.add(node) + + if isinstance(node, dace_nodes.AccessNode): + if node.data == global_data_name: + accesses_to_global_data.discard(node) + if len(accesses_to_global_data) == 0: + return False + + # Note that we only explore the ingoing edges, thus we will not necessarily + # explore the whole graph. However, this is fine, because we will see the + # relevant parts. To see that assume that we would also have to check the + # outgoing edges, this would mean that there was some branching point, + # which is a serialization point, so the dataflow would have been invalid + # before. + to_process.extend( + iedge.src for iedge in state_to_inspect.in_edges(node) if iedge.src not in seen + ) + + assert len(accesses_to_global_data) > 0 + return True + @dace_properties.make_properties class GT4PyMoveTaskletIntoMap(dace_transformation.SingleStateTransformation): diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py index 1543a048ad..d61b8a2d42 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py @@ -13,7 +13,7 @@ transformations as gtx_transformations, ) -# from . import util +from . import util # dace = pytest.importorskip("dace") @@ -21,8 +21,8 @@ import dace -def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG("NAME") # util.unique_name("distributed_buffer_sdfg")) +def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("distributed_buffer_sdfg")) for name in ["a", "b", "tmp"]: sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) @@ -66,19 +66,224 @@ def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: sdfg.validate() assert sdfg.number_of_nodes() == 3 - return sdfg, state1 + return sdfg, state1, state3 def test_distributed_buffer_remover(): - sdfg, state1 = _mk_distributed_buffer_sdfg() + sdfg, state1, state3 = _mk_distributed_buffer_sdfg() assert state1.number_of_nodes() == 5 assert not any(dnode.data == "b" for dnode in state1.data_nodes()) res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) - assert res is not None + assert res[sdfg]["DistributedBufferRelocator"][state3] == {"tmp"} # Because the final state has now become empty assert sdfg.number_of_nodes() == 3 assert state1.number_of_nodes() == 6 assert any(dnode.data == "b" for dnode in state1.data_nodes()) assert any(dnode.data == "tmp" for dnode in state1.data_nodes()) + + +def _make_distributed_buffer_global_memory_data_race_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_race")) + arr_names = ["a", "b", "t"] + for name in arr_names: + sdfg.add_array( + name=name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + a_state1 = state1.add_access("a") + state1.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("t[__i0, __i1]")}, + input_nodes={a_state1}, + external_edges=True, + ) + state1.add_nedge(a_state1, state1.add_access("b"), dace.Memlet("a[0:10, 0:10]")) + + state2.add_nedge(state2.add_access("t"), state2.add_access("a"), dace.Memlet("t[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state2 + + +def test_distributed_buffer_global_memory_data_race(): + """Tests if the transformation realized that it would create a data race. + + If the transformation would apply, then `a` is read twice, once from two + different branches, whose order of execution is indeterminate. + """ + sdfg, state2 = _make_distributed_buffer_global_memory_data_race_sdfg() + assert state2.number_of_nodes() == 2 + + sdfg.simplify() + assert sdfg.number_of_nodes() == 2 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert "DistributedBufferRelocator" not in res[sdfg] + assert state2.number_of_nodes() == 2 + + +def _make_distributed_buffer_global_memory_data_race_sdfg2() -> ( + tuple[dace.SDFG, dace.SDFGState, dace.SDFGState] +): + sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_race2_sdfg")) + arr_names = ["a", "b", "t"] + for name in arr_names: + sdfg.add_array( + name=name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + state1.add_mapped_tasklet( + "computation1", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("t[__i0, __i1]")}, + external_edges=True, + ) + state1.add_mapped_tasklet( + "computation1", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in - 10", + outputs={"__out": dace.Memlet("b[__i0, __i1]")}, + external_edges=True, + ) + state2.add_nedge(state2.add_access("t"), state2.add_access("a"), dace.Memlet("t[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state1, state2 + + +def test_distributed_buffer_global_memory_data_race2(): + """Tests if the transformation realized that it would create a data race. + + Similar situation but now there are two different subgraphs. This is needed + because it is another branch that checks it. + """ + sdfg, state1, state2 = _make_distributed_buffer_global_memory_data_race_sdfg2() + assert state1.number_of_nodes() == 10 + assert state2.number_of_nodes() == 2 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert "DistributedBufferRelocator" not in res[sdfg] + assert state1.number_of_nodes() == 10 + assert state2.number_of_nodes() == 2 + + +def _make_distributed_buffer_global_memory_data_no_rance() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_no_rance_sdfg")) + arr_names = ["a", "t"] + for name in arr_names: + sdfg.add_array( + name=name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + a_state1 = state1.add_access("a") + state1.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("t[__i0, __i1]")}, + input_nodes={a_state1}, + external_edges=True, + ) + + state2.add_nedge(state2.add_access("t"), state2.add_access("a"), dace.Memlet("t[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state2 + + +def test_distributed_buffer_global_memory_data_no_rance(): + """Transformation applies if there is no data race. + + According to ADR18, pointwise dependencies are fine. This tests checks if the + checks for the read-write conflicts are not too strong. + """ + sdfg, state2 = _make_distributed_buffer_global_memory_data_no_rance() + assert state2.number_of_nodes() == 2 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert res[sdfg]["DistributedBufferRelocator"][state2] == {"t"} + assert state2.number_of_nodes() == 0 + + +def _make_distributed_buffer_global_memory_data_no_rance2() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_no_rance2_sdfg")) + arr_names = ["a", "t"] + for name in arr_names: + sdfg.add_array( + name=name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + a_state1 = state1.add_access("a") + state1.add_mapped_tasklet( + "computation1", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("a[__i0, __i1]")}, + output_nodes={a_state1}, + external_edges=True, + ) + state1.add_mapped_tasklet( + "computation2", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("t[__i0, __i1]")}, + input_nodes={a_state1}, + external_edges=True, + ) + + state2.add_nedge(state2.add_access("t"), state2.add_access("a"), dace.Memlet("t[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state2 + + +def test_distributed_buffer_global_memory_data_no_rance2(): + """Transformation applies if there is no data race. + + These dependency is fine, because the access nodes are in a clear serial order. + """ + sdfg, state2 = _make_distributed_buffer_global_memory_data_no_rance2() + assert state2.number_of_nodes() == 2 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert res[sdfg]["DistributedBufferRelocator"][state2] == {"t"} + assert state2.number_of_nodes() == 0 From 489ccbb64d971dae189e056f80c71d25abf16661 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 17 Jan 2025 02:06:26 +0100 Subject: [PATCH 096/178] feat[next]: Output argument with non-zero domain start (#1780) ```python field = gtx.as_field(gtx.domain({IDim: (1, 10)}), arr) field_operator(out=field) ``` This PR also adds a test for non-zero domain start input arguments, which already worked before. --------- Co-authored-by: Edoardo Paone --- src/gt4py/next/ffront/past_process_args.py | 26 +++-- src/gt4py/next/ffront/past_to_itir.py | 37 ++++---- src/gt4py/next/otf/arguments.py | 9 +- .../runners/dace_common/utility.py | 2 +- .../runners/dace_common/workflow.py | 15 +-- .../dace_fieldview/gtir_python_codegen.py | 6 ++ .../feature_tests/dace/test_orchestration.py | 6 ++ .../feature_tests/dace/test_program.py | 3 + .../ffront_tests/test_execution.py | 2 +- .../ffront_tests/test_program.py | 41 +++++++- .../ffront_tests/test_past_to_gtir.py | 94 +++++++++++++++---- 11 files changed, 187 insertions(+), 54 deletions(-) diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index 1add668791..ea4a2995e0 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -83,41 +83,47 @@ def _process_args( # TODO(tehrengruber): Previously this function was called with the actual arguments # not their type. The check using the shape here is not functional anymore and # should instead be placed in a proper location. - shapes_and_dims = [*_field_constituents_shape_and_dims(args[param_idx], param.type)] + ranges_and_dims = [*_field_constituents_range_and_dims(args[param_idx], param.type)] # check that all non-scalar like constituents have the same shape and dimension, e.g. # for `(scalar, (field1, field2))` the two fields need to have the same shape and # dimension - if shapes_and_dims: - shape, dims = shapes_and_dims[0] + if ranges_and_dims: + range_, dims = ranges_and_dims[0] if not all( - el_shape == shape and el_dims == dims for (el_shape, el_dims) in shapes_and_dims + el_range == range_ and el_dims == dims + for (el_range, el_dims) in ranges_and_dims ): raise ValueError( "Constituents of composite arguments (e.g. the elements of a" " tuple) need to have the same shape and dimensions." ) + index_type = ts.ScalarType(kind=ts.ScalarKind.INT32) size_args.extend( - shape if shape else [ts.ScalarType(kind=ts.ScalarKind.INT32)] * len(dims) # type: ignore[arg-type] # shape is always empty + range_ if range_ else [ts.TupleType(types=[index_type, index_type])] * len(dims) # type: ignore[arg-type] # shape is always empty ) return tuple(rewritten_args), tuple(size_args), kwargs -def _field_constituents_shape_and_dims( +def _field_constituents_range_and_dims( arg: Any, # TODO(havogt): improve typing arg_type: ts.DataType, -) -> Iterator[tuple[tuple[int, ...], list[common.Dimension]]]: +) -> Iterator[tuple[tuple[tuple[int, int], ...], list[common.Dimension]]]: match arg_type: case ts.TupleType(): for el, el_type in zip(arg, arg_type.types): assert isinstance(el_type, ts.DataType) - yield from _field_constituents_shape_and_dims(el, el_type) + yield from _field_constituents_range_and_dims(el, el_type) case ts.FieldType(): dims = type_info.extract_dims(arg_type) if isinstance(arg, ts.TypeSpec): # TODO yield (tuple(), dims) elif dims: - assert hasattr(arg, "shape") and len(arg.shape) == len(dims) - yield (arg.shape, dims) + assert ( + hasattr(arg, "domain") + and isinstance(arg.domain, common.Domain) + and len(arg.domain.dims) == len(dims) + ) + yield (tuple((r.start, r.stop) for r in arg.domain.ranges), dims) else: yield from [] # ignore 0-dim fields case ts.ScalarType(): diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 4ec12bb76b..5adc229595 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -138,8 +138,8 @@ def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension] return iter(scanops_per_axis.keys()).__next__() -def _size_arg_from_field(field_name: str, dim: int) -> str: - return f"__{field_name}_size_{dim}" +def _range_arg_from_field(field_name: str, dim: int) -> str: + return f"__{field_name}_{dim}_range" def _flatten_tuple_expr(node: past.Expr) -> list[past.Name | past.Subscript]: @@ -217,13 +217,14 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: ) if len(fields_dims) > 0: # otherwise `param` has no constituent which is of `FieldType` assert all(field_dims == fields_dims[0] for field_dims in fields_dims) + index_type = ts.ScalarType( + kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + ) for dim_idx in range(len(fields_dims[0])): size_params.append( itir.Sym( - id=_size_arg_from_field(param.id, dim_idx), - type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) - ), + id=_range_arg_from_field(param.id, dim_idx), + type=ts.TupleType(types=[index_type, index_type]), ) ) @@ -286,7 +287,8 @@ def _visit_slice_bound( self, slice_bound: Optional[past.Constant], default_value: itir.Expr, - dim_size: itir.Expr, + start_idx: itir.Expr, + stop_idx: itir.Expr, **kwargs: Any, ) -> itir.Expr: if slice_bound is None: @@ -296,11 +298,9 @@ def _visit_slice_bound( slice_bound.type ) if slice_bound.value < 0: - lowered_bound = itir.FunCall( - fun=itir.SymRef(id="plus"), args=[dim_size, self.visit(slice_bound, **kwargs)] - ) + lowered_bound = im.plus(stop_idx, self.visit(slice_bound, **kwargs)) else: - lowered_bound = self.visit(slice_bound, **kwargs) + lowered_bound = im.plus(start_idx, self.visit(slice_bound, **kwargs)) else: raise AssertionError("Expected 'None' or 'past.Constant'.") if slice_bound: @@ -348,8 +348,9 @@ def _construct_itir_domain_arg( domain_args = [] domain_args_kind = [] for dim_i, dim in enumerate(out_dims): - # an expression for the size of a dimension - dim_size = itir.SymRef(id=_size_arg_from_field(out_field.id, dim_i)) + # an expression for the range of a dimension + dim_range = itir.SymRef(id=_range_arg_from_field(out_field.id, dim_i)) + dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range) # bounds lower: itir.Expr upper: itir.Expr @@ -359,11 +360,15 @@ def _construct_itir_domain_arg( else: lower = self._visit_slice_bound( slices[dim_i].lower if slices else None, - im.literal("0", itir.INTEGER_INDEX_BUILTIN), - dim_size, + dim_start, + dim_start, + dim_stop, ) upper = self._visit_slice_bound( - slices[dim_i].upper if slices else None, dim_size, dim_size + slices[dim_i].upper if slices else None, + dim_stop, + dim_start, + dim_stop, ) if dim.kind == common.DimensionKind.LOCAL: diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 69d8985beb..c4235eaa9a 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -122,7 +122,7 @@ def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]: return None -def iter_size_args(args: tuple[Any, ...]) -> Iterator[int]: +def iter_size_args(args: tuple[Any, ...]) -> Iterator[tuple[int, int]]: """ Yield the size of each field argument in each dimension. @@ -136,7 +136,9 @@ def iter_size_args(args: tuple[Any, ...]) -> Iterator[int]: if first_field: yield from iter_size_args((first_field,)) case common.Field(): - yield from arg.ndarray.shape + for range_ in arg.domain.ranges: + assert isinstance(range_, common.UnitRange) + yield (range_.start, range_.stop) case _: pass @@ -156,6 +158,7 @@ def iter_size_compile_args( ) if field_constituents: # we only need the first field, because all fields in a tuple must have the same dims and sizes + index_type = ts.ScalarType(kind=ts.ScalarKind.INT32) yield from [ - ts.ScalarType(kind=ts.ScalarKind.INT32) for dim in field_constituents[0].dims + ts.TupleType(types=[index_type, index_type]) for dim in field_constituents[0].dims ] diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index 3e99c27049..a0f7711231 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -24,7 +24,7 @@ # regex to match the symbols for field shape and strides -FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile(r"__.+_(size|stride)_\d+") +FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile(r"^__.+_((\d+_range_[01])|((size|stride)_\d+))$") def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace_common/workflow.py index f0577ffaf2..6fb7539c92 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_common/workflow.py @@ -23,15 +23,16 @@ from gt4py.next.program_processors.runners.dace_common import dace_backend, utility as dace_utils -class CompiledDaceProgram(stages.CompiledProgram): +class CompiledDaceProgram(stages.ExtendedCompiledProgram): sdfg_program: dace.CompiledSDFG # Sorted list of SDFG arguments as they appear in program ABI and corresponding data type; # scalar arguments that are not used in the SDFG will not be present. sdfg_arglist: list[tuple[str, dace.dtypes.Data]] - def __init__(self, program: dace.CompiledSDFG): + def __init__(self, program: dace.CompiledSDFG, implicit_domain: bool): self.sdfg_program = program + self.implicit_domain = implicit_domain # `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument # name to its data type, in the same order as arguments appear in the program ABI. # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. @@ -88,7 +89,7 @@ def __call__( dace.config.Config.set("compiler", "cpu", "args", value=compiler_args) sdfg_program = sdfg.compile(validate=False) - return CompiledDaceProgram(sdfg_program) + return CompiledDaceProgram(sdfg_program, inp.program_source.implicit_domain) class DaCeCompilationStepFactory(factory.Factory): @@ -113,9 +114,11 @@ def decorated_program( if out is not None: args = (*args, out) flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(tuple(args)) - if len(sdfg.arg_names) > len(flat_args): - # The Ahead-of-Time (AOT) workflow for FieldView programs requires domain size arguments. - flat_args = (*flat_args, *arguments.iter_size_args(args)) + if inp.implicit_domain: + # generate implicit domain size arguments only if necessary + size_args = arguments.iter_size_args(args) + flat_size_args: Sequence[int] = gtx_utils.flatten_nested_tuple(tuple(size_args)) + flat_args = (*flat_args, *flat_size_args) if sdfg_program._lastargs: kwargs = dict(zip(sdfg.arg_names, flat_args, strict=True)) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index 4bdb602f5f..2b3c5417cd 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -84,6 +84,11 @@ def builtin_if(*args: Any) -> str: return f"{true_val} if {cond} else {false_val}" +def builtin_tuple_get(*args: Any) -> str: + index, tuple_name = args + return f"{tuple_name}_{index}" + + def make_const_list(arg: str) -> str: """ Takes a single scalar argument and broadcasts this value on the local dimension @@ -97,6 +102,7 @@ def make_const_list(arg: str) -> str: "cast_": builtin_cast, "if_": builtin_if, "make_const_list": make_const_list, + "tuple_get": builtin_tuple_get, } diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index cd71c306eb..22af788845 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -42,6 +42,9 @@ def test_sdfgConvertible_laplap(cartesian_case): # noqa: F811 if not cartesian_case.backend or "dace" not in cartesian_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") + # TODO(edopao): add support for range symbols in field domain and re-enable this test + pytest.skip("Requires support for field domain range.") + backend = cartesian_case.backend in_field = cases.allocate(cartesian_case, laplap_program, "in_field")() @@ -87,6 +90,9 @@ def test_sdfgConvertible_connectivities(unstructured_case): # noqa: F811 if not unstructured_case.backend or "dace" not in unstructured_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") + # TODO(edopao): add support for range symbols in field domain and re-enable this test + pytest.skip("Requires support for field domain range.") + allocator, backend = unstructured_case.allocator, unstructured_case.backend if gtx_allocators.is_field_allocator_for(allocator, gtx_allocators.CUPY_DEVICE): diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py index db0f90b409..2a7a3710a9 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py @@ -90,6 +90,9 @@ def unstructured(request, gtir_dace_backend, mesh_descriptor): # noqa: F811 def test_halo_exchange_helper_attrs(unstructured): local_int = gtx.int + # TODO(edopao): add support for range symbols in field domain and re-enable this test + pytest.skip("Requires support for field domain range.") + @gtx.field_operator(backend=unstructured.backend) def testee_op( a: gtx.Field[[Vertex, KDim], gtx.int], diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 2e40cb897a..644e0c6103 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -998,7 +998,7 @@ def program_domain(a: cases.IField, out: cases.IField): a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - ref = out.asnumpy().copy() # ensure we are not overwriting out outside of the domain + ref = out.asnumpy().copy() # ensure we are not writing to out outside the domain ref[1:9] = a.asnumpy()[1:9] * 2 cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index f1cb8ffb17..27c4252e14 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -14,7 +14,7 @@ import pytest import gt4py.next as gtx -from gt4py.next import errors +from gt4py.next import errors, constructors, common from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -251,3 +251,42 @@ def empty_domain_program(a: cases.IJField, out_field: cases.IJField): ValueError, match=(r"Dimensions in out field and field domain are not equivalent") ): cases.run(cartesian_case, empty_domain_program, a, out_field, offset_provider={}) + + +@pytest.mark.uses_origin +def test_out_field_arg_with_non_zero_domain_start(cartesian_case, copy_program_def): + copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend) + + size = cartesian_case.default_sizes[IDim] + + inp = cases.allocate(cartesian_case, copy_program, "in_field").unique()() + out = constructors.empty( + common.domain({IDim: (1, size - 2)}), + allocator=cartesian_case.allocator, + ) + ref = inp.ndarray[1:-2] + + cases.verify(cartesian_case, copy_program, inp, out=out, ref=ref) + + +@pytest.mark.uses_origin +def test_in_field_arg_with_non_zero_domain_start(cartesian_case, copy_program_def): + @gtx.field_operator + def identity(a: cases.IField) -> cases.IField: + return a + + @gtx.program + def copy_program(a: cases.IField, out: cases.IField): + identity(a, out=out, domain={IDim: (1, 9)}) + + inp = constructors.empty( + common.domain({IDim: (1, 9)}), + dtype=np.int32, + allocator=cartesian_case.allocator, + ) + inp.ndarray[...] = 42 + out = cases.allocate(cartesian_case, copy_program, "out", sizes={IDim: 10})() + ref = out.asnumpy().copy() # ensure we are not writing to `out` outside the domain + ref[1:9] = inp.asnumpy() + + cases.verify(cartesian_case, copy_program, inp, out=out, ref=ref) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py index c813285bd0..fa9a0220ef 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py @@ -58,8 +58,30 @@ def test_copy_lowering(copy_program_def, gtir_identity_fundef): fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), args=[ P(itir.AxisLiteral, value="IDim"), - P(itir.Literal, value="0", type=ts.ScalarType(kind=ts.ScalarKind.INT32)), - P(itir.SymRef, id=eve.SymbolRef("__out_size_0")), + P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), + args=[ + P( + itir.Literal, + value="0", + type=ts.ScalarType(kind=ts.ScalarKind.INT32), + ), + P(itir.SymRef, id=eve.SymbolRef("__out_0_range")), + ], + ), + P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), + args=[ + P( + itir.Literal, + value="1", + type=ts.ScalarType(kind=ts.ScalarKind.INT32), + ), + P(itir.SymRef, id=eve.SymbolRef("__out_0_range")), + ], + ), ], ) ], @@ -77,8 +99,8 @@ def test_copy_lowering(copy_program_def, gtir_identity_fundef): params=[ P(itir.Sym, id=eve.SymbolName("in_field")), P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_size_0")), - P(itir.Sym, id=eve.SymbolName("__out_size_0")), + P(itir.Sym, id=eve.SymbolName("__in_field_0_range")), + P(itir.Sym, id=eve.SymbolName("__out_0_range")), ], body=[set_at_pattern], ) @@ -105,18 +127,58 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) args=[ P(itir.AxisLiteral, value="IDim"), P( - itir.Literal, - value="1", - type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) - ), + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("plus")), + args=[ + P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), + args=[ + P( + itir.Literal, + value="0", + type=ts.ScalarType(kind=ts.ScalarKind.INT32), + ), + P(itir.SymRef, id=eve.SymbolRef("__out_0_range")), + ], + ), + P( + itir.Literal, + value="1", + type=ts.ScalarType( + kind=getattr( + ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper() + ) + ), + ), + ], ), P( - itir.Literal, - value="2", - type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) - ), + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("plus")), + args=[ + P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), + args=[ + P( + itir.Literal, + value="0", + type=ts.ScalarType(kind=ts.ScalarKind.INT32), + ), + P(itir.SymRef, id=eve.SymbolRef("__out_0_range")), + ], + ), + P( + itir.Literal, + value="2", + type=ts.ScalarType( + kind=getattr( + ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper() + ) + ), + ), + ], ), ], ) @@ -129,8 +191,8 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) params=[ P(itir.Sym, id=eve.SymbolName("in_field")), P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_size_0")), - P(itir.Sym, id=eve.SymbolName("__out_size_0")), + P(itir.Sym, id=eve.SymbolName("__in_field_0_range")), + P(itir.Sym, id=eve.SymbolName("__out_0_range")), ], body=[set_at_pattern], ) From 1b17202b84b11279ddc626f5fcc28e90e84dcc7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Fri, 17 Jan 2025 07:25:51 +0100 Subject: [PATCH 097/178] fix[next][dace]: Fixing Strides Reconstruction During Propagation (#1802) NestedSDFG essentially allows to perform some slices, there are technically three chases: - The data container on the inside has a smaller rank than the one on the outside, thus some dimensions were removed. - The data container on the inside has the same rank than the one on the outside. - The data container on the inside has a larger rank than the one on the outside, thus some dimensions were added. The last case is not handled, as it does not happens in GT4Py. Before, the first and second case were handled together, but it was realized that the second case was not implemented properly and it was added explicitly. This PR fixes the issues with `TestFusedVelocityAdvectionStencil1To7` and `TestFusedVelocityAdvectionStencil8To13` in [ICON4Py#638](https://github.com/C2SM/icon4py/pull/638), however, the later test now fails with a segmentation fault. --- .../dace_fieldview/transformations/strides.py | 49 ++++++---- .../transformation_tests/test_strides.py | 96 +++++++++++++++++++ 2 files changed, 126 insertions(+), 19 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 980b2a8fdf..d1bf8fe266 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -498,31 +498,42 @@ def _gt_map_strides_into_nested_sdfg( inner_shape = inner_desc.shape inner_strides_init = inner_desc.strides + outer_shape = outer_desc.shape outer_strides = outer_desc.strides outer_inflow = outer_subset.size() - new_strides: list = [] - for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): - if dim_oinflow == 1: - # This is the case of implicit slicing along one dimension. - pass - else: - # There is inflow into the SDFG, so we need the stride. - new_strides.append(dim_ostride) - assert len(new_strides) <= len(inner_shape) - - # If we have a scalar on the inside, then there is nothing to adjust. - # We could have performed the test above, but doing it here, gives us - # the chance of validating it. if isinstance(inner_desc, dace_data.Scalar): - if len(new_strides) != 0: - raise ValueError(f"Dimensional error for '{inner_data}' in '{nsdfg_node.label}'.") + # A scalar does not have a stride that must be propagated. return - if not isinstance(inner_desc, dace_data.Array): - raise TypeError( - f"Expected that '{inner_data}' is an 'Array' but it is '{type(inner_desc).__name__}'." - ) + # Now determine the new stride that is needed on the inside. + new_strides: list = [] + if len(outer_shape) == len(inner_shape): + # The inner and the outer descriptor have the same dimensionality. + # We now have to decide if we should take the stride from the outside, + # which happens for example in case of `A[0:N, 0:M] -> B[N, M]`, or if we + # must take 1, which happens if we do `A[0:N, i] -> B[N, 1]`, we detect that + # based on the volume that flows in. + for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): + new_strides.append(1 if dim_oinflow == 1 else dim_ostride) + + elif len(inner_shape) < len(outer_shape): + # There are less dimensions on the inside than on the outside. This means + # that some were sliced away. We detect this case by checking if the Memlet + # subset in that dimension has size 1. + # NOTE: That this is not always correct as it might be possible that there + # are some explicit size 1 dimensions at several places. + new_strides = [] + for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): + if dim_oinflow == 1: + pass + else: + new_strides.append(dim_ostride) + assert len(new_strides) <= len(inner_shape) + else: + # The case that we have more dimensions on the inside than on the outside. + # This is currently not supported. + raise NotImplementedError("NestedSDFGs can not be used to increase the rank.") if len(new_strides) != len(inner_shape): raise ValueError("Failed to compute the inner strides.") diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index 5b16e41bc3..19b33d0bef 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -539,3 +539,99 @@ def ref(a1, b1): ref(**ref_args) sdfg_level1(**res_args) assert np.allclose(ref_args["b1"], res_args["b1"]) + + +def _make_strides_propagation_stride_1_nsdfg() -> dace.SDFG: + sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_stride_1_nsdfg")) + state = sdfg_level1.add_state(is_start_block=True) + + a_stride_sym = dace.symbol("a_stride", dtype=dace.uint64) + b_stride_sym = dace.symbol("b_stride", dtype=dace.uint64) + stride_syms = {"a": a_stride_sym, "b": b_stride_sym} + + for name in ["a", "b"]: + sdfg_level1.add_array( + name, + shape=(10, 1), + strides=(stride_syms[name], 1), + dtype=dace.float64, + transient=False, + ) + + state.add_mapped_tasklet( + "computation", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i, 0]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("b[__i, 0]")}, + external_edges=True, + ) + sdfg_level1.validate() + return sdfg_level1 + + +def _make_strides_propagation_stride_1_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg = dace.SDFG(util.unique_name("strides_propagation_stride_1_sdfg")) + state = sdfg.add_state(is_start_block=True) + + a_stride_sym = dace.symbol("a_stride", dtype=dace.uint64) + b_stride_sym = dace.symbol("b_stride", dtype=dace.uint64) + stride_syms = {"a": a_stride_sym, "b": b_stride_sym} + + for name in ["a", "b"]: + sdfg.add_array( + name, + shape=(10, 10), + strides=(stride_syms[name], 1), + dtype=dace.float64, + transient=False, + ) + + # Now get the nested SDFG. + sdfg_level1 = _make_strides_propagation_stride_1_nsdfg() + + nsdfg = state.add_nested_sdfg( + parent=sdfg, + sdfg=sdfg_level1, + inputs={"a"}, + outputs={"b"}, + symbol_mapping=None, + ) + + state.add_edge(state.add_access("a"), None, nsdfg, "a", dace.Memlet("a[0:10, 3]")) + state.add_edge(nsdfg, "b", state.add_access("b"), None, dace.Memlet("b[0:10, 2]")) + sdfg.validate() + return sdfg, nsdfg + + +def test_strides_propagation_stride_1(): + def ref(a, b): + for i in range(10): + b[i, 2] = a[i, 3] + 10.0 + + sdfg, nsdfg = _make_strides_propagation_stride_1_sdfg() + + outer_desc_a = sdfg.arrays["a"] + inner_desc_a = nsdfg.sdfg.arrays["a"] + assert outer_desc_a.strides == inner_desc_a.strides + + # Now switch the strides of `a` on the top level. + # Essentially going from `C` to FORTRAN order. + stride_outer_a_0, stride_outer_a_1 = outer_desc_a.strides + outer_desc_a.set_shape(outer_desc_a.shape, (stride_outer_a_1, stride_outer_a_0)) + + # Now we propagate the data into it. + gtx_transformations.gt_propagate_strides_of(sdfg=sdfg, data_name="a") + + # Because of the propagation it must now been changed to `(1, 1)` on the inside. + assert inner_desc_a.strides == (1, 1) + + res_args = { + "a": np.array(np.random.rand(10, 10), order="F", dtype=np.float64, copy=True), + "b": np.array(np.random.rand(10, 10), order="C", dtype=np.float64, copy=True), + } + ref_args = copy.deepcopy(res_args) + + sdfg(**res_args, a_stride=10, b_stride=10) + ref(**ref_args) + assert np.allclose(ref_args["b"], res_args["b"]) From 517e1e9e5e6c6e8581dada8eca1917874baf32b6 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sat, 18 Jan 2025 15:33:04 +0100 Subject: [PATCH 098/178] feat[next]: Support for direct field operator call with domain arg (#1779) Adds support for directly calling a field operator with a domain argument, which was previously only supported inside of a program. Many field operators in icon4py use the domain argument resulting in excessive amounts of boilerplate programs that can be removed now. ```python @field_operator def testee(inp: IField) -> IField: return inp testee(inp, domain={IDim: (0, 10)}) ``` Support in the dace backend is missing and will be added in a seperate PR. --------- Co-authored-by: Edoardo Paone --- src/gt4py/next/ffront/decorator.py | 4 ++++ tests/next_tests/integration_tests/cases.py | 12 ++++++++--- .../ffront_tests/test_arg_call_interface.py | 20 ++++++++++++++++++- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index d1631a461d..7e2abc44fb 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -594,6 +594,10 @@ def __call__(self, *args, **kwargs) -> None: if "out" not in kwargs: raise errors.MissingArgumentError(None, "out", True) out = kwargs.pop("out") + if "domain" in kwargs: + domain = common.domain(kwargs.pop("domain")) + out = out[domain] + args, kwargs = type_info.canonicalize_arguments( self.foast_stage.foast_node.type, args, kwargs ) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 8a78307f87..c2b98ee8d9 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -381,6 +381,7 @@ def verify( fieldview_prog: decorator.FieldOperator | decorator.Program, *args: FieldViewArg, ref: ReferenceValue, + domain: Optional[dict[common.Dimension, tuple[int, int]]] = None, out: Optional[FieldViewInout] = None, inout: Optional[FieldViewInout] = None, offset_provider: Optional[OffsetProvider] = None, @@ -405,6 +406,8 @@ def verify( or tuple of fields here and they will be compared to ``ref`` under the assumption that the fieldview code stores its results in them. + domain: If given will be passed to the fieldview code as ``domain=`` + keyword argument. offset_provider: An override for the test case's offset_provider. Use with care! comparison: A comparison function, which will be called as @@ -414,10 +417,13 @@ def verify( used as an argument to the fieldview program and compared against ``ref``. Else, ``inout`` will not be passed and compared to ``ref``. """ + kwargs = {} if out: - run(case, fieldview_prog, *args, out=out, offset_provider=offset_provider) - else: - run(case, fieldview_prog, *args, offset_provider=offset_provider) + kwargs["out"] = out + if domain: + kwargs["domain"] = domain + + run(case, fieldview_prog, *args, **kwargs, offset_provider=offset_provider) out_comp = out or inout assert out_comp is not None diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index cb535f9596..8f67c1d198 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -13,7 +13,7 @@ import numpy as np import pytest -from gt4py.next import errors +from gt4py.next import errors, common, constructors from gt4py.next.ffront.decorator import field_operator, program, scan_operator from gt4py.next.ffront.fbuiltins import broadcast, int32 @@ -296,3 +296,21 @@ def test_call_bound_program_with_already_bound_arg(cartesian_case, bound_args_te ) is not None ) + + +@pytest.mark.uses_origin +def test_direct_fo_call_with_domain_arg(cartesian_case): + @field_operator + def testee(inp: IField) -> IField: + return inp + + size = cartesian_case.default_sizes[IDim] + inp = cases.allocate(cartesian_case, testee, "inp").unique()() + out = cases.allocate( + cartesian_case, testee, cases.RETURN, strategy=cases.ConstInitializer(42) + )() + ref = inp.array_ns.zeros(size) + ref[0] = ref[-1] = 42 + ref[1:-1] = inp.ndarray[1:-1] + + cases.verify(cartesian_case, testee, inp, out=out, domain={IDim: (1, size - 1)}, ref=ref) From 0455024c0b9deaf33b35cdc04730f85d1664de9b Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 20 Jan 2025 12:38:27 +0100 Subject: [PATCH 099/178] ci[next]: Remove Github CI on GTIR branch (#1804) --- .github/workflows/_disabled/gt4py-sphinx.yml | 2 -- .github/workflows/code-quality.yml | 2 -- .github/workflows/test-cartesian-fallback.yml | 1 - .github/workflows/test-cartesian.yml | 2 -- .github/workflows/test-eve-fallback.yml | 1 - .github/workflows/test-eve.yml | 2 -- .github/workflows/test-next-fallback.yml | 1 - .github/workflows/test-next.yml | 2 -- .github/workflows/test-notebooks.yml | 2 -- .github/workflows/test-storage-fallback.yml | 1 - .github/workflows/test-storage.yml | 2 -- 11 files changed, 18 deletions(-) diff --git a/.github/workflows/_disabled/gt4py-sphinx.yml b/.github/workflows/_disabled/gt4py-sphinx.yml index cb3b275787..2533b2a42d 100644 --- a/.github/workflows/_disabled/gt4py-sphinx.yml +++ b/.github/workflows/_disabled/gt4py-sphinx.yml @@ -4,11 +4,9 @@ on: push: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 concurrency: group: ${{ github.workflow }}-${{ github.ref }} diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index ee5ccce53c..d54fea9269 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -4,11 +4,9 @@ on: push: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 jobs: code-quality: diff --git a/.github/workflows/test-cartesian-fallback.yml b/.github/workflows/test-cartesian-fallback.yml index 76fd898159..8061ca56b9 100644 --- a/.github/workflows/test-cartesian-fallback.yml +++ b/.github/workflows/test-cartesian-fallback.yml @@ -4,7 +4,6 @@ on: pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 paths: # Inverse of corresponding workflow - "src/gt4py/next/**" - "tests/next_tests/**" diff --git a/.github/workflows/test-cartesian.yml b/.github/workflows/test-cartesian.yml index fd896c3d89..4b5a790f4d 100644 --- a/.github/workflows/test-cartesian.yml +++ b/.github/workflows/test-cartesian.yml @@ -4,11 +4,9 @@ on: push: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 paths-ignore: # Skip if only gt4py.next and irrelevant doc files have been updated - "src/gt4py/next/**" - "tests/next_tests/**" diff --git a/.github/workflows/test-eve-fallback.yml b/.github/workflows/test-eve-fallback.yml index 461400423f..78f6136888 100644 --- a/.github/workflows/test-eve-fallback.yml +++ b/.github/workflows/test-eve-fallback.yml @@ -4,7 +4,6 @@ on: pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 paths-ignore: # Inverse of corresponding workflow - "src/gt4py/eve/**" - "tests/eve_tests/**" diff --git a/.github/workflows/test-eve.yml b/.github/workflows/test-eve.yml index e83c4c563b..6b9f16e29b 100644 --- a/.github/workflows/test-eve.yml +++ b/.github/workflows/test-eve.yml @@ -4,11 +4,9 @@ on: push: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 paths: # Run when gt4py.eve files (or package settings) are changed - "src/gt4py/eve/**" - "tests/eve_tests/**" diff --git a/.github/workflows/test-next-fallback.yml b/.github/workflows/test-next-fallback.yml index b8c39dc0e6..16a0cf0df3 100644 --- a/.github/workflows/test-next-fallback.yml +++ b/.github/workflows/test-next-fallback.yml @@ -4,7 +4,6 @@ on: pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 paths: # Inverse of corresponding workflow - "src/gt4py/cartesian/**" - "tests/cartesian_tests/**" diff --git a/.github/workflows/test-next.yml b/.github/workflows/test-next.yml index 1460a5bdf4..35dcfe336b 100644 --- a/.github/workflows/test-next.yml +++ b/.github/workflows/test-next.yml @@ -4,11 +4,9 @@ on: push: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 paths-ignore: # Skip if only gt4py.cartesian and irrelevant doc files have been updated - "src/gt4py/cartesian/**" - "tests/cartesian_tests/**" diff --git a/.github/workflows/test-notebooks.yml b/.github/workflows/test-notebooks.yml index 4a65b7f30d..ae45cb154d 100644 --- a/.github/workflows/test-notebooks.yml +++ b/.github/workflows/test-notebooks.yml @@ -4,11 +4,9 @@ on: push: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 jobs: test-notebooks: diff --git a/.github/workflows/test-storage-fallback.yml b/.github/workflows/test-storage-fallback.yml index 022c66b1f1..46a4442520 100644 --- a/.github/workflows/test-storage-fallback.yml +++ b/.github/workflows/test-storage-fallback.yml @@ -4,7 +4,6 @@ on: pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 paths-ignore: # Inverse of corresponding workflow - "src/gt4py/storage/**" - "src/gt4py/cartesian/backend/**" # For DaCe storages diff --git a/.github/workflows/test-storage.yml b/.github/workflows/test-storage.yml index bfe6e49d23..a7f3b69c8d 100644 --- a/.github/workflows/test-storage.yml +++ b/.github/workflows/test-storage.yml @@ -4,11 +4,9 @@ on: push: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 paths: # Run when gt4py.storage files (or package settings) are changed - "src/gt4py/storage/**" - "src/gt4py/cartesian/backend/**" # For DaCe storages From bf57c0cfeb48cb788e447c1989776f6fddd058e6 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 20 Jan 2025 16:17:52 +0100 Subject: [PATCH 100/178] feat[next][dace]: lowering of scan to SDFG (#1776) This PR contains the lowering of the scan builtin function. --------- Co-authored-by: Philip Mueller --- pyproject.toml | 1 + .../gtir_builtin_translators.py | 55 +- .../runners/dace_fieldview/gtir_dataflow.py | 35 +- .../dace_fieldview/gtir_scan_translator.py | 691 ++++++++++++++++++ .../runners/dace_fieldview/gtir_sdfg.py | 84 ++- .../runners/dace_fieldview/workflow.py | 2 +- tests/next_tests/definitions.py | 14 +- .../ffront_tests/test_execution.py | 1 + 8 files changed, 829 insertions(+), 54 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/gtir_scan_translator.py diff --git a/pyproject.toml b/pyproject.toml index 88bb2feac6..a9f62b8ae7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -252,6 +252,7 @@ markers = [ 'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields', 'uses_scalar_in_domain_and_fo', 'uses_scan: tests that uses scan', + 'uses_scan_1d_field: that that uses scan on 1D vertical field', 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', 'uses_scan_in_stencil: tests that require backend support for scan in stencil', 'uses_scan_without_field_args: tests that require calls to scan that do not have any fields as arguments', diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 4cbc737312..1b7e1e15c1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -30,6 +30,7 @@ gtir_sdfg, utility as dace_gtir_utils, ) +from gt4py.next.program_processors.runners.dace_fieldview.gtir_scan_translator import translate_scan from gt4py.next.type_system import type_info as ti, type_specifications as ts @@ -37,7 +38,7 @@ from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg -def _get_domain_indices( +def get_domain_indices( dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None ) -> dace_subsets.Indices: """ @@ -101,7 +102,7 @@ def get_local_view( if isinstance(self.gt_type, ts.FieldType): domain_dims = [dim for dim, _, _ in domain] - domain_indices = _get_domain_indices(domain_dims) + domain_indices = get_domain_indices(domain_dims) it_indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { dim: gtir_dataflow.SymbolExpr(index, INDEX_DTYPE) for dim, index in zip(domain_dims, domain_indices) @@ -232,10 +233,10 @@ def _parse_fieldop_arg( return arg.get_local_view(domain) else: # handle tuples of fields - return gtx_utils.tree_map(lambda x: x.get_local_view(domain))(arg) + return gtx_utils.tree_map(lambda targ: targ.get_local_view(domain))(arg) -def _get_field_layout( +def get_field_layout( domain: FieldopDomain, ) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr], list[dace.symbolic.SymExpr]]: """ @@ -274,7 +275,8 @@ def _create_field_operator_impl( map_exit: dace.nodes.MapExit, ) -> FieldopData: """ - Helper method to allocate a temporary array that stores one field computed by a field operator. + Helper method to allocate a temporary array that stores one field computed + by a field operator. This method is called by `_create_field_operator()`. @@ -288,17 +290,21 @@ def _create_field_operator_impl( map_exit: The `MapExit` node of the field operator map scope. Returns: - The field data descriptor, which includes the field access node in the given `state` - and the field domain offset. + The field data descriptor, which includes the field access node in the + given `state` and the field domain offset. """ dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) - domain_dims, domain_offset, domain_shape = _get_field_layout(domain) - domain_indices = _get_domain_indices(domain_dims, domain_offset) + # the memory layout of the output field follows the field operator compute domain + domain_dims, domain_offset, domain_shape = get_field_layout(domain) + domain_indices = get_domain_indices(domain_dims, domain_offset) domain_subset = dace_subsets.Range.from_indices(domain_indices) if isinstance(output_edge.result.gt_dtype, ts.ScalarType): - assert output_edge.result.gt_dtype == output_type.dtype + if output_edge.result.gt_dtype != output_type.dtype: + raise TypeError( + f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}." + ) field_dtype = output_edge.result.gt_dtype field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset) assert isinstance(dataflow_output_desc, dace.data.Scalar) @@ -306,8 +312,11 @@ def _create_field_operator_impl( else: assert isinstance(output_type.dtype, ts.ListType) assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) - assert output_edge.result.gt_dtype.element_type == output_type.dtype.element_type field_dtype = output_edge.result.gt_dtype.element_type + if field_dtype != output_type.dtype.element_type: + raise TypeError( + f"Type mismatch, expected {output_type.dtype.element_type} got {field_dtype}." + ) assert isinstance(dataflow_output_desc, dace.data.Array) assert len(dataflow_output_desc.shape) == 1 # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) @@ -339,8 +348,7 @@ def _create_field_operator( node_type: ts.FieldType | ts.TupleType, sdfg_builder: gtir_sdfg.SDFGBuilder, input_edges: Iterable[gtir_dataflow.DataflowInputEdge], - output_edges: gtir_dataflow.DataflowOutputEdge - | tuple[gtir_dataflow.DataflowOutputEdge | tuple[Any, ...], ...], + output_tree: tuple[gtir_dataflow.DataflowOutputEdge | tuple[Any, ...], ...], ) -> FieldopResult: """ Helper method to build the output of a field operator, which can consist of @@ -356,11 +364,11 @@ def _create_field_operator( node_type: The GT4Py type of the IR node that produces this field. sdfg_builder: The object used to build the map scope in the provided SDFG. input_edges: List of edges to pass input data into the dataflow. - output_edges: Single edge or tuple of edges representing the dataflow output data. + output_tree: A tree representation of the dataflow output data. Returns: - The descriptor of the field operator result, which can be either a single field - or a tuple fields. + The descriptor of the field operator result, which can be either a single + field or a tuple fields. """ # create map range corresponding to the field operator domain @@ -378,9 +386,12 @@ def _create_field_operator( edge.connect(map_entry) if isinstance(node_type, ts.FieldType): - assert isinstance(output_edges, gtir_dataflow.DataflowOutputEdge) + assert len(output_tree) == 1 and isinstance( + output_tree[0], gtir_dataflow.DataflowOutputEdge + ) + output_edge = output_tree[0] return _create_field_operator_impl( - sdfg_builder, sdfg, state, domain, output_edges, node_type, map_exit + sdfg_builder, sdfg, state, domain, output_edge, node_type, map_exit ) else: # handle tuples of fields @@ -389,7 +400,7 @@ def _create_field_operator( lambda output_edge, output_sym: _create_field_operator_impl( sdfg_builder, sdfg, state, domain, output_edge, output_sym.type, map_exit ) - )(output_edges, output_symbol_tree) + )(output_tree, output_symbol_tree) def extract_domain(node: gtir.Node) -> FieldopDomain: @@ -454,6 +465,9 @@ def translate_as_fieldop( assert len(fun_node.args) == 2 fieldop_expr, domain_expr = fun_node.args + if cpm.is_call_to(fieldop_expr, "scan"): + return translate_scan(node, sdfg, state, sdfg_builder) + if cpm.is_ref_to(fieldop_expr, "deref"): # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value @@ -627,7 +641,7 @@ def translate_index( ] output_edge = gtir_dataflow.DataflowOutputEdge(state, index_value) return _create_field_operator( - sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge + sdfg, state, domain, node.type, sdfg_builder, input_edges, (output_edge,) ) @@ -859,5 +873,6 @@ def translate_symbol_ref( translate_make_tuple, translate_tuple_get, translate_scalar_expr, + translate_scan, translate_symbol_ref, ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index d086b26a2d..2c91e2d1b3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -649,7 +649,7 @@ def _visit_if_branch( if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr], ) -> tuple[ list[DataflowInputEdge], - DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...], + tuple[DataflowOutputEdge | tuple[Any, ...], ...], ]: """ Helper method to visit an if-branch expression and lower it to a dataflow inside the given nested SDFG and state. @@ -665,7 +665,7 @@ def _visit_if_branch( Returns: A tuple containing: - the list of input edges for the parent dataflow - - the output data, in the form of a single data edge or a tuple of data edges. + - the tree representation of output data, in the form of a tuple of data edges. """ assert if_branch_state in if_sdfg.states() @@ -692,18 +692,18 @@ def _visit_if_branch( # visit each branch of the if-statement as if it was a Lambda node lambda_node = gtir.Lambda(params=lambda_params, expr=expr) - input_edges, output_edges = translate_lambda_to_dataflow( + input_edges, output_tree = translate_lambda_to_dataflow( if_sdfg, if_branch_state, self.subgraph_builder, lambda_node, args=lambda_args ) for data_node in if_branch_state.data_nodes(): - # In case tuple arguments, isolated non-transient nodes might be left in the state, - # because not all tuple fields are necessarily used in the lambda scope + # In case of tuple arguments, isolated access nodes might be left in the state, + # because not all tuple fields are necessarily used inside the lambda scope if if_branch_state.degree(data_node) == 0: assert not data_node.desc(if_sdfg).transient if_branch_state.remove_node(data_node) - return input_edges, output_edges + return input_edges, output_tree def _visit_if_branch_result( self, sdfg: dace.SDFG, state: dace.SDFGState, edge: DataflowOutputEdge, sym: gtir.Sym @@ -807,20 +807,21 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp for nstate, arg in zip([tstate, fstate], node.args[1:3]): # visit each if-branch in the corresponding state of the nested SDFG - in_edges, out_edge = self._visit_if_branch(nsdfg, nstate, arg, input_memlets) + in_edges, output_tree = self._visit_if_branch(nsdfg, nstate, arg, input_memlets) for edge in in_edges: edge.connect(map_entry=None) - if isinstance(out_edge, tuple): - assert isinstance(node.type, ts.TupleType) + if isinstance(node.type, ts.TupleType): out_symbol_tree = dace_gtir_utils.make_symbol_tree("__output", node.type) outer_value = gtx_utils.tree_map( lambda x, y, nstate=nstate: self._visit_if_branch_result(nsdfg, nstate, x, y) - )(out_edge, out_symbol_tree) + )(output_tree, out_symbol_tree) else: assert isinstance(node.type, ts.FieldType | ts.ScalarType) + assert len(output_tree) == 1 and isinstance(output_tree[0], DataflowOutputEdge) + output_edge = output_tree[0] outer_value = self._visit_if_branch_result( - nsdfg, nstate, out_edge, im.sym("__output", node.type) + nsdfg, nstate, output_edge, im.sym("__output", node.type) ) # Isolated access node will make validation fail. # Isolated access nodes can be found in `make_tuple` expressions that @@ -1777,7 +1778,7 @@ def translate_lambda_to_dataflow( ], ) -> tuple[ list[DataflowInputEdge], - DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...], + tuple[DataflowOutputEdge | tuple[Any, ...], ...], ]: """ Entry point to visit a `Lambda` node and lower it to a dataflow graph, @@ -1797,8 +1798,12 @@ def translate_lambda_to_dataflow( Returns: A tuple of two elements: - List of connections for data inputs to the dataflow. - - Output data connection. + - Tree representation of output data connections. """ taskgen = LambdaToDataflow(sdfg, state, sdfg_builder) - output_edges = taskgen.visit_let(node, args) - return taskgen.input_edges, output_edges + lambda_output = taskgen.visit_let(node, args) + + if isinstance(lambda_output, DataflowOutputEdge): + return taskgen.input_edges, (lambda_output,) + else: + return taskgen.input_edges, lambda_output diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_scan_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_scan_translator.py new file mode 100644 index 0000000000..e105030908 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_scan_translator.py @@ -0,0 +1,691 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements the lowering of scan field operator. + +This builtin translator implements the `PrimitiveTranslator` protocol as other +translators in `gtir_builtin_translators` module. This module implements the scan +translator, separately from the `gtir_builtin_translators` module, because the +parsing of input arguments as well as the construction of the map scope differ +from a regular field operator, which requires slightly different helper methods. +Besides, the function code is quite large, another reason to keep it separate +from other translators. + +The current GTIR representation of the scan operator is based on iterator view. +This is likely to change in the future, to enable GTIR optimizations for scan. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Iterable, Optional + +import dace +from dace import subsets as dace_subsets + +from gt4py.next import common as gtx_common, utils as gtx_utils +from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.program_processors.runners.dace_common import utility as dace_utils +from gt4py.next.program_processors.runners.dace_fieldview import ( + gtir_builtin_translators as gtir_translators, + gtir_dataflow, + gtir_sdfg, + utility as dace_gtir_utils, +) +from gt4py.next.type_system import type_info as ti, type_specifications as ts + + +if TYPE_CHECKING: + from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg + + +def _parse_scan_fieldop_arg( + node: gtir.Expr, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, + domain: gtir_translators.FieldopDomain, +) -> gtir_dataflow.MemletExpr | tuple[gtir_dataflow.MemletExpr | tuple[Any, ...], ...]: + """Helper method to visit an expression passed as argument to a scan field operator. + + On the innermost level, a scan operator is lowered to a loop region which computes + column elements in the vertical dimension. + + It differs from the helper method `gtir_builtin_translators` in that field arguments + are passed in full shape along the vertical dimension, rather than as iterator. + """ + + def _parse_fieldop_arg_impl( + arg: gtir_translators.FieldopData, + ) -> gtir_dataflow.MemletExpr: + arg_expr = arg.get_local_view(domain) + if isinstance(arg_expr, gtir_dataflow.MemletExpr): + return arg_expr + # In scan field operator, the arguments to the vertical stencil are passed by value. + # Therefore, the full field shape is passed as `MemletExpr` rather than `IteratorExpr`. + return gtir_dataflow.MemletExpr( + arg_expr.field, arg_expr.gt_dtype, arg_expr.get_memlet_subset(sdfg) + ) + + arg = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) + + if isinstance(arg, gtir_translators.FieldopData): + return _parse_fieldop_arg_impl(arg) + else: + # handle tuples of fields + return gtx_utils.tree_map(lambda x: _parse_fieldop_arg_impl(x))(arg) + + +def _create_scan_field_operator_impl( + sdfg_builder: gtir_sdfg.SDFGBuilder, + sdfg: dace.SDFG, + state: dace.SDFGState, + domain: gtir_translators.FieldopDomain, + output_edge: gtir_dataflow.DataflowOutputEdge, + output_type: ts.FieldType, + map_exit: dace.nodes.MapExit, +) -> gtir_translators.FieldopData: + """ + Helper method to allocate a temporary array that stores one field computed + by the scan field operator. + + This method is called by `_create_scan_field_operator()`. + + Similar to `gtir_builtin_translators._create_field_operator_impl()` but + for scan field operators. It differs in that the scan loop region produces + a field along the vertical dimension, rather than a single point. + Therefore, the memlet subset will write a slice into the result array, that + corresponds to the full vertical shape for each horizontal grid point. + + Refer to `gtir_builtin_translators._create_field_operator_impl()` for + the description of function arguments and return values. + """ + dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) + assert isinstance(dataflow_output_desc, dace.data.Array) + + # the memory layout of the output field follows the field operator compute domain + domain_dims, domain_offset, domain_shape = gtir_translators.get_field_layout(domain) + domain_indices = gtir_translators.get_domain_indices(domain_dims, domain_offset) + domain_subset = dace_subsets.Range.from_indices(domain_indices) + + # the vertical dimension used as scan column is computed by the `LoopRegion` + # inside the map scope, therefore it is excluded from the map range + scan_dim_index = [sdfg_builder.is_column_axis(dim) for dim in domain_dims].index(True) + + # the map scope writes the full-shape dimension corresponding to the scan column + field_subset = ( + dace_subsets.Range(domain_subset[:scan_dim_index]) + + dace_subsets.Range.from_string(f"0:{dataflow_output_desc.shape[0]}") + + dace_subsets.Range(domain_subset[scan_dim_index + 1 :]) + ) + + if isinstance(output_edge.result.gt_dtype, ts.ScalarType): + assert isinstance(output_type.dtype, ts.ScalarType) + if output_edge.result.gt_dtype != output_type.dtype: + raise TypeError( + f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}." + ) + field_dtype = output_edge.result.gt_dtype + field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset) + # the scan field operator computes a column of scalar values + assert len(dataflow_output_desc.shape) == 1 + else: + assert isinstance(output_type.dtype, ts.ListType) + assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) + field_dtype = output_edge.result.gt_dtype.element_type + if field_dtype != output_type.dtype.element_type: + raise TypeError( + f"Type mismatch, expected {output_type.dtype.element_type} got {field_dtype}." + ) + # the scan field operator computes a list of scalar values for each column level + # 1st dim: column level, 2nd dim: list of scalar values (e.g. `neighbors`) + assert len(dataflow_output_desc.shape) == 2 + # the lines below extend the array with the local dimension added by the field operator + assert output_edge.result.gt_dtype.offset_type is not None + field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type] + field_shape = [*domain_shape, dataflow_output_desc.shape[1]] + field_offset = [*domain_offset, dataflow_output_desc.offset[1]] + field_subset = field_subset + dace_subsets.Range.from_string( + f"0:{dataflow_output_desc.shape[1]}" + ) + + # allocate local temporary storage + assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype) + field_name, field_desc = sdfg_builder.add_temp_array( + sdfg, field_shape, dataflow_output_desc.dtype + ) + # the inner and outer strides have to match + scan_output_stride = field_desc.strides[scan_dim_index] + # also consider the stride of the local dimension, in case the scan field operator computes a list + local_strides = field_desc.strides[len(domain_dims) :] + assert len(local_strides) == (1 if isinstance(output_edge.result.gt_dtype, ts.ListType) else 0) + new_inner_strides = [scan_output_stride, *local_strides] + dataflow_output_desc.set_shape(dataflow_output_desc.shape, new_inner_strides) + + # and here the edge writing the dataflow result data through the map exit node + field_node = state.add_access(field_name) + output_edge.connect(map_exit, field_node, field_subset) + + return gtir_translators.FieldopData( + field_node, + ts.FieldType(field_dims, field_dtype), + offset=(field_offset if set(field_offset) != {0} else None), + ) + + +def _create_scan_field_operator( + sdfg: dace.SDFG, + state: dace.SDFGState, + domain: gtir_translators.FieldopDomain, + node_type: ts.FieldType | ts.TupleType, + sdfg_builder: gtir_sdfg.SDFGBuilder, + input_edges: Iterable[gtir_dataflow.DataflowInputEdge], + output_tree: gtir_dataflow.DataflowOutputEdge + | tuple[gtir_dataflow.DataflowOutputEdge | tuple[Any, ...], ...], +) -> gtir_translators.FieldopResult: + """ + Helper method to build the output of a field operator, which can consist of + a single field or a tuple of fields. + + Similar to `gtir_builtin_translators._create_field_operator()` but for scan + field operators. The main difference is that the scan vertical dimension is + excluded from the map range. This because the vertical dimension is traversed + by a loop region in a mapped nested SDFG. + + Refer to `gtir_builtin_translators._create_field_operator()` for the + description of function arguments and return values. + """ + domain_dims, _, _ = gtir_translators.get_field_layout(domain) + + # create a map scope to execute the `LoopRegion` over the horizontal domain + if len(domain_dims) == 1: + # We construct the scan field operator on the horizontal domain, while the + # vertical dimension (the column axis) is computed by the loop region. + # If the field operator computes only the column axis (a 1d scan field operator), + # there is no horizontal domain, therefore the map scope is not needed. + # This case currently produces wrong CUDA code because of a DaCe issue + # (see https://github.com/GridTools/gt4py/issues/1136). + # The corresponding GT4Py tests are disabled (pytest marker `uses_scan_1d_field`). + map_entry, map_exit = (None, None) + else: + # create map range corresponding to the field operator domain + map_entry, map_exit = sdfg_builder.add_map( + "fieldop", + state, + ndrange={ + dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" + for dim, lower_bound, upper_bound in domain + if not sdfg_builder.is_column_axis(dim) + }, + ) + + # here we setup the edges passing through the map entry node + for edge in input_edges: + edge.connect(map_entry) + + if isinstance(node_type, ts.FieldType): + assert isinstance(output_tree, gtir_dataflow.DataflowOutputEdge) + return _create_scan_field_operator_impl( + sdfg_builder, sdfg, state, domain, output_tree, node_type, map_exit + ) + else: + # handle tuples of fields + # the symbol name 'x' in the call below is not used, we only need + # the tree structure of the `TupleType` definition to pass to `tree_map()` + output_symbol_tree = dace_gtir_utils.make_symbol_tree("x", node_type) + return gtx_utils.tree_map( + lambda output_edge, output_sym: ( + _create_scan_field_operator_impl( + sdfg_builder, + sdfg, + state, + domain, + output_edge, + output_sym.type, + map_exit, + ) + ) + )(output_tree, output_symbol_tree) + + +def _scan_input_name(input_name: str) -> str: + """ + Helper function to make naming of input connectors in the scan nested SDFG + consistent throughut this module scope. + """ + return f"__gtir_scan_input_{input_name}" + + +def _scan_output_name(input_name: str) -> str: + """ + Same as above, but for the output connecters in the scan nested SDFG. + """ + return f"__gtir_scan_output_{input_name}" + + +def _lower_lambda_to_nested_sdfg( + lambda_node: gtir.Lambda, + sdfg: dace.SDFG, + sdfg_builder: gtir_sdfg.SDFGBuilder, + domain: gtir_translators.FieldopDomain, + init_data: gtir_translators.FieldopResult, + lambda_symbols: dict[str, ts.DataType], + lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]], + scan_forward: bool, + scan_carry_symbol: gtir.Sym, +) -> tuple[dace.SDFG, gtir_translators.FieldopResult]: + """ + Helper method to lower the lambda node representing the scan stencil dataflow + inside a separate SDFG. + + In regular field operators, where the computation of a grid point is independent + from other points, therefore the stencil can be lowered to a mapped tasklet + dataflow, and the map range is defined on the full domain. + The scan field operator has to carry an intermediate result while the stencil + is applied on vertical levels, which is input to the computation of next level + (an accumulator function, for example). Therefore, the points on the vertical + dimension are computed inside a `LoopRegion` construct. + This function creates the `LoopRegion` inside a nested SDFG, which will be + mapped by the caller to the horizontal domain in the field operator context. + + Args: + lambda_node: The lambda representing the stencil expression on the horizontal level. + sdfg: The SDFG where the scan field operator is translated. + sdfg_builder: The SDFG builder object to access the field operator context. + domain: The field operator domain, with all horizontal and vertical dimensions. + init_data: The data produced in the field operator context that is used + to initialize the scan carry value. + lambda_symbols: List of symbols used as parameters of the stencil expressions. + lambda_field_offsets: Mapping from symbol name to field origin, + `None` if field origin is 0 in all dimensions. + scan_forward: When True, the loop should range starting from the origin; + when False, traverse towards origin. + scan_carry_symbol: The symbol used in the stencil expression to carry the + intermediate result along the vertical dimension. + + Returns: + A tuple of two elements: + - An SDFG containing the `LoopRegion` computation along the vertical + dimension, to be instantied as a nested SDFG in the field operator context. + - The inner fields, that is 1d arrays with vertical shape containing + the output of the stencil computation. These fields will have to be + mapped to outer arrays by the caller. The caller is responsible to ensure + that inner and outer arrays use the same strides. + """ + + # the lambda expression, i.e. body of the scan, will be created inside a nested SDFG. + nsdfg = dace.SDFG(sdfg_builder.unique_nsdfg_name(sdfg, "scan")) + nsdfg.debuginfo = dace_utils.debug_info(lambda_node, default=sdfg.debuginfo) + lambda_translator = sdfg_builder.setup_nested_context( + lambda_node, nsdfg, lambda_symbols, lambda_field_offsets + ) + + # use the vertical dimension in the domain as scan dimension + scan_domain = [ + (dim, lower_bound, upper_bound) + for dim, lower_bound, upper_bound in domain + if sdfg_builder.is_column_axis(dim) + ] + assert len(scan_domain) == 1 + scan_dim, scan_lower_bound, scan_upper_bound = scan_domain[0] + + # extract the scan loop range + scan_loop_var = dace_gtir_utils.get_map_variable(scan_dim) + + # in case the scan operator computes a list (not a scalar), we need to add an extra dimension + def get_scan_output_shape( + scan_init_data: gtir_translators.FieldopData, + ) -> list[dace.symbolic.SymExpr]: + scan_column_size = scan_upper_bound - scan_lower_bound + if isinstance(scan_init_data.gt_type, ts.ScalarType): + return [scan_column_size] + assert isinstance(scan_init_data.gt_type, ts.ListType) + assert scan_init_data.gt_type.offset_type + offset_type = scan_init_data.gt_type.offset_type + offset_provider_type = sdfg_builder.get_offset_provider_type(offset_type.value) + assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) + list_size = offset_provider_type.max_neighbors + return [scan_column_size, dace.symbolic.SymExpr(list_size)] + + if isinstance(init_data, tuple): + lambda_result_shape = gtx_utils.tree_map(get_scan_output_shape)(init_data) + else: + lambda_result_shape = get_scan_output_shape(init_data) + + # Create the body of the initialization state + # This dataflow will write the initial value of the scan carry variable. + init_state = nsdfg.add_state("scan_init", is_start_block=True) + scan_carry_input = ( + dace_gtir_utils.make_symbol_tree(scan_carry_symbol.id, scan_carry_symbol.type) + if isinstance(scan_carry_symbol.type, ts.TupleType) + else scan_carry_symbol + ) + + def init_scan_carry(sym: gtir.Sym) -> None: + scan_carry_dataname = str(sym.id) + scan_carry_desc = nsdfg.data(scan_carry_dataname) + input_scan_carry_dataname = _scan_input_name(scan_carry_dataname) + input_scan_carry_desc = scan_carry_desc.clone() + nsdfg.add_datadesc(input_scan_carry_dataname, input_scan_carry_desc) + scan_carry_desc.transient = True + init_state.add_nedge( + init_state.add_access(input_scan_carry_dataname), + init_state.add_access(scan_carry_dataname), + nsdfg.make_array_memlet(input_scan_carry_dataname), + ) + + if isinstance(scan_carry_input, tuple): + gtx_utils.tree_map(init_scan_carry)(scan_carry_input) + else: + init_scan_carry(scan_carry_input) + + # Create a loop region over the vertical dimension corresponding to the scan column + if scan_forward: + scan_loop = dace.sdfg.state.LoopRegion( + label="scan", + condition_expr=f"{scan_loop_var} < {scan_upper_bound}", + loop_var=scan_loop_var, + initialize_expr=f"{scan_loop_var} = {scan_lower_bound}", + update_expr=f"{scan_loop_var} = {scan_loop_var} + 1", + inverted=False, + ) + else: + scan_loop = dace.sdfg.state.LoopRegion( + label="scan", + condition_expr=f"{scan_loop_var} >= {scan_lower_bound}", + loop_var=scan_loop_var, + initialize_expr=f"{scan_loop_var} = {scan_upper_bound} - 1", + update_expr=f"{scan_loop_var} = {scan_loop_var} - 1", + inverted=False, + ) + nsdfg.add_node(scan_loop) + nsdfg.add_edge(init_state, scan_loop, dace.InterstateEdge()) + + # Inside the loop region, create a 'compute' and an 'update' state. + # The body of the 'compute' state implements the stencil expression for one vertical level. + # The 'update' state writes the value computed by the stencil into the scan carry variable, + # in order to make it available to the next vertical level. + compute_state = scan_loop.add_state("scan_compute") + update_state = scan_loop.add_state_after(compute_state, "scan_update") + + # inside the 'compute' state, visit the list of arguments to be passed to the stencil + stencil_args = [ + _parse_scan_fieldop_arg(im.ref(p.id), nsdfg, compute_state, lambda_translator, domain) + for p in lambda_node.params + ] + # stil inside the 'compute' state, generate the dataflow representing the stencil + # to be applied on the horizontal domain + lambda_input_edges, lambda_result = gtir_dataflow.translate_lambda_to_dataflow( + nsdfg, compute_state, lambda_translator, lambda_node, args=stencil_args + ) + # connect the dataflow input directly to the source data nodes, without passing through a map node; + # the reason is that the map for horizontal domain is outside the scan loop region + for edge in lambda_input_edges: + edge.connect(map_entry=None) + # connect the dataflow output nodes, called 'scan_result' below, to a global field called 'output' + output_column_index = dace.symbolic.pystr_to_symbolic(scan_loop_var) - scan_lower_bound + + def connect_scan_output( + scan_output_edge: gtir_dataflow.DataflowOutputEdge, + scan_output_shape: list[dace.symbolic.SymExpr], + scan_carry_sym: gtir.Sym, + ) -> gtir_translators.FieldopData: + scan_result = scan_output_edge.result + if isinstance(scan_result.gt_dtype, ts.ScalarType): + assert scan_result.gt_dtype == scan_carry_sym.type + # the scan field operator computes a column of scalar values + assert len(scan_output_shape) == 1 + output_subset = dace_subsets.Range.from_string(str(output_column_index)) + else: + assert isinstance(scan_carry_sym.type, ts.ListType) + assert scan_result.gt_dtype.element_type == scan_carry_sym.type.element_type + # the scan field operator computes a list of scalar values for each column level + assert len(scan_output_shape) == 2 + output_subset = dace_subsets.Range.from_string( + f"{output_column_index}, 0:{scan_output_shape[1]}" + ) + scan_result_data = scan_result.dc_node.data + scan_result_desc = scan_result.dc_node.desc(nsdfg) + + # `sym` represents the global output data, that is the nested-SDFG output connector + scan_carry_data = str(scan_carry_sym.id) + output = _scan_output_name(scan_carry_data) + nsdfg.add_array(output, scan_output_shape, scan_result_desc.dtype) + output_node = compute_state.add_access(output) + + # in the 'compute' state, we write the current vertical level data to the output field + # (the output field is mapped to an external array) + compute_state.add_nedge( + scan_result.dc_node, output_node, dace.Memlet(data=output, subset=output_subset) + ) + + # in the 'update' state, the value of the current vertical level is written + # to the scan carry variable for the next loop iteration + update_state.add_nedge( + update_state.add_access(scan_result_data), + update_state.add_access(scan_carry_data), + dace.Memlet.from_array(scan_result_data, scan_result_desc), + ) + + output_type = ts.FieldType(dims=[scan_dim], dtype=scan_result.gt_dtype) + return gtir_translators.FieldopData(output_node, output_type, offset=scan_lower_bound) + + # write the stencil result (value on one vertical level) into a 1D field + # with full vertical shape representing one column + if isinstance(scan_carry_input, tuple): + assert isinstance(lambda_result_shape, tuple) + lambda_output = gtx_utils.tree_map(connect_scan_output)( + lambda_result, lambda_result_shape, scan_carry_input + ) + else: + assert isinstance(lambda_result[0], gtir_dataflow.DataflowOutputEdge) + assert isinstance(lambda_result_shape, list) + lambda_output = connect_scan_output(lambda_result[0], lambda_result_shape, scan_carry_input) + + # in case tuples are passed as argument, isolated access nodes might be left in the state, + # because not all tuple fields are necessarily accessed inside the lambda scope + for data_node in compute_state.data_nodes(): + data_desc = data_node.desc(nsdfg) + if compute_state.degree(data_node) == 0: + # By construction there should never be isolated transient nodes. + # Therefore, the assert below implements a sanity check, that allows + # the exceptional case (encountered in one GT4Py test) where the carry + # variable is not used, so not a scan indeed because no data dependency. + assert (not data_desc.transient) or data_node.data.startswith(scan_carry_symbol.id) + compute_state.remove_node(data_node) + + return nsdfg, lambda_output + + +def _connect_nested_sdfg_output_to_temporaries( + sdfg: dace.SDFG, + nsdfg: dace.SDFG, + nsdfg_node: dace.nodes.NestedSDFG, + outer_state: dace.SDFGState, + inner_data: gtir_translators.FieldopData, +) -> gtir_dataflow.DataflowOutputEdge: + """ + Helper function to create the edges to write output data from the nested SDFG + to temporary arrays in the parent SDFG, denoted as outer context. + + Args: + sdfg: The SDFG representing the outer context, where the field operator is translated. + nsdfg: The SDFG where the scan `LoopRegion` is translated. + nsdfg_node: The nested SDFG node in the outer context. + outer_state: The state in outer context where the field operator is translated. + inner_data: The data produced by the scan `LoopRegion` in the inner context. + + Returns: + An object representing the output data connection of this field operator. + """ + assert isinstance(inner_data.gt_type, ts.FieldType) + inner_dataname = inner_data.dc_node.data + inner_desc = nsdfg.data(inner_dataname) + outer_dataname, outer_desc = sdfg.add_temp_transient_like(inner_desc) + outer_node = outer_state.add_access(outer_dataname) + outer_state.add_edge( + nsdfg_node, + inner_dataname, + outer_node, + None, + dace.Memlet.from_array(outer_dataname, outer_desc), + ) + output_expr = gtir_dataflow.ValueExpr(outer_node, inner_data.gt_type.dtype) + return gtir_dataflow.DataflowOutputEdge(outer_state, output_expr) + + +def translate_scan( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> gtir_translators.FieldopResult: + """ + Generates the dataflow subgraph for the `as_fieldop` builtin with a scan operator. + + It differs from `translate_as_fieldop()` in that the horizontal domain is lowered + to a map scope, while the scan column computation is lowered to a `LoopRegion` + on the vertical dimension, that is inside the horizontal map. + The current design choice is to keep the map scope on the outer level, and + the `LoopRegion` inside. This choice follows the GTIR representation where + the `scan` operator is called inside the `as_fieldop` node. + + Implements the `PrimitiveTranslator` protocol. + """ + assert isinstance(node, gtir.FunCall) + assert cpm.is_call_to(node.fun, "as_fieldop") + assert isinstance(node.type, (ts.FieldType, ts.TupleType)) + + fun_node = node.fun + assert len(fun_node.args) == 2 + scan_expr, domain_expr = fun_node.args + assert cpm.is_call_to(scan_expr, "scan") + + # parse the domain of the scan field operator + domain = gtir_translators.extract_domain(domain_expr) + + # parse scan parameters + assert len(scan_expr.args) == 3 + stencil_expr = scan_expr.args[0] + assert isinstance(stencil_expr, gtir.Lambda) + + # params[0]: the lambda parameter to propagate the scan carry on the vertical dimension + scan_carry = str(stencil_expr.params[0].id) + + # params[1]: boolean flag for forward/backward scan + assert isinstance(scan_expr.args[1], gtir.Literal) and ti.is_logical(scan_expr.args[1].type) + scan_forward = scan_expr.args[1].value == "True" + + # params[2]: the expression that computes the value for scan initialization + init_expr = scan_expr.args[2] + # visit the initialization value of the scan expression + init_data = sdfg_builder.visit(init_expr, sdfg=sdfg, head_state=state) + # extract type definition of the scan carry + scan_carry_type = ( + init_data.gt_type + if isinstance(init_data, gtir_translators.FieldopData) + else gtir_translators.get_tuple_type(init_data) + ) + + # define the set of symbols available in the lambda context, which consists of + # the carry argument and all lambda function arguments + lambda_arg_types = [scan_carry_type] + [ + arg.type for arg in node.args if isinstance(arg.type, ts.DataType) + ] + lambda_symbols = { + str(p.id): arg_type + for p, arg_type in zip(stencil_expr.params, lambda_arg_types, strict=True) + } + + # visit the arguments to be passed to the lambda expression + # this must be executed before visiting the lambda expression, in order to populate + # the data descriptor with the correct field domain offsets for field arguments + lambda_args = [sdfg_builder.visit(arg, sdfg=sdfg, head_state=state) for arg in node.args] + lambda_args_mapping = { + _scan_input_name(scan_carry): init_data, + } | { + str(param.id): arg for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) + } + + # parse the dataflow input and output symbols + lambda_flat_args: dict[str, gtir_translators.FieldopData] = {} + # the field offset is set to `None` when it is zero in all dimensions + lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = {} + for param, outer_arg in lambda_args_mapping.items(): + tuple_fields = gtir_translators.flatten_tuples(param, outer_arg) + lambda_field_offsets |= {tsym: tfield.offset for tsym, tfield in tuple_fields} + lambda_flat_args |= dict(tuple_fields) + if isinstance(scan_carry_type, ts.TupleType): + lambda_flat_outs = { + str(sym.id): sym.type + for sym in dace_gtir_utils.flatten_tuple_fields( + _scan_output_name(scan_carry), scan_carry_type + ) + } + else: + lambda_flat_outs = {_scan_output_name(scan_carry): scan_carry_type} + + # lower the scan stencil expression in a separate SDFG context + nsdfg, lambda_output = _lower_lambda_to_nested_sdfg( + stencil_expr, + sdfg, + sdfg_builder, + domain, + init_data, + lambda_symbols, + lambda_field_offsets, + scan_forward, + im.sym(scan_carry, scan_carry_type), + ) + + # build the mapping of symbols from nested SDFG to field operator context + nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} + for inner_dataname, outer_arg in lambda_flat_args.items(): + inner_desc = nsdfg.data(inner_dataname) + outer_desc = outer_arg.dc_node.desc(sdfg) + nsdfg_symbols_mapping |= { + str(nested_symbol): parent_symbol + for nested_symbol, parent_symbol in zip( + [*inner_desc.shape, *inner_desc.strides], + [*outer_desc.shape, *outer_desc.strides], + strict=True, + ) + if dace.symbolic.issymbolic(nested_symbol) + } + + # the scan nested SDFG is ready: it is instantiated in the field operator context + # where the map scope over the horizontal domain lives + nsdfg_node = state.add_nested_sdfg( + nsdfg, + sdfg, + inputs=set(lambda_flat_args.keys()), + outputs=set(lambda_flat_outs.keys()), + symbol_mapping=nsdfg_symbols_mapping, + ) + + lambda_input_edges = [] + for input_connector, outer_arg in lambda_flat_args.items(): + arg_desc = outer_arg.dc_node.desc(sdfg) + input_subset = dace_subsets.Range.from_array(arg_desc) + input_edge = gtir_dataflow.MemletInputEdge( + state, outer_arg.dc_node, input_subset, nsdfg_node, input_connector + ) + lambda_input_edges.append(input_edge) + + # for output connections, we create temporary arrays that contain the computation + # results of a column slice for each point in the horizontal domain + lambda_output_tree = gtx_utils.tree_map( + lambda lambda_output_data: _connect_nested_sdfg_output_to_temporaries( + sdfg, nsdfg, nsdfg_node, state, lambda_output_data + ) + )(lambda_output) + + # we call a helper method to create a map scope that will compute the entire field + return _create_scan_field_operator( + sdfg, state, domain, node.type, sdfg_builder, lambda_input_edges, lambda_output_tree + ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 23a36ba79f..1a48959f8c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -22,6 +22,7 @@ from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union import dace +from dace.sdfg import utils as dace_sdfg_utils from gt4py import eve from gt4py.eve import concepts @@ -130,6 +131,38 @@ def get_symbol_type(self, symbol_name: str) -> ts.DataType: """Retrieve the GT4Py type of a symbol used in the SDFG.""" ... + @abc.abstractmethod + def is_column_axis(self, dim: gtx_common.Dimension) -> bool: + """Check if the given dimension is the column axis.""" + ... + + @abc.abstractmethod + def setup_nested_context( + self, + expr: gtir.Expr, + sdfg: dace.SDFG, + global_symbols: dict[str, ts.DataType], + field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]], + ) -> SDFGBuilder: + """ + Create an SDFG context to translate a nested expression, indipendent + from the current context where the parent expression is being translated. + + This method will setup the global symbols, that correspond to the parameters + of the expression to be lowered, as well as the set of symbolic arguments, + that is scalar values used in internal domain expressions. + + Args: + expr: The nested expresson to be lowered. + sdfg: The SDFG where to lower the nested expression. + global_symbols: Mapping from symbol name to GTIR data type. + field_offsets: Mapping from symbol name to field origin, `None` if field origin is 0 in all dimensions. + + Returns: + A visitor object implementing the `SDFGBuilder` protocol. + """ + ... + @abc.abstractmethod def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: """Visit a node of the GT4Py IR.""" @@ -183,6 +216,7 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """ offset_provider_type: gtx_common.OffsetProviderType + column_axis: Optional[gtx_common.Dimension] global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=dict) field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = dataclasses.field( default_factory=dict @@ -209,6 +243,29 @@ def make_field( def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.global_symbols[symbol_name] + def is_column_axis(self, dim: gtx_common.Dimension) -> bool: + assert self.column_axis + return dim == self.column_axis + + def setup_nested_context( + self, + expr: gtir.Expr, + sdfg: dace.SDFG, + global_symbols: dict[str, ts.DataType], + field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]], + ) -> SDFGBuilder: + nsdfg_builder = GTIRToSDFG( + self.offset_provider_type, self.column_axis, global_symbols, field_offsets + ) + nsdfg_params = [ + gtir.Sym(id=p_name, type=p_type) for p_name, p_type in global_symbols.items() + ] + domain_symbols = _collect_symbols_in_domain_expressions(expr, nsdfg_params) + nsdfg_builder._add_sdfg_params( + sdfg, node_params=nsdfg_params, symbolic_arguments=domain_symbols + ) + return nsdfg_builder + def unique_nsdfg_name(self, sdfg: dace.SDFG, prefix: str) -> str: nsdfg_list = [ nsdfg.label for nsdfg in sdfg.all_sdfgs_recursive() if nsdfg.label.startswith(prefix) @@ -296,9 +353,9 @@ def _add_storage( if len(gt_type.dims) == 0: # represent zero-dimensional fields as scalar arguments return self._add_storage(sdfg, symbolic_arguments, name, gt_type.dtype, transient) + if not isinstance(gt_type.dtype, ts.ScalarType): + raise ValueError(f"Field type '{gt_type.dtype}' not supported.") # handle default case: field with one or more dimensions - # ListType not supported: concept is represented as Field with local Dimension - assert isinstance(gt_type.dtype, ts.ScalarType) dc_dtype = dace_utils.as_dace_type(gt_type.dtype) # Use symbolic shape, which allows to invoke the program with fields of different size; # and symbolic strides, which enables decoupling the memory layout from generated code. @@ -391,6 +448,7 @@ def _add_sdfg_params( except when they are listed in 'symbolic_arguments', in which case they will be represented in the SDFG as DaCe symbols. """ + # add non-transient arrays and/or SDFG symbols for the program arguments sdfg_args = [] for param in node_params: @@ -674,19 +732,10 @@ def get_field_domain_offset( lambda_field_offsets |= get_field_domain_offset(p_name, p_type) # lower let-statement lambda node as a nested SDFG - lambda_translator = GTIRToSDFG( - self.offset_provider_type, lambda_symbols, lambda_field_offsets - ) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) nsdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) - - # add sdfg storage for the symbols that need to be passed as input parameters - lambda_params = [ - gtir.Sym(id=p_name, type=p_type) for p_name, p_type in lambda_symbols.items() - ] - lambda_domain_symbols = _collect_symbols_in_domain_expressions(node.expr, lambda_params) - lambda_translator._add_sdfg_params( - nsdfg, node_params=lambda_params, symbolic_arguments=lambda_domain_symbols + lambda_translator = self.setup_nested_context( + node.expr, nsdfg, lambda_symbols, lambda_field_offsets ) nstate = nsdfg.add_state("lambda") @@ -723,7 +772,7 @@ def get_field_domain_offset( [*datadesc.shape, *datadesc.strides], strict=True, ) - if isinstance(nested_symbol, dace.symbol) + if dace.symbolic.issymbolic(nested_symbol) } else: dataname = nsdfg_dataname @@ -851,6 +900,7 @@ def visit_SymRef( def build_sdfg_from_gtir( ir: gtir.Program, offset_provider_type: gtx_common.OffsetProviderType, + column_axis: Optional[gtx_common.Dimension] = None, ) -> dace.SDFG: """ Receives a GTIR program and lowers it to a DaCe SDFG. @@ -861,6 +911,7 @@ def build_sdfg_from_gtir( Args: ir: The GTIR program node to be lowered to SDFG offset_provider_type: The definitions of offset providers used by the program node + column_axis: Vertical dimension used for column scan expressions. Returns: An SDFG in the DaCe canonical form (simplified) @@ -875,8 +926,11 @@ def build_sdfg_from_gtir( # Here we find new names for invalid symbols present in the IR. ir = dace_gtir_utils.replace_invalid_symbols(ir) - sdfg_genenerator = GTIRToSDFG(offset_provider_type) + sdfg_genenerator = GTIRToSDFG(offset_provider_type, column_axis) sdfg = sdfg_genenerator.visit(ir) assert isinstance(sdfg, dace.SDFG) + # TODO(edopao): remove inlining when DaCe transformations support LoopRegion construct + dace_sdfg_utils.inline_loop_blocks(sdfg) + return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 779dc8a1c9..a83654ebc9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -55,7 +55,7 @@ def generate_sdfg( if not self.itir_transforms_off: ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) sdfg = gtir_sdfg.build_sdfg_from_gtir( - ir, offset_provider_type=common.offset_provider_to_type(offset_provider) + ir, common.offset_provider_to_type(offset_provider), column_axis ) if auto_opt: diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index e19d9e1d81..6ffbc667bb 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -104,6 +104,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_SCAN_WITHOUT_FIELD_ARGS = "uses_scan_without_field_args" USES_SCAN_NESTED = "uses_scan_nested" USES_SCAN_REQUIRING_PROJECTOR = "uses_scan_requiring_projector" +USES_SCAN_1D_FIELD = "uses_scan_1d_field" USES_SPARSE_FIELDS = "uses_sparse_fields" USES_SPARSE_FIELDS_AS_OUTPUT = "uses_sparse_fields_as_output" USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields" @@ -146,7 +147,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_LIFT, XFAIL, UNSUPPORTED_MESSAGE), (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCE_WITH_LAMBDA, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLE_ITERATOR, XFAIL, UNSUPPORTED_MESSAGE), @@ -184,9 +184,17 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST + + [ + # dace issue https://github.com/spcl/dace/issues/1773 + (USES_SCAN_1D_FIELD, XFAIL, UNSUPPORTED_MESSAGE), + ], OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST + + [ + # dace issue https://github.com/spcl/dace/issues/1773 + (USES_SCAN_1D_FIELD, XFAIL, UNSUPPORTED_MESSAGE), + ], ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 644e0c6103..e301dbe11b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -819,6 +819,7 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: @pytest.mark.uses_scan +@pytest.mark.uses_scan_1d_field def test_ternary_scan(cartesian_case): @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def simple_scan_operator(carry: float, a: float) -> float: From 8eae147f94346d14a8338fa33c04bc1938c167ee Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 21 Jan 2025 10:39:21 +0100 Subject: [PATCH 101/178] style[cartesian]: Remove unused optional keyword arguments (#1805) --- src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index 5f2007871e..952bafd46a 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -417,7 +417,6 @@ def visit_HorizontalExecution( global_ctx: DaCeIRBuilder.GlobalContext, iteration_ctx: DaCeIRBuilder.IterationContext, symbol_collector: DaCeIRBuilder.SymbolCollector, - loop_order, k_interval, **kwargs: Any, ): @@ -522,7 +521,6 @@ def visit_VerticalLoopSection( self, node: oir.VerticalLoopSection, *, - loop_order, iteration_ctx: DaCeIRBuilder.IterationContext, global_ctx: DaCeIRBuilder.GlobalContext, symbol_collector: DaCeIRBuilder.SymbolCollector, @@ -546,7 +544,6 @@ def visit_VerticalLoopSection( iteration_ctx=iteration_ctx, global_ctx=global_ctx, symbol_collector=symbol_collector, - loop_order=loop_order, k_interval=node.interval, **kwargs, ) @@ -723,7 +720,6 @@ def _process_loop_item( scope_nodes, item: Loop, *, - global_ctx: DaCeIRBuilder.GlobalContext, iteration_ctx: DaCeIRBuilder.IterationContext, symbol_collector: DaCeIRBuilder.SymbolCollector, **kwargs: Any, @@ -840,7 +836,6 @@ def visit_VerticalLoop( sections = flatten_list( self.generic_visit( node.sections, - loop_order=node.loop_order, global_ctx=global_ctx, iteration_ctx=iteration_ctx, symbol_collector=symbol_collector, From 7e566fcd8e3317307fa3175a13f44692ced7b691 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 21 Jan 2025 10:39:51 +0100 Subject: [PATCH 102/178] feature[next]: Runtime check args in is_call_to (#1796) Calling `is_call_to` with the arguments in the wrong order happens easily. This PR adds a runtime check to avoid this. --- .../next/iterator/ir_utils/common_pattern_matcher.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 9df091ac2a..19d0802f4b 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from collections.abc import Iterable -from typing import TypeGuard +from typing import Any, TypeGuard from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im @@ -63,10 +63,14 @@ def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) -def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunCall]: +def is_call_to(node: Any, fun: str | Iterable[str]) -> TypeGuard[itir.FunCall]: """ Match call expression to a given function. + If the `node` argument is not an `itir.Node` the function does not error, but just returns + `False`. This is useful in visitors, where sometimes we pass a list of nodes or a leaf + attribute which can be anything. + >>> from gt4py.next.iterator.ir_utils import ir_makers as im >>> node = im.call("plus")(1, 2) >>> is_call_to(node, "plus") @@ -76,6 +80,7 @@ def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunC >>> is_call_to(node, ("plus", "minus")) True """ + assert not isinstance(fun, itir.Node) # to avoid accidentally passing the fun as first argument if isinstance(fun, (list, tuple, set, Iterable)) and not isinstance(fun, str): return any((is_call_to(node, f) for f in fun)) return ( From ae603cb167c1c90633d29db181eb964ebf4db41d Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 21 Jan 2025 12:18:56 +0100 Subject: [PATCH 103/178] refactor[next][dace]: normalize SDFG field type with local dimension (#1808) Fields with a local dimension can be passed as program arguments. The corresponding `FieldType` parameter type in GTIR contains the local dimension in the list of field domain dimensions, while the data type of the field elements is `ScalarType`. For example: `ts.FieldType(dims=[Vertex, V2EDim], dtype=FLOAT_TYPE)` where: ``` V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) FLOAT_TYPE = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) ``` Except for the program arguments, the internal representation in the SDFG lowering should only contain global dimensions in the field domain, and use `ListType` for the element type in case of a list of values: `ts.FieldType(dims=[Vertex], dtype=ts.ListType(element_type=FLOAT_TYPE, offset_type=V2EDim))` With this PR, the normalized form is used across the SDFG lowering. The `make_field` helper method is modified to convert the type definition of field arguments to the normalized form. --- .../gtir_builtin_translators.py | 47 +++++------------- .../dace_fieldview/gtir_scan_translator.py | 9 ++-- .../runners/dace_fieldview/gtir_sdfg.py | 49 +++++++++++++++++-- 3 files changed, 60 insertions(+), 45 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 1b7e1e15c1..b0d09e0a15 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -111,31 +111,13 @@ def get_local_view( (dim, dace.symbolic.SymExpr(0) if self.offset is None else self.offset[i]) for i, dim in enumerate(self.gt_type.dims) ] - local_dims = [ - dim for dim in self.gt_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL - ] - if len(local_dims) == 0: - return gtir_dataflow.IteratorExpr( - self.dc_node, self.gt_type.dtype, field_domain, it_indices - ) - - elif len(local_dims) == 1: - field_dtype = ts.ListType( - element_type=self.gt_type.dtype, offset_type=local_dims[0] - ) - field_domain = [ - (dim, offset) - for dim, offset in field_domain - if dim.kind != gtx_common.DimensionKind.LOCAL - ] - return gtir_dataflow.IteratorExpr( - self.dc_node, field_dtype, field_domain, it_indices - ) - - else: - raise ValueError( - f"Unexpected data field {self.dc_node.data} with more than one local dimension." - ) + # The property below is ensured by calling `make_field()` to construct `FieldopData`. + # The `make_field` constructor ensures that any local dimension, if present, is converted + # to `ListType` element type, while the field domain consists of all global dimensions. + assert all(dim != gtx_common.DimensionKind.LOCAL for dim in self.gt_type.dims) + return gtir_dataflow.IteratorExpr( + self.dc_node, self.gt_type.dtype, field_domain, it_indices + ) raise NotImplementedError(f"Node type {type(self.gt_type)} not supported.") @@ -305,29 +287,24 @@ def _create_field_operator_impl( raise TypeError( f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}." ) - field_dtype = output_edge.result.gt_dtype - field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset) assert isinstance(dataflow_output_desc, dace.data.Scalar) + field_shape = domain_shape field_subset = domain_subset else: assert isinstance(output_type.dtype, ts.ListType) assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) - field_dtype = output_edge.result.gt_dtype.element_type - if field_dtype != output_type.dtype.element_type: + if output_edge.result.gt_dtype.element_type != output_type.dtype.element_type: raise TypeError( - f"Type mismatch, expected {output_type.dtype.element_type} got {field_dtype}." + f"Type mismatch, expected {output_type.dtype.element_type} got {output_edge.result.gt_dtype.element_type}." ) assert isinstance(dataflow_output_desc, dace.data.Array) assert len(dataflow_output_desc.shape) == 1 # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) assert output_edge.result.gt_dtype.offset_type is not None - field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type] field_shape = [*domain_shape, dataflow_output_desc.shape[0]] - field_offset = [*domain_offset, dataflow_output_desc.offset[0]] field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc) # allocate local temporary storage - assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype) field_name, _ = sdfg_builder.add_temp_array(sdfg, field_shape, dataflow_output_desc.dtype) field_node = state.add_access(field_name) @@ -336,8 +313,8 @@ def _create_field_operator_impl( return FieldopData( field_node, - ts.FieldType(field_dims, field_dtype), - offset=(field_offset if set(field_offset) != {0} else None), + ts.FieldType(domain_dims, output_edge.result.gt_dtype), + offset=(domain_offset if set(domain_offset) != {0} else None), ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_scan_translator.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_scan_translator.py index e105030908..27551a68bf 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_scan_translator.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_scan_translator.py @@ -131,7 +131,7 @@ def _create_scan_field_operator_impl( f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}." ) field_dtype = output_edge.result.gt_dtype - field_dims, field_shape, field_offset = (domain_dims, domain_shape, domain_offset) + field_shape = domain_shape # the scan field operator computes a column of scalar values assert len(dataflow_output_desc.shape) == 1 else: @@ -147,15 +147,12 @@ def _create_scan_field_operator_impl( assert len(dataflow_output_desc.shape) == 2 # the lines below extend the array with the local dimension added by the field operator assert output_edge.result.gt_dtype.offset_type is not None - field_dims = [*domain_dims, output_edge.result.gt_dtype.offset_type] field_shape = [*domain_shape, dataflow_output_desc.shape[1]] - field_offset = [*domain_offset, dataflow_output_desc.offset[1]] field_subset = field_subset + dace_subsets.Range.from_string( f"0:{dataflow_output_desc.shape[1]}" ) # allocate local temporary storage - assert dataflow_output_desc.dtype == dace_utils.as_dace_type(field_dtype) field_name, field_desc = sdfg_builder.add_temp_array( sdfg, field_shape, dataflow_output_desc.dtype ) @@ -173,8 +170,8 @@ def _create_scan_field_operator_impl( return gtir_translators.FieldopData( field_node, - ts.FieldType(field_dims, field_dtype), - offset=(field_offset if set(field_offset) != {0} else None), + ts.FieldType(domain_dims, output_edge.result.gt_dtype), + offset=(domain_offset if set(domain_offset) != {0} else None), ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 1a48959f8c..2139ffe578 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -234,11 +234,52 @@ def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderType def make_field( self, data_node: dace.nodes.AccessNode, data_type: ts.FieldType | ts.ScalarType ) -> gtir_builtin_translators.FieldopData: - if isinstance(data_type, ts.FieldType): - domain_offset = self.field_offsets.get(data_node.data, None) + """ + Helper method to build the field data type associated with an access node in the SDFG. + + In case of `ScalarType` data, the descriptor is constructed with `offset=None`. + In case of `FieldType` data, the field origin is added to the data descriptor. + Besides, if the `FieldType` contains a local dimension, the descriptor is converted + to a canonical form where the field domain consists of all global dimensions + (the grid axes) and the field data type is `ListType`, with `offset_type` equal + to the field local dimension. + + Args: + data_node: The access node to the SDFG data storage. + data_type: The GT4Py data descriptor, which can either come from a field parameter + of an expression node, or from an intermediate field in a previous expression. + + Returns: + The descriptor associated with the SDFG data storage, filled with field origin. + """ + if isinstance(data_type, ts.ScalarType): + return gtir_builtin_translators.FieldopData(data_node, data_type, offset=None) + domain_offset = self.field_offsets.get(data_node.data, None) + local_dims = [dim for dim in data_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL] + if len(local_dims) == 0: + # do nothing: the field domain consists of all global dimensions + field_type = data_type + elif len(local_dims) == 1: + local_dim = local_dims[0] + local_dim_index = data_type.dims.index(local_dim) + # the local dimension is converted into `ListType` data element + if not isinstance(data_type.dtype, ts.ScalarType): + raise ValueError(f"Invalid field type {data_type}.") + if local_dim_index != (len(data_type.dims) - 1): + raise ValueError( + f"Invalid field domain: expected the local dimension to be at the end, found at position {local_dim_index}." + ) + if local_dim.value not in self.offset_provider_type: + raise ValueError( + f"The provided local dimension {local_dim} does not match any offset provider type." + ) + local_type = ts.ListType(element_type=data_type.dtype, offset_type=local_dim) + field_type = ts.FieldType(dims=data_type.dims[:local_dim_index], dtype=local_type) else: - domain_offset = None - return gtir_builtin_translators.FieldopData(data_node, data_type, domain_offset) + raise NotImplementedError( + "Fields with more than one local dimension are not supported." + ) + return gtir_builtin_translators.FieldopData(data_node, field_type, domain_offset) def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.global_symbols[symbol_name] From 022a73c86a282d1407854173518accf1c50d9cae Mon Sep 17 00:00:00 2001 From: SF-N Date: Tue, 21 Jan 2025 13:36:57 +0100 Subject: [PATCH 104/178] feat[next]: Add support for more datatypes (#1786) This builds on [PR#1708](https://github.com/GridTools/gt4py/pull/1708) without the `float16` and `bfloat16` changes. Add support for `int8, uin8, int16, uint16, uint32` and `uint64`. Move builtin definitions from `src/gt4py/next/iterator/ir.py` to `src/gt4py/next/iterator/builtins.py`. Use ascending integer values in `ScalarKind`-Enum and modify tests respectively. Set `start: int = 1` in `tests/next_tests/integration_tests/cases.py` to not start initialization from zero as this has the same value as zero-initialized memory and modify tests respectively. --------- Co-authored-by: Till Ehrengruber --- src/gt4py/next/ffront/fbuiltins.py | 11 +- .../ffront/foast_passes/type_deduction.py | 8 +- src/gt4py/next/ffront/past_to_itir.py | 4 +- src/gt4py/next/ffront/type_info.py | 2 +- src/gt4py/next/iterator/builtins.py | 122 ++++++++++++------ src/gt4py/next/iterator/embedded.py | 122 ++++++------------ src/gt4py/next/iterator/ir.py | 82 +----------- .../next/iterator/ir_utils/domain_utils.py | 6 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 20 +-- .../iterator/transforms/constant_folding.py | 4 +- .../next/iterator/transforms/infer_domain.py | 6 +- .../next/iterator/transforms/prune_casts.py | 4 +- .../next/iterator/transforms/trace_shifts.py | 6 +- .../next/iterator/type_system/inference.py | 14 +- .../iterator/type_system/type_synthesizer.py | 12 +- src/gt4py/next/otf/binding/cpp_interface.py | 24 +--- src/gt4py/next/otf/binding/nanobind.py | 6 +- src/gt4py/next/otf/cpp_utils.py | 32 +++++ .../codegens/gtfn/codegen.py | 14 +- .../codegens/gtfn/gtfn_ir.py | 6 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 19 +-- .../runners/dace_common/utility.py | 19 ++- .../runners/dace_fieldview/gtir_dataflow.py | 4 +- .../dace_fieldview/gtir_python_codegen.py | 4 +- src/gt4py/next/type_system/type_info.py | 13 +- .../next/type_system/type_specifications.py | 16 ++- .../next/type_system/type_translation.py | 10 +- tests/next_tests/integration_tests/cases.py | 4 +- .../ffront_tests/test_execution.py | 6 +- .../iterator_tests/test_program.py | 5 +- .../feature_tests/test_util_cases.py | 6 +- tests/next_tests/toy_connectivity.py | 10 +- .../ffront_tests/test_past_to_gtir.py | 6 +- .../iterator_tests/test_pretty_parser.py | 4 +- .../iterator_tests/test_pretty_printer.py | 4 +- .../transforms_tests/test_global_tmps.py | 4 +- .../binding_tests/test_cpp_interface.py | 16 +-- .../gtfn_tests/test_gtfn_module.py | 6 +- 38 files changed, 304 insertions(+), 357 deletions(-) create mode 100644 src/gt4py/next/otf/cpp_utils.py diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 1210e96efc..cef7fc101f 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -14,7 +14,7 @@ from typing import Any, Callable, Final, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np -from numpy import float32, float64, int32, int64 +from numpy import float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64 from gt4py._core import definitions as core_defs from gt4py.next import common @@ -29,12 +29,19 @@ TYPE_BUILTINS = [ common.Field, common.Dimension, + int8, + uint8, + int16, + uint16, int32, + uint32, int64, + uint64, float32, float64, *PYTHON_TYPE_BUILTINS, -] +] # TODO(tehrengruber): validate matches iterator.builtins.TYPE_BUILTINS? + TYPE_BUILTIN_NAMES = [t.__name__ for t in TYPE_BUILTINS] # Be aware: Type aliases are not fully supported in the frontend yet, e.g. `IndexType(1)` will not diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 6b40cbb77f..26bcadaef1 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -79,7 +79,7 @@ def construct_tuple_type( ... ts.ScalarType(kind=ts.ScalarKind.FLOAT64), ... ] >>> print(construct_tuple_type(true_branch_types, false_branch_types, mask_type)) - [FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)), FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None))] + [FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)), FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None))] """ element_types_new = true_branch_types for i, element in enumerate(true_branch_types): @@ -111,15 +111,15 @@ def promote_to_mask_type( >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) >>> dtype = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) >>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), dtype) - FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) + FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) >>> promote_to_mask_type( ... ts.FieldType(dims=[I, J], dtype=bool_type), ts.FieldType(dims=[I], dtype=dtype) ... ) - FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) + FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) >>> promote_to_mask_type( ... ts.FieldType(dims=[I], dtype=bool_type), ts.FieldType(dims=[I, J], dtype=dtype) ... ) - FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) + FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) """ if isinstance(input_type, ts.ScalarType) or not all( item in input_type.dims for item in mask_type.dims diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 5adc229595..4bc1dfb2f8 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -24,7 +24,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.stages import AOT_PRG -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.otf import stages, workflow from gt4py.next.type_system import type_info, type_specifications as ts @@ -218,7 +218,7 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: if len(fields_dims) > 0: # otherwise `param` has no constituent which is of `FieldType` assert all(field_dims == fields_dims[0] for field_dims in fields_dims) index_type = ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) ) for dim_idx in range(len(fields_dims[0])): size_params.append( diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 83ecf92839..80ba93e187 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -173,7 +173,7 @@ def _scan_param_promotion(param: ts.TypeSpec, arg: ts.TypeSpec) -> ts.FieldType ... dims=[common.Dimension("I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) ... ), ... ) - FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)) + FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)) """ def _as_field(dtype: ts.TypeSpec, path: tuple[int, ...]) -> ts.FieldType: diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index c8edc12331..959f451e01 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -337,16 +337,46 @@ def int(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() +@builtin_dispatch +def int8(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def uint8(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def int16(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def uint16(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def int32(*args): raise BackendNotSelectedError() +@builtin_dispatch +def uint32(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def int64(*args): raise BackendNotSelectedError() +@builtin_dispatch +def uint64(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def float(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() @@ -368,6 +398,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] UNARY_MATH_NUMBER_BUILTINS = {"abs"} +UNARY_LOGICAL_BUILTINS = {"not_"} UNARY_MATH_FP_BUILTINS = { "sin", "cos", @@ -391,52 +422,69 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "trunc", } UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"} -BINARY_MATH_NUMBER_BUILTINS = {"minimum", "maximum", "fmod", "power"} -TYPEBUILTINS = {"int32", "int64", "float32", "float64", "bool"} -MATH_BUILTINS = ( - UNARY_MATH_NUMBER_BUILTINS - | UNARY_MATH_FP_BUILTINS - | UNARY_MATH_FP_PREDICATE_BUILTINS - | BINARY_MATH_NUMBER_BUILTINS - | TYPEBUILTINS -) +BINARY_MATH_NUMBER_BUILTINS = { + "plus", + "minus", + "multiplies", + "divides", + "mod", + "floordiv", # TODO see https://github.com/GridTools/gt4py/issues/1136 + "minimum", + "maximum", + "fmod", +} +BINARY_MATH_COMPARISON_BUILTINS = {"eq", "less", "greater", "greater_equal", "less_equal", "not_eq"} +BINARY_LOGICAL_BUILTINS = {"and_", "or_", "xor_"} + + +#: builtin / dtype used to construct integer indices, like domain bounds +INTEGER_INDEX_BUILTIN = "int32" +INTEGER_TYPE_BUILTINS = { + "int8", + "uint8", + "int16", + "uint16", + "int32", + "uint32", + "int64", + "uint64", +} +FLOATING_POINT_TYPE_BUILTINS = {"float32", "float64"} +TYPE_BUILTINS = {*INTEGER_TYPE_BUILTINS, *FLOATING_POINT_TYPE_BUILTINS, "bool"} + +ARITHMETIC_BUILTINS = { + *UNARY_MATH_NUMBER_BUILTINS, + *UNARY_LOGICAL_BUILTINS, + *UNARY_MATH_FP_BUILTINS, + *UNARY_MATH_FP_PREDICATE_BUILTINS, + *BINARY_MATH_NUMBER_BUILTINS, + "power", + *BINARY_MATH_COMPARISON_BUILTINS, + *BINARY_LOGICAL_BUILTINS, +} + BUILTINS = { - "deref", + "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) "can_deref", + "cartesian_domain", + "cast_", + "deref", + "if_", + "index", # `index(dim)` creates a dim-field that has the current index at each point "shift", - "neighbors", "list_get", + "lift", "make_const_list", + "make_tuple", "map_", - "lift", + "named_range", + "neighbors", "reduce", - "plus", - "minus", - "multiplies", - "divides", - "floordiv", - "mod", - "make_tuple", - "tuple_get", - "if_", - "cast_", - "greater", - "less", - "less_equal", - "greater_equal", - "eq", - "not_eq", - "not_", - "and_", - "or_", - "xor_", "scan", - "cartesian_domain", + "tuple_get", "unstructured_domain", - "named_range", - "as_fieldop", - "index", - *MATH_BUILTINS, + *ARITHMETIC_BUILTINS, + *TYPE_BUILTINS, } __all__ = [*BUILTINS] diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 5949d29432..970e88e8c5 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -15,6 +15,7 @@ import dataclasses import itertools import math +import operator import sys import warnings @@ -85,7 +86,19 @@ TupleAxis: TypeAlias = type[None] Axis: TypeAlias = Union[FieldAxis, TupleAxis] Scalar: TypeAlias = ( - SupportsInt | SupportsFloat | np.int32 | np.int64 | np.float32 | np.float64 | np.bool_ + SupportsInt + | SupportsFloat + | np.int8 + | np.uint8 + | np.int16 + | np.uint16 + | np.int32 + | np.uint32 + | np.int64 + | np.uint64 + | np.float32 + | np.float64 + | np.bool_ ) @@ -389,27 +402,6 @@ def gamma(a): return res.item() -@builtins.and_.register(EMBEDDED) -def and_(a, b): - if isinstance(a, Column): - return np.logical_and(a, b) - return a and b - - -@builtins.or_.register(EMBEDDED) -def or_(a, b): - if isinstance(a, Column): - return np.logical_or(a, b) - return a or b - - -@builtins.xor_.register(EMBEDDED) -def xor_(a, b): - if isinstance(a, Column): - return np.logical_xor(a, b) - return a ^ b - - @builtins.tuple_get.register(EMBEDDED) def tuple_get(i, tup): if isinstance(tup, Column): @@ -497,66 +489,6 @@ def named_range(tag: Tag | common.Dimension, start: int, end: int) -> NamedRange return (tag, range(start, end)) -@builtins.minus.register(EMBEDDED) -def minus(first, second): - return first - second - - -@builtins.plus.register(EMBEDDED) -def plus(first, second): - return first + second - - -@builtins.multiplies.register(EMBEDDED) -def multiplies(first, second): - return first * second - - -@builtins.divides.register(EMBEDDED) -def divides(first, second): - return first / second - - -@builtins.floordiv.register(EMBEDDED) -def floordiv(first, second): - return first // second - - -@builtins.mod.register(EMBEDDED) -def mod(first, second): - return first % second - - -@builtins.eq.register(EMBEDDED) -def eq(first, second): - return first == second - - -@builtins.greater.register(EMBEDDED) -def greater(first, second): - return first > second - - -@builtins.less.register(EMBEDDED) -def less(first, second): - return first < second - - -@builtins.less_equal.register(EMBEDDED) -def less_equal(first, second): - return first <= second - - -@builtins.greater_equal.register(EMBEDDED) -def greater_equal(first, second): - return first >= second - - -@builtins.not_eq.register(EMBEDDED) -def not_eq(first, second): - return first != second - - CompositeOfScalarOrField: TypeAlias = Scalar | common.Field | tuple["CompositeOfScalarOrField", ...] @@ -585,11 +517,31 @@ def promote_scalars(val: CompositeOfScalarOrField): ) -for math_builtin_name in builtins.MATH_BUILTINS: - python_builtins = {"int": int, "float": float, "bool": bool, "str": str} +for math_builtin_name in builtins.ARITHMETIC_BUILTINS | builtins.TYPE_BUILTINS: + python_builtins: dict[str, Callable] = { + "int": int, + "float": float, + "bool": bool, + "str": str, + "plus": operator.add, + "minus": operator.sub, + "multiplies": operator.mul, + "divides": operator.truediv, + "mod": operator.mod, + "floordiv": operator.floordiv, + "eq": operator.eq, + "less": operator.lt, + "greater": operator.gt, + "greater_equal": operator.ge, + "less_equal": operator.le, + "not_eq": operator.ne, + "and_": operator.and_, + "or_": operator.or_, + "xor_": operator.xor, + } decorator = getattr(builtins, math_builtin_name).register(EMBEDDED) impl: Callable - if math_builtin_name == "gamma": + if math_builtin_name in ["gamma", "not_"]: continue # treated explicitly elif math_builtin_name in python_builtins: # TODO: Should potentially use numpy fixed size types to be consistent diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index e875709631..ea5cf84d86 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -14,6 +14,7 @@ from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.eve.utils import noninstantiable from gt4py.next import common +from gt4py.next.iterator.builtins import BUILTINS from gt4py.next.type_system import type_specifications as ts @@ -93,87 +94,6 @@ class FunctionDefinition(Node, SymbolTableTrait): expr: Expr -UNARY_MATH_NUMBER_BUILTINS = {"abs"} -UNARY_LOGICAL_BUILTINS = {"not_"} -UNARY_MATH_FP_BUILTINS = { - "sin", - "cos", - "tan", - "arcsin", - "arccos", - "arctan", - "sinh", - "cosh", - "tanh", - "arcsinh", - "arccosh", - "arctanh", - "sqrt", - "exp", - "log", - "gamma", - "cbrt", - "floor", - "ceil", - "trunc", -} -UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"} -BINARY_MATH_NUMBER_BUILTINS = { - "minimum", - "maximum", - "fmod", - "plus", - "minus", - "multiplies", - "divides", - "mod", - "floordiv", # TODO see https://github.com/GridTools/gt4py/issues/1136 -} -BINARY_MATH_COMPARISON_BUILTINS = {"eq", "less", "greater", "greater_equal", "less_equal", "not_eq"} -BINARY_LOGICAL_BUILTINS = {"and_", "or_", "xor_"} - -ARITHMETIC_BUILTINS = { - *UNARY_MATH_NUMBER_BUILTINS, - *UNARY_LOGICAL_BUILTINS, - *UNARY_MATH_FP_BUILTINS, - *UNARY_MATH_FP_PREDICATE_BUILTINS, - *BINARY_MATH_NUMBER_BUILTINS, - "power", - *BINARY_MATH_COMPARISON_BUILTINS, - *BINARY_LOGICAL_BUILTINS, -} - -#: builtin / dtype used to construct integer indices, like domain bounds -INTEGER_INDEX_BUILTIN = "int32" -INTEGER_BUILTINS = {"int32", "int64"} -FLOATING_POINT_BUILTINS = {"float32", "float64"} -TYPEBUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"} - -BUILTINS = { - "tuple_get", - "cast_", - "cartesian_domain", - "unstructured_domain", - "make_tuple", - "shift", - "neighbors", - "named_range", - "list_get", - "map_", - "make_const_list", - "lift", - "reduce", - "deref", - "can_deref", - "scan", - "if_", - "index", # `index(dim)` creates a dim-field that has the current index at each point - "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) - *ARITHMETIC_BUILTINS, - *TYPEBUILTINS, -} - - class Stmt(Node): ... diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 4a023f7535..c84e2c0228 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -13,7 +13,7 @@ from typing import Any, Literal, Mapping, Optional from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding @@ -127,7 +127,7 @@ def translate( else: # note: ugly but cheap re-computation, but should disappear horizontal_sizes = { - k: im.literal(str(v), itir.INTEGER_INDEX_BUILTIN) + k: im.literal(str(v), builtins.INTEGER_INDEX_BUILTIN) for k, v in _max_domain_sizes_by_location_type(offset_provider).items() } @@ -137,7 +137,7 @@ def translate( assert new_dim not in new_ranges or old_dim == new_dim new_range = SymbolicRange( - im.literal("0", itir.INTEGER_INDEX_BUILTIN), + im.literal("0", builtins.INTEGER_INDEX_BUILTIN), horizontal_sizes[new_dim.value], ) new_ranges = dict( diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 0839e95b5b..c5cf2efa5a 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -11,7 +11,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.type_system import type_specifications as ts, type_translation @@ -29,7 +29,7 @@ def sym(sym_or_name: Union[str, itir.Sym], type_: str | ts.TypeSpec | None = Non >>> a = sym("a", "float32") >>> a.id, a.type - (SymbolName('a'), ScalarType(kind=, shape=None)) + (SymbolName('a'), ScalarType(kind=, shape=None)) """ if isinstance(sym_or_name, itir.Sym): assert not type_ @@ -53,7 +53,7 @@ def ref( >>> a = ref("a", "float32") >>> a.id, a.type - (SymbolRef('a'), ScalarType(kind=, shape=None)) + (SymbolRef('a'), ScalarType(kind=, shape=None)) """ if isinstance(ref_or_name, itir.SymRef): assert not type_ @@ -71,7 +71,7 @@ def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> iti SymRef(id=SymbolRef('a')) >>> ensure_expr(3) - Literal(value='3', type=ScalarType(kind=, shape=None)) + Literal(value='3', type=ScalarType(kind=, shape=None)) >>> ensure_expr(itir.OffsetLiteral(value="i")) OffsetLiteral(value='i') @@ -134,7 +134,7 @@ class call: Examples -------- >>> call("plus")(1, 1) - FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type=ScalarType(kind=, shape=None)), Literal(value='1', type=ScalarType(kind=, shape=None))]) + FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type=ScalarType(kind=, shape=None)), Literal(value='1', type=ScalarType(kind=, shape=None))]) """ def __init__(self, expr): @@ -238,7 +238,7 @@ def make_tuple(*args): def tuple_get(index: str | int, tuple_expr): """Create a tuple_get FunCall, shorthand for ``call("tuple_get")(index, tuple_expr)``.""" - return call("tuple_get")(literal(str(index), itir.INTEGER_INDEX_BUILTIN), tuple_expr) + return call("tuple_get")(literal(str(index), builtins.INTEGER_INDEX_BUILTIN), tuple_expr) def if_(cond, true_val, false_val): @@ -316,11 +316,11 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: Make a literal node from a value. >>> literal_from_value(1.0) - Literal(value='1.0', type=ScalarType(kind=, shape=None)) + Literal(value='1.0', type=ScalarType(kind=, shape=None)) >>> literal_from_value(1) - Literal(value='1', type=ScalarType(kind=, shape=None)) + Literal(value='1', type=ScalarType(kind=, shape=None)) >>> literal_from_value(2147483648) - Literal(value='2147483648', type=ScalarType(kind=, shape=None)) + Literal(value='2147483648', type=ScalarType(kind=, shape=None)) >>> literal_from_value(True) Literal(value='True', type=ScalarType(kind=, shape=None)) """ @@ -335,7 +335,7 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: assert isinstance(type_spec, ts.ScalarType) typename = type_spec.kind.name.lower() - assert typename in itir.TYPEBUILTINS + assert typename in builtins.TYPE_BUILTINS return literal(str(val), typename) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 2084ab2518..7215d0787a 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import embedded, ir +from gt4py.next.iterator import builtins, embedded, ir from gt4py.next.iterator.ir_utils import ir_makers as im @@ -44,7 +44,7 @@ def visit_FunCall(self, node: ir.FunCall): and all(isinstance(arg, ir.Literal) for arg in new_node.args) ): # `1 + 1` -> `2` try: - if new_node.fun.id in ir.ARITHMETIC_BUILTINS: + if new_node.fun.id in builtins.ARITHMETIC_BUILTINS: fun = getattr(embedded, str(new_node.fun.id)) arg_values = [ getattr(embedded, str(arg.type))(arg.value) # type: ignore[attr-defined] # arg type already established in if condition diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index f26d3f9ec2..f3c3185225 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -15,7 +15,7 @@ from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Callable, Optional, TypeAlias, Unpack from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, domain_utils, @@ -383,8 +383,8 @@ def _infer_expr( elif cpm.is_call_to(expr, "if_"): return _infer_if(expr, domain, **kwargs) elif ( - cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) - or cpm.is_call_to(expr, itir.TYPEBUILTINS) + cpm.is_call_to(expr, builtins.ARITHMETIC_BUILTINS) + or cpm.is_call_to(expr, builtins.TYPE_BUILTINS) or cpm.is_call_to(expr, ("cast_", "index", "unstructured_domain", "cartesian_domain")) ): return expr, {} diff --git a/src/gt4py/next/iterator/transforms/prune_casts.py b/src/gt4py/next/iterator/transforms/prune_casts.py index c825f68a5f..3276f47042 100644 --- a/src/gt4py/next/iterator/transforms/prune_casts.py +++ b/src/gt4py/next/iterator/transforms/prune_casts.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py import eve -from gt4py.next.iterator import ir +from gt4py.next.iterator import builtins, ir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.type_system import type_specifications as ts @@ -31,7 +31,7 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node: assert ( value.type and isinstance(type_constructor, ir.SymRef) - and (type_constructor.id in ir.TYPEBUILTINS) + and (type_constructor.id in builtins.TYPE_BUILTINS) ) dtype = ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper())) diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 68346b6622..4c44d660f6 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -13,7 +13,7 @@ from gt4py import eve from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import ir +from gt4py.next.iterator import builtins, ir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -278,9 +278,9 @@ def visit_Literal(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: def visit_SymRef(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: if node.id in ctx: return ctx[node.id] - elif node.id in ir.TYPEBUILTINS: + elif node.id in builtins.TYPE_BUILTINS: return Sentinel.TYPE - elif node.id in (ir.ARITHMETIC_BUILTINS | {"list_get", "make_const_list", "cast_"}): + elif node.id in (builtins.ARITHMETIC_BUILTINS | {"list_get", "make_const_list", "cast_"}): return _combine raise ValueError(f"Undefined symbol {node.id}") diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index d0d39cbd34..901cb103da 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -17,7 +17,7 @@ from gt4py.eve import concepts from gt4py.eve.extended_typing import Any, Callable, Optional, TypeVar, Union from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_call_to from gt4py.next.iterator.type_system import type_specifications as it_ts, type_synthesizer from gt4py.next.type_system import type_info, type_specifications as ts @@ -147,7 +147,7 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): >>> float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) >>> int_type = ts.ScalarType(kind=ts.ScalarKind.INT64) >>> power(float_type, int_type) - ScalarType(kind=, shape=None) + ScalarType(kind=, shape=None) Now, consider a simple lambda function that squares its argument using the power builtin. A type synthesizer for this function is simple to formulate, but merely gives us the return @@ -159,7 +159,7 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): ... type_synthesizer=lambda base: power(base, int_type) ... ) >>> square_func_type_synthesizer(float_type, offset_provider_type={}) - ScalarType(kind=, shape=None) + ScalarType(kind=, shape=None) Note that without a corresponding call the function itself can not be fully typed and as such the type inference algorithm has to defer typing until then. This task is handled transparently @@ -173,7 +173,7 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): ... store_inferred_type_in_node=True, ... ) >>> o_type_synthesizer(float_type, offset_provider_type={}) - ScalarType(kind=, shape=None) + ScalarType(kind=, shape=None) >>> square_func.type == ts.FunctionType( ... pos_only_args=[float_type], pos_or_kw_args={}, kw_only_args={}, returns=float_type ... ) @@ -566,7 +566,9 @@ def visit_OffsetLiteral( if _is_representable_as_int(node.value): return it_ts.OffsetLiteralType( - value=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())) + value=ts.ScalarType( + kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) + ) ) else: assert isinstance(self.dimensions, dict) @@ -616,7 +618,7 @@ def visit_FunCall( self.visit(value, ctx=ctx) # ensure types in value are also inferred assert ( isinstance(type_constructor, itir.SymRef) - and type_constructor.id in itir.TYPEBUILTINS + and type_constructor.id in builtins.TYPE_BUILTINS ) return ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper())) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 6e9936c4af..f5aeac7943 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -14,7 +14,7 @@ from gt4py.eve.extended_typing import Callable, Iterable, Optional, Union from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.utils import tree_map @@ -81,7 +81,7 @@ def _register_builtin_type_synthesizer( @_register_builtin_type_synthesizer( - fun_names=itir.UNARY_MATH_NUMBER_BUILTINS | itir.UNARY_MATH_FP_BUILTINS + fun_names=builtins.UNARY_MATH_NUMBER_BUILTINS | builtins.UNARY_MATH_FP_BUILTINS ) def _(val: ts.ScalarType) -> ts.ScalarType: return val @@ -92,7 +92,7 @@ def power(base: ts.ScalarType, exponent: ts.ScalarType) -> ts.ScalarType: return base -@_register_builtin_type_synthesizer(fun_names=itir.BINARY_MATH_NUMBER_BUILTINS) +@_register_builtin_type_synthesizer(fun_names=builtins.BINARY_MATH_NUMBER_BUILTINS) def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType: if isinstance(lhs, ts.DeferredType): return rhs @@ -103,14 +103,14 @@ def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType: @_register_builtin_type_synthesizer( - fun_names=itir.UNARY_MATH_FP_PREDICATE_BUILTINS | itir.UNARY_LOGICAL_BUILTINS + fun_names=builtins.UNARY_MATH_FP_PREDICATE_BUILTINS | builtins.UNARY_LOGICAL_BUILTINS ) def _(arg: ts.ScalarType) -> ts.ScalarType: return ts.ScalarType(kind=ts.ScalarKind.BOOL) @_register_builtin_type_synthesizer( - fun_names=itir.BINARY_MATH_COMPARISON_BUILTINS | itir.BINARY_LOGICAL_BUILTINS + fun_names=builtins.BINARY_MATH_COMPARISON_BUILTINS | builtins.BINARY_LOGICAL_BUILTINS ) def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType | ts.TupleType: return ts.ScalarType(kind=ts.ScalarKind.BOOL) @@ -197,7 +197,7 @@ def make_tuple(*args: ts.DataType) -> ts.TupleType: def index(arg: ts.DimensionType) -> ts.FieldType: return ts.FieldType( dims=[arg.dim], - dtype=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())), + dtype=ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())), ) diff --git a/src/gt4py/next/otf/binding/cpp_interface.py b/src/gt4py/next/otf/binding/cpp_interface.py index d112a9c256..17eee4d5c6 100644 --- a/src/gt4py/next/otf/binding/cpp_interface.py +++ b/src/gt4py/next/otf/binding/cpp_interface.py @@ -8,7 +8,7 @@ from typing import Final, Sequence -from gt4py.next.otf import languages +from gt4py.next.otf import cpp_utils, languages from gt4py.next.otf.binding import interface from gt4py.next.type_system import type_info as ti, type_specifications as ts @@ -18,32 +18,12 @@ ) -def render_scalar_type(scalar_type: ts.ScalarType) -> str: - match scalar_type.kind: - case ts.ScalarKind.BOOL: - return "bool" - case ts.ScalarKind.INT32: - return "std::int32_t" - case ts.ScalarKind.INT64: - return "std::int64_t" - case ts.ScalarKind.FLOAT32: - return "float" - case ts.ScalarKind.FLOAT64: - return "double" - case ts.ScalarKind.STRING: - return "std::string" - case _: - raise AssertionError( - f"Scalar kind '{scalar_type}' is not implemented when it should be." - ) - - def render_function_declaration(function: interface.Function, body: str) -> str: template_params: list[str] = [] rendered_params: list[str] = [] for index, param in enumerate(function.parameters): if isinstance(param.type_, ts.ScalarType): - rendered_params.append(f"{render_scalar_type(param.type_)} {param.name}") + rendered_params.append(f"{cpp_utils.pytype_to_cpptype(param.type_)} {param.name}") elif ti.is_type_or_tuple_of_type(param.type_, (ts.FieldType, ts.ScalarType)): template_param = f"ArgT{index}" template_params.append(f"class {template_param}") diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 3abf49788f..a2cf480d7f 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -14,7 +14,7 @@ import gt4py.eve as eve from gt4py.eve.codegen import JinjaTemplate as as_jinja, TemplatedGenerator -from gt4py.next.otf import languages, stages, workflow +from gt4py.next.otf import cpp_utils, languages, stages, workflow from gt4py.next.otf.binding import cpp_interface, interface from gt4py.next.type_system import type_specifications as ts @@ -88,13 +88,13 @@ def _type_string(type_: ts.TypeSpec) -> str: ndims = len(type_.dims) # cannot be ListType: the concept is represented as Field with local Dimension in this interface assert isinstance(type_.dtype, ts.ScalarType) - dtype = cpp_interface.render_scalar_type(type_.dtype) + dtype = cpp_utils.pytype_to_cpptype(type_.dtype) shape = f"nanobind::shape<{', '.join(['gridtools::nanobind::dynamic_size'] * ndims)}>" buffer_t = f"nanobind::ndarray<{dtype}, {shape}>" origin_t = f"std::tuple<{', '.join(['ptrdiff_t'] * ndims)}>" return f"std::pair<{buffer_t}, {origin_t}>" elif isinstance(type_, ts.ScalarType): - return cpp_interface.render_scalar_type(type_) + return cpp_utils.pytype_to_cpptype(type_) else: raise ValueError(f"Type '{type_}' is not supported in nanobind interfaces.") diff --git a/src/gt4py/next/otf/cpp_utils.py b/src/gt4py/next/otf/cpp_utils.py new file mode 100644 index 0000000000..8b2af40eb5 --- /dev/null +++ b/src/gt4py/next/otf/cpp_utils.py @@ -0,0 +1,32 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +from gt4py.next.type_system import type_specifications as ts + + +def pytype_to_cpptype(t: ts.ScalarType | str) -> str: + if isinstance(t, ts.ScalarType): + t = t.kind.name.lower() + try: + return { + "float32": "float", + "float64": "double", + "int8": "std::int8_t", + "uint8": "std::uint8_t", + "int16": "std::int16_t", + "uint16": "std::uint16_t", + "int32": "std::int32_t", + "uint32": "std::uint32_t", + "int64": "std::int64_t", + "uint64": "std::uint64_t", + "bool": "bool", + "string": "string", + }[t] + except KeyError: + raise TypeError(f"Unsupported type '{t}'.") from None diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index bfc45d7944..c6bf28d8e0 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -11,8 +11,8 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import common +from gt4py.next.otf import cpp_utils from gt4py.next.program_processors.codegens.gtfn import gtfn_im_ir, gtfn_ir, gtfn_ir_common -from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import pytype_to_cpptype class GTFNCodegen(codegen.TemplatedGenerator): @@ -52,8 +52,14 @@ class GTFNCodegen(codegen.TemplatedGenerator): "power": "std::pow", "float32": "float", "float64": "double", + "int8": "std::int8_t", + "uint8": "std::uint8_t", + "int16": "std::int16_t", + "uint16": "std::uint16_t", "int32": "std::int32_t", + "uint32": "std::uint32_t", "int64": "std::int64_t", + "uint64": "std::uint64_t", "bool": "bool", "plus": "std::plus{}", "minus": "std::minus{}", @@ -92,8 +98,11 @@ def asfloat(value: str) -> str: return value def visit_Literal(self, node: gtfn_ir.Literal, **kwargs: Any) -> str: + if node.type == "axis_literal": + return node.value + # TODO(tehrengruber): isn't this wrong and int32 should be casted to an actual int32? - match pytype_to_cpptype(node.type): + match cpp_utils.pytype_to_cpptype(node.type): case "float": return self.asfloat(node.value) + "f" case "double": @@ -101,6 +110,7 @@ def visit_Literal(self, node: gtfn_ir.Literal, **kwargs: Any) -> str: case "bool": return node.value.lower() case _: + # TODO(tehrengruber): we should probably shouldn't just allow anything here. Revisit. return node.value IntegralConstant = as_fmt("{value}_c") diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 85a100a88d..831694791a 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -13,7 +13,7 @@ from gt4py.eve import Coerced, SymbolName, datamodels from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins from gt4py.next.program_processors.codegens.gtfn.gtfn_im_ir import ImperativeFunctionDefinition from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Node, Sym, SymRef @@ -230,8 +230,8 @@ class TemporaryAllocation(Node): "reduce", "index", ] -ARITHMETIC_BUILTINS = itir.ARITHMETIC_BUILTINS -TYPEBUILTINS = itir.TYPEBUILTINS +ARITHMETIC_BUILTINS = builtins.ARITHMETIC_BUILTINS +TYPEBUILTINS = builtins.TYPE_BUILTINS BUILTINS = {*GTFN_BUILTINS, *ARITHMETIC_BUILTINS, *TYPEBUILTINS} diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 3dc7998a54..104e2eccc1 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -17,6 +17,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.type_system import inference as itir_type_inference +from gt4py.next.otf import cpp_utils from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import ( Backend, BinaryExpr, @@ -47,22 +48,6 @@ from gt4py.next.type_system import type_info, type_specifications as ts -def pytype_to_cpptype(t: ts.ScalarType | str) -> Optional[str]: - if isinstance(t, ts.ScalarType): - t = t.kind.name.lower() - try: - return { - "float32": "float", - "float64": "double", - "int32": "std::int32_t", - "int64": "std::int64_t", - "bool": "bool", - "axis_literal": None, # TODO: domain? - }[t] - except KeyError: - raise TypeError(f"Unsupported type '{t}'.") from None - - _vertical_dimension = "gtfn::unstructured::dim::vertical" _horizontal_dimension = "gtfn::unstructured::dim::horizontal" @@ -707,7 +692,7 @@ def dtype_to_cpp(x: ts.DataType) -> str: assert all(isinstance(i, ts.ScalarType) for i in x.types) return "::gridtools::tuple<" + ", ".join(dtype_to_cpp(i) for i in x.types) + ">" # type: ignore[arg-type] # ensured by assert assert isinstance(x, ts.ScalarType) - res = pytype_to_cpptype(x) + res = cpp_utils.pytype_to_cpptype(x) assert isinstance(res, str) return res diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index a0f7711231..4a3e5d4e4c 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -29,17 +29,14 @@ def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: """Converts GT4Py scalar type to corresponding DaCe type.""" - if type_.kind == ts.ScalarKind.BOOL: - return dace.bool_ - elif type_.kind == ts.ScalarKind.INT32: - return dace.int32 - elif type_.kind == ts.ScalarKind.INT64: - return dace.int64 - elif type_.kind == ts.ScalarKind.FLOAT32: - return dace.float32 - elif type_.kind == ts.ScalarKind.FLOAT64: - return dace.float64 - raise ValueError(f"Scalar type '{type_}' not supported.") + + match type_.kind: + case ts.ScalarKind.BOOL: + return dace.bool_ + case ts.ScalarKind(): + return getattr(dace, type_.kind.name.lower()) + case _: + raise ValueError(f"Scalar type '{type_}' not supported.") def as_itir_type(dtype: dace.typeclass) -> ts.ScalarType: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 2c91e2d1b3..966ade5c03 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -29,7 +29,7 @@ from gt4py import eve from gt4py.next import common as gtx_common, utils as gtx_utils -from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator import builtins, ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import symbol_ref_utils from gt4py.next.program_processors.runners.dace_common import utility as dace_utils @@ -1507,7 +1507,7 @@ def _make_unstructured_shift( def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: # convert builtin-index type to dace type IndexDType: Final = dace_utils.as_dace_type( - ts.ScalarType(kind=getattr(ts.ScalarKind, gtir.INTEGER_INDEX_BUILTIN.upper())) + ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())) ) assert isinstance(node.fun, gtir.FunCall) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index 2b3c5417cd..dfbba9c88b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -14,7 +14,7 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt -from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator import builtins, ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm @@ -75,7 +75,7 @@ def builtin_cast(*args: Any) -> str: val, target_type = args - assert target_type in gtir.TYPEBUILTINS + assert target_type in builtins.TYPE_BUILTINS return MATH_BUILTINS_MAPPING[target_type].format(val) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 983063a9cb..26373c647f 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -173,7 +173,7 @@ def apply_to_primitive_constituents( ... with_path_arg=True, ... tuple_constructor=lambda *elements: dict(elements), ... ) - {(0,): ScalarType(kind=, shape=None), (1,): ScalarType(kind=, shape=None)} + {(0,): ScalarType(kind=, shape=None), (1,): ScalarType(kind=, shape=None)} """ if isinstance(symbol_types[0], ts.TupleType): assert all(isinstance(symbol_type, ts.TupleType) for symbol_type in symbol_types) @@ -254,7 +254,12 @@ def is_integer(symbol_type: ts.TypeSpec) -> bool: False """ return isinstance(symbol_type, ts.ScalarType) and symbol_type.kind in { + ts.ScalarKind.INT8, + ts.ScalarKind.UINT8, + ts.ScalarKind.INT16, + ts.ScalarKind.UINT16, ts.ScalarKind.INT32, + ts.ScalarKind.UINT32, ts.ScalarKind.INT64, } @@ -327,8 +332,14 @@ def arithmetic_bounds(arithmetic_type: ts.ScalarType) -> tuple[np.number, np.num return { # type: ignore[return-value] # why resolved to `tuple[object, object]`? ts.ScalarKind.FLOAT32: (np.finfo(np.float32).min, np.finfo(np.float32).max), ts.ScalarKind.FLOAT64: (np.finfo(np.float64).min, np.finfo(np.float64).max), + ts.ScalarKind.INT8: (np.iinfo(np.int8).min, np.iinfo(np.int8).max), + ts.ScalarKind.UINT8: (np.iinfo(np.uint8).min, np.iinfo(np.uint8).max), + ts.ScalarKind.INT16: (np.iinfo(np.int16).min, np.iinfo(np.int16).max), + ts.ScalarKind.UINT16: (np.iinfo(np.uint16).min, np.iinfo(np.uint16).max), ts.ScalarKind.INT32: (np.iinfo(np.int32).min, np.iinfo(np.int32).max), + ts.ScalarKind.UINT32: (np.iinfo(np.uint32).min, np.iinfo(np.uint32).max), ts.ScalarKind.INT64: (np.iinfo(np.int64).min, np.iinfo(np.int64).max), + ts.ScalarKind.UINT64: (np.iinfo(np.uint64).min, np.iinfo(np.uint64).max), }[arithmetic_type.kind] diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index c1c0f0b5e1..2fbd039d16 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -61,11 +61,17 @@ def __str__(self) -> str: class ScalarKind(eve_types.IntEnum): BOOL = 1 - INT32 = 32 - INT64 = 64 - FLOAT32 = 1032 - FLOAT64 = 1064 - STRING = 3001 + INT8 = 2 + UINT8 = 3 + INT16 = 4 + UINT16 = 5 + INT32 = 6 + UINT32 = 7 + INT64 = 8 + UINT64 = 9 + FLOAT32 = 10 + FLOAT64 = 11 + STRING = 12 class ScalarType(DataType): diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index e601556e55..10b82f7861 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -41,16 +41,10 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: match dt: case np.bool_: return ts.ScalarKind.BOOL - case np.int32: - return ts.ScalarKind.INT32 - case np.int64: - return ts.ScalarKind.INT64 - case np.float32: - return ts.ScalarKind.FLOAT32 - case np.float64: - return ts.ScalarKind.FLOAT64 case np.str_: return ts.ScalarKind.STRING + case np.dtype(): + return getattr(ts.ScalarKind, dt.name.upper()) case _: raise ValueError(f"Impossible to map '{dtype}' value to a 'ScalarKind'.") else: diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index c2b98ee8d9..89ad556476 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -192,13 +192,13 @@ class UniqueInitializer(DataInitializer): data containers. """ - start: int = 0 + start: int = 1 @property def scalar_value(self) -> ScalarValue: start = self.start self.start += 1 - return np.int64(start) + return start def field( self, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index e301dbe11b..95bde32107 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -318,7 +318,9 @@ def testee(a: int32, b: int32, c: cases.IField) -> cases.IField: # not inlined return tmp2 * tmp2 * c - cases.verify_with_default_data(cartesian_case, testee, ref=lambda a, b, c: a * b * a * b * c) + cases.verify_with_default_data( + cartesian_case, testee, ref=lambda a, b, c: a * b * a * b * a * b * a * b * c + ) @pytest.mark.uses_scalar_in_domain_and_fo @@ -1126,7 +1128,7 @@ def implicit_broadcast_scalar(inp: cases.EmptyField): inp = cases.allocate(cartesian_case, implicit_broadcast_scalar, "inp")() out = cases.allocate(cartesian_case, implicit_broadcast_scalar, "inp")() - cases.verify(cartesian_case, implicit_broadcast_scalar, inp, out=out, ref=np.array(0)) + cases.verify(cartesian_case, implicit_broadcast_scalar, inp, out=out, ref=np.array(1)) def test_implicit_broadcast_mixed_dim(cartesian_case): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index c79f8dbb6b..09dc04acb1 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -18,6 +18,7 @@ index, named_range, shift, + INTEGER_INDEX_BUILTIN, ) from gt4py.next.iterator.runtime import fendef, fundef, set_at @@ -68,7 +69,7 @@ def test_index_builtin(program_processor): program_processor, validate = program_processor isize = 10 - out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, itir.INTEGER_INDEX_BUILTIN)) + out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, INTEGER_INDEX_BUILTIN)) run_processor(index_program_simple, program_processor, out, isize, offset_provider={}) if validate: @@ -91,7 +92,7 @@ def test_index_builtin_shift(program_processor): program_processor, validate = program_processor isize = 10 - out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, itir.INTEGER_INDEX_BUILTIN)) + out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, INTEGER_INDEX_BUILTIN)) run_processor(index_program_shift, program_processor, out, isize, offset_provider={"Ioff": I}) if validate: diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index eaeb76b404..3e3df069bf 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -35,8 +35,8 @@ def mixed_args( def test_allocate_default_unique(cartesian_case): a = cases.allocate(cartesian_case, mixed_args, "a")() - assert np.min(a.asnumpy()) == 0 - assert np.max(a.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) - 1 + assert np.min(a.asnumpy()) == 1 + assert np.max(a.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) b = cases.allocate(cartesian_case, mixed_args, "b")() @@ -45,7 +45,7 @@ def test_allocate_default_unique(cartesian_case): c = cases.allocate(cartesian_case, mixed_args, "c")() assert np.min(c.asnumpy()) == b + 1 - assert np.max(c.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) * 2 + assert np.max(c.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) * 2 + 1 def test_allocate_return_default_zeros(cartesian_case): diff --git a/tests/next_tests/toy_connectivity.py b/tests/next_tests/toy_connectivity.py index 50db24b880..154b666c5d 100644 --- a/tests/next_tests/toy_connectivity.py +++ b/tests/next_tests/toy_connectivity.py @@ -9,7 +9,7 @@ import numpy as np import gt4py.next as gtx -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir Vertex = gtx.Dimension("Vertex") @@ -46,7 +46,7 @@ [7, 17, 1, 16], [8, 15, 2, 17], ], - dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), + dtype=np.dtype(builtins.INTEGER_INDEX_BUILTIN), ) c2e_conn = gtx.as_connectivity(domain={Cell: 9, C2EDim: 4}, codomain=Edge, data=c2e_arr) @@ -63,7 +63,7 @@ [8, 1, 6, 4], [6, 2, 7, 5], ], - dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), + dtype=np.dtype(builtins.INTEGER_INDEX_BUILTIN), ) v2v_conn = gtx.as_connectivity(domain={Vertex: 9, V2VDim: 4}, codomain=Vertex, data=v2v_arr) @@ -89,7 +89,7 @@ [7, 1], [8, 2], ], - dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), + dtype=np.dtype(builtins.INTEGER_INDEX_BUILTIN), ) e2v_conn = gtx.as_connectivity(domain={Edge: 18, E2VDim: 2}, codomain=Vertex, data=e2v_arr) @@ -107,7 +107,7 @@ [7, 13, 6, 16], [8, 14, 7, 17], ], - dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), + dtype=np.dtype(builtins.INTEGER_INDEX_BUILTIN), ) v2e_conn = gtx.as_connectivity(domain={Vertex: 9, V2EDim: 4}, codomain=Edge, data=v2e_arr) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py index fa9a0220ef..cbaa84454d 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py @@ -17,7 +17,7 @@ from gt4py.next import errors from gt4py.next.ffront.func_to_past import ProgramParser from gt4py.next.ffront.past_to_itir import ProgramLowering -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts @@ -147,7 +147,7 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) value="1", type=ts.ScalarType( kind=getattr( - ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper() + ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper() ) ), ), @@ -174,7 +174,7 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) value="2", type=ts.ScalarType( kind=getattr( - ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper() + ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper() ) ), ), diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index af9084f407..f825c3823b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir, builtins from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.pretty_parser import pparse from gt4py.next.type_system import type_specifications as ts @@ -111,7 +111,7 @@ def test_tuple_get(): testee = "x[42]" expected = ir.FunCall( fun=ir.SymRef(id="tuple_get"), - args=[im.literal("42", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], + args=[im.literal("42", builtins.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], ) actual = pparse(testee) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 6b45f470b7..b0f7021bc0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir, builtins from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.pretty_printer import PrettyPrinter, pformat from gt4py.next.type_system import type_specifications as ts @@ -200,7 +200,7 @@ def test_shift(): def test_tuple_get(): testee = ir.FunCall( fun=ir.SymRef(id="tuple_get"), - args=[im.literal("42", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], + args=[im.literal("42", builtins.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], ) expected = "x[42]" actual = pformat(testee) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 9d51dc4f33..52d77e5fda 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -9,7 +9,7 @@ from typing import Optional from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import global_tmps, infer_domain from gt4py.next.iterator.type_system import inference as type_inference @@ -19,7 +19,7 @@ IDim = common.Dimension(value="IDim") JDim = common.Dimension(value="JDim") KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) -index_type = ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())) +index_type = ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())) float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) index_field_type_factory = lambda dim: ts.FieldType(dims=[dim], dtype=index_type) diff --git a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py index b1e051c82b..51b6bf512b 100644 --- a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py +++ b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py @@ -33,10 +33,10 @@ def test_render_function_declaration_scalar(function_scalar_example): expected = format_source( "cpp", """\ - decltype(auto) example(double a, std::int64_t b) { +decltype(auto) example(double a, std::int64_t b) { return; }\ - """, +""", style="LLVM", ) assert rendered == expected @@ -81,11 +81,11 @@ def test_render_function_declaration_buffer(function_buffer_example): expected = format_source( "cpp", """\ - template - decltype(auto) example(ArgT0 &&a_buf, ArgT1 &&b_buf) { +template + decltype(auto) example(ArgT0&& a_buf, ArgT1&& b_buf) { return; }\ - """, +""", style="LLVM", ) assert rendered == expected @@ -132,11 +132,11 @@ def test_render_function_declaration_tuple(function_tuple_example): expected = format_source( "cpp", """\ - template - decltype(auto) example(ArgT0 &&a_buf) { +template + decltype(auto) example(ArgT0&& a_buf) { return; }\ - """, +""", style="LLVM", ) assert rendered == expected diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 0586d48703..1afd6e8113 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -13,7 +13,7 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.otf import arguments, languages, stages from gt4py.next.program_processors.codegens.gtfn import gtfn_module @@ -41,8 +41,8 @@ def program_example(): fun=itir.SymRef(id="named_range"), args=[ itir.AxisLiteral(value="I"), - im.literal("0", itir.INTEGER_INDEX_BUILTIN), - im.literal("10", itir.INTEGER_INDEX_BUILTIN), + im.literal("0", builtins.INTEGER_INDEX_BUILTIN), + im.literal("10", builtins.INTEGER_INDEX_BUILTIN), ], ) ], From 6e486d808206dfbee2865415deb1f7c4d35209bb Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 21 Jan 2025 14:05:27 +0100 Subject: [PATCH 105/178] refactor[next]: remove DaCe to GTFN dependency for IR hash functions (#1809) This PR moves some utility functions for caching of GTIR program, used by both DaCe and GTFN backends, to `otf.stages` module close to `otf.stages.CompilableProgram`. --- src/gt4py/next/otf/stages.py | 40 +++++++++++++++++ .../next/program_processors/runners/dace.py | 6 +-- .../next/program_processors/runners/gtfn.py | 45 +------------------ .../gtfn_tests/test_gtfn_module.py | 20 ++++----- 4 files changed, 55 insertions(+), 56 deletions(-) diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 22326c7e87..ff4285d72d 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -11,6 +11,8 @@ import dataclasses from typing import Any, Generic, Optional, Protocol, TypeAlias, TypeVar +from gt4py.eve import utils +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments, languages, toolchain from gt4py.next.otf.binding import interface @@ -29,6 +31,44 @@ CompilableProgram: TypeAlias = toolchain.CompilableProgram[itir.Program, arguments.CompileTimeArgs] +def compilation_hash(otf_closure: CompilableProgram) -> int: + """Given closure compute a hash uniquely determining if we need to recompile.""" + offset_provider = otf_closure.args.offset_provider + return hash( + ( + otf_closure.data, + # As the frontend types contain lists they are not hashable. As a workaround we just + # use content_hash here. + utils.content_hash(tuple(arg for arg in otf_closure.args.args)), + # Directly using the `id` of the offset provider is not possible as the decorator adds + # the implicitly defined ones (i.e. to allow the `TDim + 1` syntax) resulting in a + # different `id` every time. Instead use the `id` of each individual offset provider. + tuple((k, id(v)) for (k, v) in offset_provider.items()) if offset_provider else None, + otf_closure.args.column_axis, + ) + ) + + +def fingerprint_compilable_program(inp: CompilableProgram) -> str: + """ + Generates a unique hash string for a stencil source program representing + the program, sorted offset_provider, and column_axis. + """ + program: itir.Program = inp.data + offset_provider: common.OffsetProvider = inp.args.offset_provider + column_axis: Optional[common.Dimension] = inp.args.column_axis + + program_hash = utils.content_hash( + ( + program, + sorted(offset_provider.items(), key=lambda el: el[0]), + column_axis, + ) + ) + + return program_hash + + @dataclasses.dataclass(frozen=True) class ProgramSource(Generic[SrcL, SettingT]): """ diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace.py index 1b3b930818..b7f419a749 100644 --- a/src/gt4py/next/program_processors/runners/dace.py +++ b/src/gt4py/next/program_processors/runners/dace.py @@ -11,12 +11,11 @@ import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators from gt4py.next import backend -from gt4py.next.otf import workflow +from gt4py.next.otf import stages, workflow from gt4py.next.program_processors.runners.dace_fieldview import workflow as dace_fieldview_workflow -from gt4py.next.program_processors.runners.gtfn import GTFNBackendFactory -class DaCeFieldviewBackendFactory(GTFNBackendFactory): +class DaCeFieldviewBackendFactory(factory.Factory): class Meta: model = backend.Backend @@ -36,6 +35,7 @@ class Params: name_cached="_cached", ) device_type = core_defs.DeviceType.CPU + hash_function = stages.compilation_hash otf_workflow = factory.SubFactory( dace_fieldview_workflow.DaCeWorkflowFactory, device_type=factory.SelfAttribute("..device_type"), diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index c0a9be9168..a8961fd9bc 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -18,10 +18,7 @@ import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators -from gt4py.eve import utils -from gt4py.eve.utils import content_hash from gt4py.next import backend, common, config -from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler @@ -102,44 +99,6 @@ def extract_connectivity_args( return args -def compilation_hash(otf_closure: stages.CompilableProgram) -> int: - """Given closure compute a hash uniquely determining if we need to recompile.""" - offset_provider = otf_closure.args.offset_provider - return hash( - ( - otf_closure.data, - # As the frontend types contain lists they are not hashable. As a workaround we just - # use content_hash here. - content_hash(tuple(arg for arg in otf_closure.args.args)), - # Directly using the `id` of the offset provider is not possible as the decorator adds - # the implicitly defined ones (i.e. to allow the `TDim + 1` syntax) resulting in a - # different `id` every time. Instead use the `id` of each individual offset provider. - tuple((k, id(v)) for (k, v) in offset_provider.items()) if offset_provider else None, - otf_closure.args.column_axis, - ) - ) - - -def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str: - """ - Generates a unique hash string for a stencil source program representing - the program, sorted offset_provider, and column_axis. - """ - program: itir.Program = inp.data - offset_provider: common.OffsetProvider = inp.args.offset_provider - column_axis: Optional[common.Dimension] = inp.args.column_axis - - program_hash = utils.content_hash( - ( - program, - sorted(offset_provider.items(), key=lambda el: el[0]), - column_axis, - ) - ) - - return program_hash - - class FileCache(diskcache.Cache): """ This class extends `diskcache.Cache` to ensure the cache is properly @@ -189,7 +148,7 @@ class Params: translation=factory.LazyAttribute( lambda o: workflow.CachedStep( o.bare_translation, - hash_function=fingerprint_compilable_program, + hash_function=stages.fingerprint_compilable_program, cache=FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), ) ), @@ -236,7 +195,7 @@ class Params: name_cached="_cached", ) device_type = core_defs.DeviceType.CPU - hash_function = compilation_hash + hash_function = stages.compilation_hash otf_workflow = factory.SubFactory( GTFNCompileWorkflowFactory, device_type=factory.SelfAttribute("..device_type") ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 1afd6e8113..53e463c6c7 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -94,7 +94,7 @@ def test_hash_and_diskcache(program_example, tmp_path): *parameters, **{"offset_provider": {}} ), ) - hash = gtfn.fingerprint_compilable_program(compilable_program) + hash = stages.fingerprint_compilable_program(compilable_program) with diskcache.Cache(tmp_path) as cache: cache[hash] = compilable_program @@ -107,27 +107,27 @@ def test_hash_and_diskcache(program_example, tmp_path): del reopened_cache[hash] # delete data # hash creation is deterministic - assert hash == gtfn.fingerprint_compilable_program(compilable_program) - assert hash == gtfn.fingerprint_compilable_program(compilable_program_from_cache) + assert hash == stages.fingerprint_compilable_program(compilable_program) + assert hash == stages.fingerprint_compilable_program(compilable_program_from_cache) # hash is different if program changes altered_program_id = copy.deepcopy(compilable_program) altered_program_id.data.id = "example2" - assert gtfn.fingerprint_compilable_program( + assert stages.fingerprint_compilable_program( compilable_program - ) != gtfn.fingerprint_compilable_program(altered_program_id) + ) != stages.fingerprint_compilable_program(altered_program_id) altered_program_offset_provider = copy.deepcopy(compilable_program) object.__setattr__(altered_program_offset_provider.args, "offset_provider", {"Koff": KDim}) - assert gtfn.fingerprint_compilable_program( + assert stages.fingerprint_compilable_program( compilable_program - ) != gtfn.fingerprint_compilable_program(altered_program_offset_provider) + ) != stages.fingerprint_compilable_program(altered_program_offset_provider) altered_program_column_axis = copy.deepcopy(compilable_program) object.__setattr__(altered_program_column_axis.args, "column_axis", KDim) - assert gtfn.fingerprint_compilable_program( + assert stages.fingerprint_compilable_program( compilable_program - ) != gtfn.fingerprint_compilable_program(altered_program_column_axis) + ) != stages.fingerprint_compilable_program(altered_program_column_axis) def test_gtfn_file_cache(program_example): @@ -146,7 +146,7 @@ def test_gtfn_file_cache(program_example): gpu=False, cached=True, otf_workflow__cached_translation=False ).executor.step.translation - cache_key = gtfn.fingerprint_compilable_program(compilable_program) + cache_key = stages.fingerprint_compilable_program(compilable_program) # ensure the actual cached step in the backend generates the cache item for the test if cache_key in (translation_cache := cached_gtfn_translation_step.cache): From 44578ecea1132554c9a54c10205d7fb9e99ad81b Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 21 Jan 2025 16:02:45 +0100 Subject: [PATCH 106/178] ci: disable test config for DaintXC (#1812) The GitLab configuration `.container-runner-daint-gpu` was removed upstream. The GT4Py CI was still referring to it, and therefore it failed. --- ci/cscs-ci.yml | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index ad919d6bc0..b5ea07b787 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -119,17 +119,17 @@ build_py310_image_aarch64: SLURM_TIMELIMIT: 15 NUM_PROCESSES: auto VIRTUALENV_SYSTEM_SITE_PACKAGES: 1 -.test_helper_x86_64: - extends: [.container-runner-daint-gpu, .test_helper] - parallel: - matrix: - - SUBPACKAGE: [cartesian, storage] - VARIANT: [-internal, -dace] - SUBVARIANT: [-cuda11x, -cpu] - - SUBPACKAGE: eve - - SUBPACKAGE: next - VARIANT: [-nomesh, -atlas] - SUBVARIANT: [-cuda11x, -cpu] +# .test_helper_x86_64: +# extends: [.container-runner-daint-gpu, .test_helper] +# parallel: +# matrix: +# - SUBPACKAGE: [cartesian, storage] +# VARIANT: [-internal, -dace] +# SUBVARIANT: [-cuda11x, -cpu] +# - SUBPACKAGE: eve +# - SUBPACKAGE: next +# VARIANT: [-nomesh, -atlas] +# SUBVARIANT: [-cuda11x, -cpu] .test_helper_aarch64: extends: [.container-runner-daint-gh200, .test_helper] parallel: From 9bbb95276e31c19f5c67e1de32e279265ed6cf3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Wed, 22 Jan 2025 09:02:38 +0100 Subject: [PATCH 107/178] fix[next][dace]: Fix for `DistributedBufferRelocator` (#1814) This PR fixes a bug in `DistributedBufferRelocator` that was observed in ICON4Py's `TestUpdateThetaAndExner` test. In essence there was an `assert` that assumed that checked if this temporary was a sink node, but, the code that finds all write backs was never excluding such cases, i.e. the temporaries that were selected might not be sink nodes in the state where they are defined. The `assert` was not part of the original implementation and is not a requirement of the transformation, instead it was introduced by [PR#1799](https://github.com/GridTools/gt4py/pull/1799), that fixed some issues in the analysis of read write dependencies. There are two solutions for this, either removing the `assert` or prune these kinds of temporaries. After some consideration, it was realized that handling such cases will not lead to invalid SDFG, as long as the other restrictions on the global data are respected. For that reason the `assert` was removed. However, we should thinking of doing something more intelligent in that case. --- .../transformations/simplify.py | 13 +++- .../test_distributed_buffer_relocator.py | 60 +++++++++++++++++++ 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index bb95244aef..cc845505c9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -551,11 +551,17 @@ def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: def_locations: list[AccessLocation] = [] for upstream_state in find_upstream_states(temp_storage_state): if temp_storage_node.data in access_sets[upstream_state][1]: - def_locations.extend( + # NOTE: We do not impose any restriction on `temp_storage`. Thus + # It could be that we do read from it (we can never write to it) + # in this state or any other state later. + # TODO(phimuell): Should we require that `temp_storage` is a sink + # node? It might prevent or allow other optimizations. + new_locations = [ (data_node, upstream_state) for data_node in upstream_state.data_nodes() if data_node.data == temp_storage_node.data - ) + ] + def_locations.extend(new_locations) if len(def_locations) != 0: result_candidates.append((temp_storage, def_locations)) @@ -677,7 +683,6 @@ def _check_read_write_dependency_impl( # Get the location and the state where the temporary is originally defined. def_location_of_intermediate, state_to_inspect = target_location - assert state_to_inspect.out_degree(def_location_of_intermediate) == 0 # These are all access nodes that refers to the global data, that we want # to move into the state `state_to_inspect`. We need them to do the @@ -689,6 +694,8 @@ def _check_read_write_dependency_impl( # empty Memlets. This is done because such Memlets are used to induce a # schedule or order in the dataflow graph. # As a byproduct, for the second test, we also collect all of these nodes. + # TODO(phimuell): Refine this such that it takes the location of the data + # into account. for dnode in state_to_inspect.data_nodes(): if dnode.data != global_data_name: continue diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py index d61b8a2d42..709079dd0d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py @@ -287,3 +287,63 @@ def test_distributed_buffer_global_memory_data_no_rance2(): res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) assert res[sdfg]["DistributedBufferRelocator"][state2] == {"t"} assert state2.number_of_nodes() == 0 + + +def _make_distributed_buffer_non_sink_temporary_sdfg() -> ( + tuple[dace.SDFG, dace.SDFGState, dace.SDFGState] +): + sdfg = dace.SDFG(util.unique_name("distributed_buffer_non_sink_temporary_sdfg")) + state = sdfg.add_state(is_start_block=True) + wb_state = sdfg.add_state_after(state) + + names = ["a", "b", "c", "t1", "t2"] + for name in names: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t1"].transient = True + sdfg.arrays["t2"].transient = True + t1 = state.add_access("t1") + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i": "0:10"}, + inputs={"__in1": dace.Memlet("a[__i]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("t1[__i]")}, + output_nodes={t1}, + external_edges=True, + ) + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i": "0:10"}, + inputs={"__in1": dace.Memlet("t1[__i]")}, + code="__out = __in1 / 2.0", + outputs={"__out": dace.Memlet("t2[__i]")}, + input_nodes={t1}, + external_edges=True, + ) + + wb_state.add_nedge(wb_state.add_access("t1"), wb_state.add_access("b"), dace.Memlet("t1[0:10]")) + wb_state.add_nedge(wb_state.add_access("t2"), wb_state.add_access("b"), dace.Memlet("t2[0:10]")) + + sdfg.validate() + return sdfg, state, wb_state + + +def test_distributed_buffer_non_sink_temporary(): + """Tests the transformation if one of the temporaries is not a sink node. + + Note that the SDFG has two temporaries, `t1` is not a sink node and `t2` is + a sink node. + """ + sdfg, state, wb_state = _make_distributed_buffer_non_sink_temporary_sdfg() + assert wb_state.number_of_nodes() == 4 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + sdfg.view() + assert res[sdfg]["DistributedBufferRelocator"][wb_state] == {"t1", "t2"} + assert wb_state.number_of_nodes() == 0 From 89d3b0d1c2eb4246096688e64cd0e76b89d6664e Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 22 Jan 2025 10:00:08 +0100 Subject: [PATCH 108/178] refactor[next][dace]: cleanup dace backend module (#1811) PR #1753 removed the `dace_iterator` backend, that before existed by side of `dace_fieldview` backend. Code common to these two backends was placed in `dace_common` module. The current GTIR-DaCe backend supports both iterator view and field view flavors of the GTIR, so there is no need for a distinction at the top level of the backend module. This PR applies the following cleanup tasks: ``` runners/dace_common/dace_backend -> runners/dace_fieldview/sdfg_callable runners/dace_common/utility -> runners/dace_fieldview/utils runners/dace_common/workflow, runners/dace_fieldview/workflow -> runners/dace_fieldview/workflow runners/dace_fiedview/utility -> runners/dace_fieldview/gtir_sdfg_utils runners/dace_fieldview -> runners/dace ``` The module `runners/dace/workflow` was also split into sub-modules. A doc string was added to `runners/dace/workflow/__init__.py` --- src/gt4py/next/ffront/decorator.py | 2 +- .../runners/dace/__init__.py | 27 ++++++ .../gtir_builtin_translators.py | 28 +++--- .../{dace_fieldview => dace}/gtir_dataflow.py | 44 ++++----- .../gtir_python_codegen.py | 0 .../gtir_scan_translator.py | 19 ++-- .../{dace_fieldview => dace}/gtir_sdfg.py | 47 ++++----- .../utility.py => dace/gtir_sdfg_utils.py} | 18 +++- .../{dace_fieldview => dace}/program.py | 24 ++--- .../dace_backend.py => dace/sdfg_callable.py} | 4 +- .../transformations/__init__.py | 2 +- .../transformations/auto_optimize.py | 4 +- .../transformations/gpu_utils.py | 4 +- .../transformations/local_double_buffering.py | 8 +- .../transformations/loop_blocking.py | 4 +- .../transformations/map_fusion_helper.py | 0 .../transformations/map_fusion_parallel.py | 0 .../transformations/map_fusion_serial.py | 0 .../transformations/map_orderer.py | 4 +- .../transformations/map_promoter.py | 4 +- .../transformations/simplify.py | 20 ++-- .../transformations/strides.py | 6 +- .../util.py => dace/transformations/utils.py} | 4 +- .../{dace_common/utility.py => dace/utils.py} | 19 +--- .../runners/dace/workflow/__init__.py | 20 ++++ .../{dace.py => dace/workflow/backend.py} | 18 ++-- .../workflow/compilation.py} | 81 +--------------- .../runners/dace/workflow/decoration.py | 96 +++++++++++++++++++ .../runners/dace/workflow/factory.py | 58 +++++++++++ .../workflow/translation.py} | 42 +------- .../runners/dace_common/__init__.py | 8 -- .../runners/dace_fieldview/__init__.py | 17 ---- .../dace_tests/test_gtir_to_sdfg.py | 2 +- .../test_constant_substitution.py | 2 +- .../test_create_local_double_buffering.py | 2 +- .../test_distributed_buffer_relocator.py | 2 +- .../test_global_self_copy_elimination.py | 2 +- .../transformation_tests/test_gpu_utils.py | 2 +- .../test_loop_blocking.py | 2 +- .../test_map_buffer_elimination.py | 2 +- .../transformation_tests/test_map_fusion.py | 2 +- .../transformation_tests/test_map_order.py | 2 +- .../test_move_tasklet_into_map.py | 2 +- .../test_serial_map_promoter.py | 2 +- .../transformation_tests/test_strides.py | 2 +- 45 files changed, 357 insertions(+), 301 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace/__init__.py rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/gtir_builtin_translators.py (97%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/gtir_dataflow.py (98%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/gtir_python_codegen.py (100%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/gtir_scan_translator.py (97%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/gtir_sdfg.py (96%) rename src/gt4py/next/program_processors/runners/{dace_fieldview/utility.py => dace/gtir_sdfg_utils.py} (86%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/program.py (91%) rename src/gt4py/next/program_processors/runners/{dace_common/dace_backend.py => dace/sdfg_callable.py} (98%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/transformations/__init__.py (96%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/transformations/auto_optimize.py (99%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/transformations/gpu_utils.py (99%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/transformations/local_double_buffering.py (98%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/transformations/loop_blocking.py (99%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/transformations/map_fusion_helper.py (100%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/transformations/map_fusion_parallel.py (100%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/transformations/map_fusion_serial.py (100%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/transformations/map_orderer.py (97%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/transformations/map_promoter.py (99%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/transformations/simplify.py (98%) rename src/gt4py/next/program_processors/runners/{dace_fieldview => dace}/transformations/strides.py (99%) rename src/gt4py/next/program_processors/runners/{dace_fieldview/transformations/util.py => dace/transformations/utils.py} (98%) rename src/gt4py/next/program_processors/runners/{dace_common/utility.py => dace/utils.py} (79%) create mode 100644 src/gt4py/next/program_processors/runners/dace/workflow/__init__.py rename src/gt4py/next/program_processors/runners/{dace.py => dace/workflow/backend.py} (72%) rename src/gt4py/next/program_processors/runners/{dace_common/workflow.py => dace/workflow/compilation.py} (50%) create mode 100644 src/gt4py/next/program_processors/runners/dace/workflow/decoration.py create mode 100644 src/gt4py/next/program_processors/runners/dace/workflow/factory.py rename src/gt4py/next/program_processors/runners/{dace_fieldview/workflow.py => dace/workflow/translation.py} (71%) delete mode 100644 src/gt4py/next/program_processors/runners/dace_common/__init__.py delete mode 100644 src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 7e2abc44fb..ecaf1a76b4 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -306,7 +306,7 @@ def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: try: - from gt4py.next.program_processors.runners.dace_fieldview.program import Program + from gt4py.next.program_processors.runners.dace.program import Program except ImportError: pass diff --git a/src/gt4py/next/program_processors/runners/dace/__init__.py b/src/gt4py/next/program_processors/runners/dace/__init__.py new file mode 100644 index 0000000000..8540585494 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/__init__.py @@ -0,0 +1,27 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +from gt4py.next.program_processors.runners.dace.gtir_sdfg import build_sdfg_from_gtir +from gt4py.next.program_processors.runners.dace.sdfg_callable import get_sdfg_args +from gt4py.next.program_processors.runners.dace.workflow.backend import ( + run_dace_cpu, + run_dace_cpu_noopt, + run_dace_gpu, + run_dace_gpu_noopt, +) + + +__all__ = [ + "build_sdfg_from_gtir", + "get_sdfg_args", + "run_dace_cpu", + "run_dace_cpu_noopt", + "run_dace_gpu", + "run_dace_gpu_noopt", +] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py similarity index 97% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py rename to src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py index b0d09e0a15..0fe776c3ee 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py @@ -23,19 +23,19 @@ domain_utils, ir_makers as im, ) -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( gtir_dataflow, gtir_python_codegen, gtir_sdfg, - utility as dace_gtir_utils, + gtir_sdfg_utils, + utils as gtx_dace_utils, ) -from gt4py.next.program_processors.runners.dace_fieldview.gtir_scan_translator import translate_scan +from gt4py.next.program_processors.runners.dace.gtir_scan_translator import translate_scan from gt4py.next.type_system import type_info as ti, type_specifications as ts if TYPE_CHECKING: - from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg + from gt4py.next.program_processors.runners.dace import gtir_sdfg def get_domain_indices( @@ -54,7 +54,7 @@ def get_domain_indices( as `dace.subsets.Indices`, it should be converted to `dace.subsets.Range` before being used in memlet subset because ranges are better supported throughout DaCe. """ - index_variables = [dace.symbolic.SymExpr(dace_gtir_utils.get_map_variable(dim)) for dim in dims] + index_variables = [dace.symbolic.SymExpr(gtir_sdfg_utils.get_map_variable(dim)) for dim in dims] if offsets is None: return dace_subsets.Indices(index_variables) else: @@ -158,7 +158,7 @@ def flatten_tuples(name: str, arg: FieldopResult) -> list[tuple[str, FieldopData """ if isinstance(arg, tuple): tuple_type = get_tuple_type(arg) - tuple_symbols = dace_gtir_utils.flatten_tuple_fields(name, tuple_type) + tuple_symbols = gtir_sdfg_utils.flatten_tuple_fields(name, tuple_type) tuple_data_fields = gtx_utils.flatten_nested_tuple(arg) return [ (str(tsym.id), tfield) @@ -353,7 +353,7 @@ def _create_field_operator( "fieldop", state, ndrange={ - dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" + gtir_sdfg_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" for dim, lower_bound, upper_bound in domain }, ) @@ -372,7 +372,7 @@ def _create_field_operator( ) else: # handle tuples of fields - output_symbol_tree = dace_gtir_utils.make_symbol_tree("x", node_type) + output_symbol_tree = gtir_sdfg_utils.make_symbol_tree("x", node_type) return gtx_utils.tree_map( lambda output_edge, output_sym: _create_field_operator_impl( sdfg_builder, sdfg, state, domain, output_edge, output_sym.type, map_exit @@ -590,13 +590,13 @@ def translate_index( domain = extract_domain(node.annex.domain) assert len(domain) == 1 dim, _, _ = domain[0] - dim_index = dace_gtir_utils.get_map_variable(dim) + dim_index = gtir_sdfg_utils.get_map_variable(dim) index_data, _ = sdfg_builder.add_temp_scalar(sdfg, INDEX_DTYPE) index_node = state.add_access(index_data) index_value = gtir_dataflow.ValueExpr( dc_node=index_node, - gt_dtype=dace_utils.as_itir_type(INDEX_DTYPE), + gt_dtype=gtx_dace_utils.as_itir_type(INDEX_DTYPE), ) index_write_tasklet = sdfg_builder.add_tasklet( "index", @@ -643,7 +643,7 @@ def _get_data_nodes( return sdfg_builder.make_field(data_node, data_type) elif isinstance(data_type, ts.TupleType): - symbol_tree = dace_gtir_utils.make_symbol_tree(data_name, data_type) + symbol_tree = gtir_sdfg_utils.make_symbol_tree(data_name, data_type) return gtx_utils.tree_map( lambda sym: _get_data_nodes(sdfg, state, sdfg_builder, sym.id, sym.type) )(symbol_tree) @@ -669,7 +669,7 @@ def _get_symbolic_value( ) temp_name, _ = sdfg.add_scalar( temp_name or sdfg.temp_data_name(), - dace_utils.as_dace_type(scalar_type), + gtx_dace_utils.as_dace_type(scalar_type), find_new_name=True, transient=True, ) @@ -808,7 +808,7 @@ def translate_scalar_expr( dace.Memlet(data=arg_node.data, subset="0"), ) # finally, create temporary for the result value - temp_name, _ = sdfg_builder.add_temp_scalar(sdfg, dace_utils.as_dace_type(node.type)) + temp_name, _ = sdfg_builder.add_temp_scalar(sdfg, gtx_dace_utils.as_dace_type(node.type)) temp_node = state.add_access(temp_name) state.add_edge( tasklet_node, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py similarity index 98% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py rename to src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py index 966ade5c03..e00e363ac4 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -32,11 +32,11 @@ from gt4py.next.iterator import builtins, ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import symbol_ref_utils -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( gtir_python_codegen, gtir_sdfg, - utility as dace_gtir_utils, + gtir_sdfg_utils, + utils as gtx_dace_utils, ) from gt4py.next.type_system import type_info as ti, type_specifications as ts @@ -274,7 +274,7 @@ def connect( def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: assert isinstance(node.type, ts.ScalarType) - dc_dtype = dace_utils.as_dace_type(node.type) + dc_dtype = gtx_dace_utils.as_dace_type(node.type) assert isinstance(node.fun, gtir.FunCall) assert len(node.fun.args) == 2 @@ -467,7 +467,7 @@ def _construct_tasklet_result( src_connector: str, use_array: bool = False, ) -> ValueExpr: - data_type = dace_utils.as_itir_type(dc_dtype) + data_type = gtx_dace_utils.as_itir_type(dc_dtype) if use_array: # In some cases, such as result data with list-type annotation, we want # that output data is represented as an array (single-element 1D array) @@ -676,7 +676,7 @@ def _visit_if_branch( if isinstance(arg, tuple): ptype = get_tuple_type(arg) # type: ignore[arg-type] psymbol = im.sym(pname, ptype) - psymbol_tree = dace_gtir_utils.make_symbol_tree(pname, ptype) + psymbol_tree = gtir_sdfg_utils.make_symbol_tree(pname, ptype) inner_arg = gtx_utils.tree_map( lambda tsym, targ: self._visit_if_branch_arg( if_sdfg, if_branch_state, tsym.id, targ, if_sdfg_input_memlets @@ -772,7 +772,7 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp ) nsdfg = dace.SDFG(self.unique_nsdfg_name(prefix="if_stmt")) - nsdfg.debuginfo = dace_utils.debug_info(node, default=self.sdfg.debuginfo) + nsdfg.debuginfo = gtir_sdfg_utils.debug_info(node, default=self.sdfg.debuginfo) # create states inside the nested SDFG for the if-branches if use_conditional_block: @@ -812,7 +812,7 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp edge.connect(map_entry=None) if isinstance(node.type, ts.TupleType): - out_symbol_tree = dace_gtir_utils.make_symbol_tree("__output", node.type) + out_symbol_tree = gtir_sdfg_utils.make_symbol_tree("__output", node.type) outer_value = gtx_utils.tree_map( lambda x, y, nstate=nstate: self._visit_if_branch_result(nsdfg, nstate, x, y) )(output_tree, out_symbol_tree) @@ -880,7 +880,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: assert all(isinstance(index, SymbolExpr) for index in it.indices.values()) field_desc = it.field.desc(self.sdfg) - connectivity = dace_utils.connectivity_identifier(offset) + connectivity = gtx_dace_utils.connectivity_identifier(offset) # initially, the storage for the connectivty tables is created as transient; # when the tables are used, the storage is changed to non-transient, # as the corresponding arrays are supposed to be allocated by the SDFG caller @@ -926,7 +926,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: ) neighbors_node = self.state.add_access(neighbors_temp) offset_type = gtx_common.Dimension(offset, gtx_common.DimensionKind.LOCAL) - neighbor_idx = dace_gtir_utils.get_map_variable(offset_type) + neighbor_idx = gtir_sdfg_utils.get_map_variable(offset_type) index_connector = "__index" output_connector = "__val" @@ -979,7 +979,7 @@ def _visit_list_get(self, node: gtir.FunCall) -> ValueExpr: list_desc = list_arg.dc_node.desc(self.sdfg) assert len(list_desc.shape) == 1 - result_dtype = dace_utils.as_dace_type(list_arg.gt_dtype.element_type) + result_dtype = gtx_dace_utils.as_dace_type(list_arg.gt_dtype.element_type) result, _ = self.subgraph_builder.add_temp_scalar(self.sdfg, result_dtype) result_node = self.state.add_access(result) @@ -1039,7 +1039,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: assert len(node.fun.args) == 1 # the operation to be mapped on the arguments assert isinstance(node.type.element_type, ts.ScalarType) - dc_dtype = dace_utils.as_dace_type(node.type.element_type) + dc_dtype = gtx_dace_utils.as_dace_type(node.type.element_type) input_connectors = [f"__arg{i}" for i in range(len(node.args))] output_connector = "__out" @@ -1081,7 +1081,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: raise ValueError("Unexpected arguments to map expression with different neighborhood.") offset_type, offset_provider_type = next(iter(input_connectivity_types.items())) local_size = offset_provider_type.max_neighbors - map_index = dace_gtir_utils.get_map_variable(offset_type) + map_index = gtir_sdfg_utils.get_map_variable(offset_type) # The dataflow we build in this class has some loose connections on input edges. # These edges are described as set of nodes, that will have to be connected to @@ -1115,11 +1115,11 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: if offset_provider_type.has_skip_values: # In case the `map_` input expressions contain skip values, we use # the connectivity-based offset provider as mask for map computation. - connectivity = dace_utils.connectivity_identifier(offset_type.value) + connectivity = gtx_dace_utils.connectivity_identifier(offset_type.value) connectivity_desc = self.sdfg.arrays[connectivity] connectivity_desc.transient = False - origin_map_index = dace_gtir_utils.get_map_variable(offset_provider_type.source_dim) + origin_map_index = gtir_sdfg_utils.get_map_variable(offset_provider_type.source_dim) connectivity_slice = self._construct_local_view( MemletExpr( @@ -1187,14 +1187,14 @@ def _make_reduce_with_skip_values( corresponding neighbor index in the connectivity table is valid, or the identity value if the neighbor index is missing. """ - origin_map_index = dace_gtir_utils.get_map_variable(offset_provider_type.source_dim) + origin_map_index = gtir_sdfg_utils.get_map_variable(offset_provider_type.source_dim) assert ( isinstance(input_expr.gt_dtype, ts.ListType) and input_expr.gt_dtype.offset_type is not None ) offset_type = input_expr.gt_dtype.offset_type - connectivity = dace_utils.connectivity_identifier(offset_type.value) + connectivity = gtx_dace_utils.connectivity_identifier(offset_type.value) connectivity_node = self.state.add_access(connectivity) connectivity_desc = connectivity_node.desc(self.sdfg) connectivity_desc.transient = False @@ -1506,7 +1506,7 @@ def _make_unstructured_shift( def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: # convert builtin-index type to dace type - IndexDType: Final = dace_utils.as_dace_type( + IndexDType: Final = gtx_dace_utils.as_dace_type( ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())) ) @@ -1536,7 +1536,7 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: # initially, the storage for the connectivity tables is created as transient; # when the tables are used, the storage is changed to non-transient, # so the corresponding arrays are supposed to be allocated by the SDFG caller - offset_table = dace_utils.connectivity_identifier(offset) + offset_table = gtx_dace_utils.connectivity_identifier(offset) self.sdfg.arrays[offset_table].transient = False offset_table_node = self.state.add_access(offset_table) @@ -1604,13 +1604,13 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: # Therefore we handle `ListType` as a single-element array with shape (1,) # that will be accessed in a map expression on a local domain. assert isinstance(node.type.element_type, ts.ScalarType) - dc_dtype = dace_utils.as_dace_type(node.type.element_type) + dc_dtype = gtx_dace_utils.as_dace_type(node.type.element_type) # In order to ease the lowring of the parent expression on local dimension, # we represent the scalar value as a single-element 1D array. use_array = True else: assert isinstance(node.type, ts.ScalarType) - dc_dtype = dace_utils.as_dace_type(node.type) + dc_dtype = gtx_dace_utils.as_dace_type(node.type) use_array = False return self._construct_tasklet_result(dc_dtype, tasklet_node, "result", use_array=use_array) @@ -1713,7 +1713,7 @@ def _visit_Lambda_impl( ) def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: - dc_dtype = dace_utils.as_dace_type(node.type) + dc_dtype = gtx_dace_utils.as_dace_type(node.type) return SymbolExpr(node.value, dc_dtype) def visit_SymRef( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py similarity index 100% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py rename to src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_scan_translator.py b/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py similarity index 97% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_scan_translator.py rename to src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py index 27551a68bf..d3d8e101a7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_scan_translator.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py @@ -30,18 +30,17 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( gtir_builtin_translators as gtir_translators, gtir_dataflow, gtir_sdfg, - utility as dace_gtir_utils, + gtir_sdfg_utils, ) from gt4py.next.type_system import type_info as ti, type_specifications as ts if TYPE_CHECKING: - from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg + from gt4py.next.program_processors.runners.dace import gtir_sdfg def _parse_scan_fieldop_arg( @@ -215,7 +214,7 @@ def _create_scan_field_operator( "fieldop", state, ndrange={ - dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" + gtir_sdfg_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" for dim, lower_bound, upper_bound in domain if not sdfg_builder.is_column_axis(dim) }, @@ -234,7 +233,7 @@ def _create_scan_field_operator( # handle tuples of fields # the symbol name 'x' in the call below is not used, we only need # the tree structure of the `TupleType` definition to pass to `tree_map()` - output_symbol_tree = dace_gtir_utils.make_symbol_tree("x", node_type) + output_symbol_tree = gtir_sdfg_utils.make_symbol_tree("x", node_type) return gtx_utils.tree_map( lambda output_edge, output_sym: ( _create_scan_field_operator_impl( @@ -317,7 +316,7 @@ def _lower_lambda_to_nested_sdfg( # the lambda expression, i.e. body of the scan, will be created inside a nested SDFG. nsdfg = dace.SDFG(sdfg_builder.unique_nsdfg_name(sdfg, "scan")) - nsdfg.debuginfo = dace_utils.debug_info(lambda_node, default=sdfg.debuginfo) + nsdfg.debuginfo = gtir_sdfg_utils.debug_info(lambda_node, default=sdfg.debuginfo) lambda_translator = sdfg_builder.setup_nested_context( lambda_node, nsdfg, lambda_symbols, lambda_field_offsets ) @@ -332,7 +331,7 @@ def _lower_lambda_to_nested_sdfg( scan_dim, scan_lower_bound, scan_upper_bound = scan_domain[0] # extract the scan loop range - scan_loop_var = dace_gtir_utils.get_map_variable(scan_dim) + scan_loop_var = gtir_sdfg_utils.get_map_variable(scan_dim) # in case the scan operator computes a list (not a scalar), we need to add an extra dimension def get_scan_output_shape( @@ -358,7 +357,7 @@ def get_scan_output_shape( # This dataflow will write the initial value of the scan carry variable. init_state = nsdfg.add_state("scan_init", is_start_block=True) scan_carry_input = ( - dace_gtir_utils.make_symbol_tree(scan_carry_symbol.id, scan_carry_symbol.type) + gtir_sdfg_utils.make_symbol_tree(scan_carry_symbol.id, scan_carry_symbol.type) if isinstance(scan_carry_symbol.type, ts.TupleType) else scan_carry_symbol ) @@ -620,7 +619,7 @@ def translate_scan( if isinstance(scan_carry_type, ts.TupleType): lambda_flat_outs = { str(sym.id): sym.type - for sym in dace_gtir_utils.flatten_tuple_fields( + for sym in gtir_sdfg_utils.flatten_tuple_fields( _scan_output_name(scan_carry), scan_carry_type ) } diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py similarity index 96% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py rename to src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py index 2139ffe578..b306a59305 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py @@ -31,10 +31,10 @@ from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.iterator.transforms import prune_casts as ir_prune_casts, symbol_ref_utils from gt4py.next.iterator.type_system import inference as gtir_type_inference -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( gtir_builtin_translators, - utility as dace_gtir_utils, + gtir_sdfg_utils, + utils as gtx_dace_utils, ) from gt4py.next.type_system import type_specifications as ts, type_translation as tt @@ -332,17 +332,17 @@ def _make_array_shape_and_strides( Two lists of symbols, one for the shape and the other for the strides of the array. """ dc_dtype = gtir_builtin_translators.INDEX_DTYPE - neighbor_table_types = dace_utils.filter_connectivity_types(self.offset_provider_type) + neighbor_table_types = gtx_dace_utils.filter_connectivity_types(self.offset_provider_type) shape = [ ( neighbor_table_types[dim.value].max_neighbors if dim.kind == gtx_common.DimensionKind.LOCAL - else dace.symbol(dace_utils.field_size_symbol_name(name, i), dc_dtype) + else dace.symbol(gtx_dace_utils.field_size_symbol_name(name, i), dc_dtype) ) for i, dim in enumerate(dims) ] strides = [ - dace.symbol(dace_utils.field_stride_symbol_name(name, i), dc_dtype) + dace.symbol(gtx_dace_utils.field_stride_symbol_name(name, i), dc_dtype) for i in range(len(dims)) ] return shape, strides @@ -383,7 +383,7 @@ def _add_storage( """ if isinstance(gt_type, ts.TupleType): tuple_fields = [] - for sym in dace_gtir_utils.flatten_tuple_fields(name, gt_type): + for sym in gtir_sdfg_utils.flatten_tuple_fields(name, gt_type): assert isinstance(sym.type, ts.DataType) tuple_fields.extend( self._add_storage(sdfg, symbolic_arguments, sym.id, sym.type, transient) @@ -397,7 +397,7 @@ def _add_storage( if not isinstance(gt_type.dtype, ts.ScalarType): raise ValueError(f"Field type '{gt_type.dtype}' not supported.") # handle default case: field with one or more dimensions - dc_dtype = dace_utils.as_dace_type(gt_type.dtype) + dc_dtype = gtx_dace_utils.as_dace_type(gt_type.dtype) # Use symbolic shape, which allows to invoke the program with fields of different size; # and symbolic strides, which enables decoupling the memory layout from generated code. sym_shape, sym_strides = self._make_array_shape_and_strides(name, gt_type.dims) @@ -405,8 +405,8 @@ def _add_storage( return [(name, gt_type)] elif isinstance(gt_type, ts.ScalarType): - dc_dtype = dace_utils.as_dace_type(gt_type) - if dace_utils.is_field_symbol(name) or name in symbolic_arguments: + dc_dtype = gtx_dace_utils.as_dace_type(gt_type) + if gtx_dace_utils.is_field_symbol(name) or name in symbolic_arguments: if name in sdfg.symbols: # Sometimes, when the field domain is implicitly derived from the # field domain, the gt4py lowering adds the field size as a scalar @@ -416,7 +416,7 @@ def _add_storage( # created by `_make_array_shape_and_strides()`, when allocating # storage for field arguments. We assume that the scalar argument # for field size, if present, always follows the field argument. - assert dace_utils.is_field_symbol(name) + assert gtx_dace_utils.is_field_symbol(name) if sdfg.symbols[name].dtype != dc_dtype: raise ValueError( f"Type mismatch on argument {name}: got {dc_dtype}, expected {sdfg.symbols[name].dtype}." @@ -501,7 +501,7 @@ def _add_sdfg_params( self.global_symbols[pname] = param.type # add SDFG storage for connectivity tables - for offset, connectivity_type in dace_utils.filter_connectivity_types( + for offset, connectivity_type in gtx_dace_utils.filter_connectivity_types( self.offset_provider_type ).items(): scalar_type = tt.from_dtype(connectivity_type.dtype) @@ -514,7 +514,10 @@ def _add_sdfg_params( # the connectivity tables that are not used. The remaining unused transient arrays # are removed by the dace simplify pass. self._add_storage( - sdfg, symbolic_arguments, dace_utils.connectivity_identifier(offset), gt_type + sdfg, + symbolic_arguments, + gtx_dace_utils.connectivity_identifier(offset), + gt_type, ) # the list of all sdfg arguments (aka non-transient arrays) which include tuple-element fields @@ -537,7 +540,7 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: assert len(self.field_offsets) == 0 sdfg = dace.SDFG(node.id) - sdfg.debuginfo = dace_utils.debug_info(node) + sdfg.debuginfo = gtir_sdfg_utils.debug_info(node) # start block of the stateful graph entry_state = sdfg.add_state("program_entry", is_start_block=True) @@ -560,7 +563,7 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: for i, stmt in enumerate(node.body): # include `debuginfo` only for `ir.Program` and `ir.Stmt` nodes: finer granularity would be too messy head_state = sdfg.add_state_after(head_state, f"stmt_{i}") - head_state._debuginfo = dace_utils.debug_info(stmt, default=sdfg.debuginfo) + head_state._debuginfo = gtir_sdfg_utils.debug_info(stmt, default=sdfg.debuginfo) head_state = self.visit(stmt, sdfg=sdfg, state=head_state) # remove unused connectivity tables (by design, arrays are marked as non-transient when they are used) @@ -568,7 +571,7 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: unused_connectivities = [ data for data, datadesc in nsdfg.arrays.items() - if dace_utils.is_connectivity_identifier(data, self.offset_provider_type) + if gtx_dace_utils.is_connectivity_identifier(data, self.offset_provider_type) and datadesc.transient ] for data in unused_connectivities: @@ -757,7 +760,7 @@ def get_field_domain_offset( elif field_domain_offset := self.field_offsets.get(p_name, None): return {p_name: field_domain_offset} elif isinstance(p_type, ts.TupleType): - tsyms = dace_gtir_utils.flatten_tuple_fields(p_name, p_type) + tsyms = gtir_sdfg_utils.flatten_tuple_fields(p_name, p_type) return functools.reduce( lambda field_offsets, sym: ( field_offsets | get_field_domain_offset(sym.id, sym.type) # type: ignore[arg-type] @@ -774,7 +777,7 @@ def get_field_domain_offset( # lower let-statement lambda node as a nested SDFG nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) - nsdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) + nsdfg.debuginfo = gtir_sdfg_utils.debug_info(node, default=sdfg.debuginfo) lambda_translator = self.setup_nested_context( node.expr, nsdfg, lambda_symbols, lambda_field_offsets ) @@ -792,8 +795,8 @@ def get_field_domain_offset( # we they are stored as non-transient array and scalar objects. # connectivity_arrays = { - dace_utils.connectivity_identifier(offset) - for offset in dace_utils.filter_connectivity_types(self.offset_provider_type) + gtx_dace_utils.connectivity_identifier(offset) + for offset in gtx_dace_utils.filter_connectivity_types(self.offset_provider_type) } input_memlets = {} @@ -858,7 +861,7 @@ def get_field_domain_offset( inputs=set(input_memlets.keys()), outputs=lambda_outputs, symbol_mapping=nsdfg_symbols_mapping, - debuginfo=dace_utils.debug_info(node, default=sdfg.debuginfo), + debuginfo=gtir_sdfg_utils.debug_info(node, default=sdfg.debuginfo), ) for connector, memlet in input_memlets.items(): @@ -965,7 +968,7 @@ def build_sdfg_from_gtir( # such as arrays and scalars. GT4Py uses a unicode symbols ('ᐞ') as name # separator in the SSA pass, which generates invalid symbols for DaCe. # Here we find new names for invalid symbols present in the IR. - ir = dace_gtir_utils.replace_invalid_symbols(ir) + ir = gtir_sdfg_utils.replace_invalid_symbols(ir) sdfg_genenerator = GTIRToSDFG(offset_provider_type, column_axis) sdfg = sdfg_genenerator.visit(ir) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg_utils.py similarity index 86% rename from src/gt4py/next/program_processors/runners/dace_fieldview/utility.py rename to src/gt4py/next/program_processors/runners/dace/gtir_sdfg_utils.py index 6121529161..9a27cad21c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg_utils.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Dict, TypeVar +from typing import Dict, Optional, TypeVar import dace @@ -19,6 +19,22 @@ from gt4py.next.type_system import type_specifications as ts +def debug_info( + node: gtir.Node, *, default: Optional[dace.dtypes.DebugInfo] = None +) -> Optional[dace.dtypes.DebugInfo]: + """Include the GT4Py node location as debug information in the corresponding SDFG nodes.""" + location = node.location + if location: + return dace.dtypes.DebugInfo( + start_line=location.line, + start_column=location.column if location.column else 0, + end_line=location.end_line if location.end_line else -1, + end_column=location.end_column if location.end_column else 0, + filename=location.filename, + ) + return default + + def get_map_variable(dim: gtx_common.Dimension) -> str: """ Format map variable name based on the naming convention for application-specific SDFG transformations. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/program.py b/src/gt4py/next/program_processors/runners/dace/program.py similarity index 91% rename from src/gt4py/next/program_processors/runners/dace_fieldview/program.py rename to src/gt4py/next/program_processors/runners/dace/program.py index 7f809152c5..78016db0a9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -20,7 +20,7 @@ from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.iterator.transforms import extractors as extractors from gt4py.next.otf import arguments, recipes, toolchain -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils +from gt4py.next.program_processors.runners.dace import utils as gtx_dace_utils from gt4py.next.type_system import type_specifications as ts @@ -152,15 +152,15 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ in_arrays_with_id = [ (name, conn_id) for name in with_table - if (conn_id := dace_utils.connectivity_identifier(name)) + if (conn_id := gtx_dace_utils.connectivity_identifier(name)) in self.sdfg_closure_cache["arrays"] ] in_arrays = (name for name, _ in in_arrays_with_id) name_axis = list(itertools.product(in_arrays, [0, 1])) def size_symbol_name(name: str, axis: int) -> str: - return dace_utils.field_size_symbol_name( - dace_utils.connectivity_identifier(name), axis + return gtx_dace_utils.field_size_symbol_name( + gtx_dace_utils.connectivity_identifier(name), axis ) connectivity_tables_size_symbols = { @@ -169,8 +169,8 @@ def size_symbol_name(name: str, axis: int) -> str: } def stride_symbol_name(name: str, axis: int) -> str: - return dace_utils.field_stride_symbol_name( - dace_utils.connectivity_identifier(name), axis + return gtx_dace_utils.field_stride_symbol_name( + gtx_dace_utils.connectivity_identifier(name), axis ) connectivity_table_stride_symbols = { @@ -196,12 +196,12 @@ def stride_symbol_name(name: str, axis: int) -> str: self.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( dtype=dace.dtypes.dtype_to_typeclass(conn.dtype.dtype.type), shape=[ - symbols[dace_utils.field_size_symbol_name(conn_id, 0)], - symbols[dace_utils.field_size_symbol_name(conn_id, 1)], + symbols[gtx_dace_utils.field_size_symbol_name(conn_id, 0)], + symbols[gtx_dace_utils.field_size_symbol_name(conn_id, 1)], ], strides=[ - symbols[dace_utils.field_stride_symbol_name(conn_id, 0)], - symbols[dace_utils.field_stride_symbol_name(conn_id, 1)], + symbols[gtx_dace_utils.field_stride_symbol_name(conn_id, 0)], + symbols[gtx_dace_utils.field_stride_symbol_name(conn_id, 1)], ], storage=Program.connectivity_tables_data_descriptors["storage"], ) @@ -221,7 +221,7 @@ def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: li ): match dace_parsed_arg: case dace.data.Scalar(): - assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg) + assert dace_parsed_arg.dtype == gtx_dace_utils.as_dace_type(gt4py_program_arg) case bool() | np.bool_(): assert isinstance(gt4py_program_arg, ts.ScalarType) assert gt4py_program_arg.kind == ts.ScalarKind.BOOL @@ -238,7 +238,7 @@ def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: li assert isinstance(gt4py_program_arg, ts.FieldType) assert isinstance(gt4py_program_arg.dtype, ts.ScalarType) assert len(dace_parsed_arg.shape) == len(gt4py_program_arg.dims) - assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg.dtype) + assert dace_parsed_arg.dtype == gtx_dace_utils.as_dace_type(gt4py_program_arg.dtype) case dace.data.Structure() | dict() | collections.OrderedDict(): # offset provider pass diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py similarity index 98% rename from src/gt4py/next/program_processors/runners/dace_common/dace_backend.py rename to src/gt4py/next/program_processors/runners/dace/sdfg_callable.py index 387619c667..ecd23619e5 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py @@ -15,7 +15,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common as gtx_common -from . import utility as dace_utils +from . import utils as gtx_dace_utils try: @@ -116,7 +116,7 @@ def get_sdfg_conn_args( connectivity_args = {} for offset, connectivity in offset_provider.items(): if gtx_common.is_neighbor_table(connectivity): - param = dace_utils.connectivity_identifier(offset) + param = gtx_dace_utils.connectivity_identifier(offset) if param in sdfg.arrays: connectivity_args[param] = _ensure_is_on_device(connectivity.ndarray, device) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py similarity index 96% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py rename to src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index 0902bd665a..c8e1cf292f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -42,7 +42,7 @@ gt_propagate_strides_from_access_node, gt_propagate_strides_of, ) -from .util import gt_find_constant_arguments, gt_make_transients_persistent +from .utils import gt_find_constant_arguments, gt_make_transients_persistent __all__ = [ diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py similarity index 99% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py rename to src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 4a06d2f416..849730db76 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -15,9 +15,7 @@ from dace.transformation.auto import auto_optimize as dace_aoptimize from gt4py.next import common as gtx_common -from gt4py.next.program_processors.runners.dace_fieldview import ( - transformations as gtx_transformations, -) +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations def gt_auto_optimize( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py similarity index 99% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py rename to src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index 7b14144ead..8bae56cd88 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -22,9 +22,7 @@ from dace.codegen.targets import cpp as dace_cpp from dace.sdfg import nodes as dace_nodes -from gt4py.next.program_processors.runners.dace_fieldview import ( - transformations as gtx_transformations, -) +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations def gt_gpu_transformation( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/local_double_buffering.py b/src/gt4py/next/program_processors/runners/dace/transformations/local_double_buffering.py similarity index 98% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/local_double_buffering.py rename to src/gt4py/next/program_processors/runners/dace/transformations/local_double_buffering.py index 52f1de3d0c..02ecbe28e6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/local_double_buffering.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/local_double_buffering.py @@ -18,9 +18,7 @@ ) from dace.sdfg import nodes as dace_nodes -from gt4py.next.program_processors.runners.dace_fieldview import ( - transformations as gtx_transformations, -) +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations def gt_create_local_double_buffering( @@ -243,7 +241,7 @@ def _check_if_map_must_be_handled_classify_adjacent_access_node( # Currently we do not handle view, as they need to be traced. # TODO(phimuell): Implement - if gtx_transformations.util.is_view(data_desc, sdfg): + if gtx_transformations.utils.is_view(data_desc, sdfg): return False # TODO(phimuell): Check if there is a access node on the inner side, then we do not have to do it. @@ -357,7 +355,7 @@ def _check_if_map_must_be_handled( if ( len(inner_read_edges) == 1 and isinstance(inner_read_edges[0].dst, dace_nodes.AccessNode) - and not gtx_transformations.util.is_view(inner_read_edges[0].dst, sdfg) + and not gtx_transformations.utils.is_view(inner_read_edges[0].dst, sdfg) ): inout_datas.pop(inout_data_name) continue diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py similarity index 99% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py rename to src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py index 27b6c68072..344b0b8c22 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py @@ -19,7 +19,7 @@ from dace.transformation import helpers as dace_helpers from gt4py.next import common as gtx_common -from gt4py.next.program_processors.runners.dace_fieldview import utility as gtx_dace_fieldview_util +from gt4py.next.program_processors.runners.dace import gtir_sdfg_utils @dace_properties.make_properties @@ -83,7 +83,7 @@ def __init__( ) -> None: super().__init__() if isinstance(blocking_parameter, gtx_common.Dimension): - blocking_parameter = gtx_dace_fieldview_util.get_map_variable(blocking_parameter) + blocking_parameter = gtir_sdfg_utils.get_map_variable(blocking_parameter) if blocking_parameter is not None: self.blocking_parameter = blocking_parameter if blocking_size is not None: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_helper.py similarity index 100% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py rename to src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_helper.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_parallel.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_parallel.py similarity index 100% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_parallel.py rename to src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_parallel.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_serial.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py similarity index 100% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_serial.py rename to src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_orderer.py similarity index 97% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py rename to src/gt4py/next/program_processors/runners/dace/transformations/map_orderer.py index 8fb41c7d0a..23dcbf8ef7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_orderer.py @@ -13,7 +13,7 @@ from dace.sdfg import nodes as dace_nodes from gt4py.next import common as gtx_common -from gt4py.next.program_processors.runners.dace_fieldview import utility as gtx_dace_fieldview_util +from gt4py.next.program_processors.runners.dace import gtir_sdfg_utils def gt_set_iteration_order( @@ -107,7 +107,7 @@ def __init__( self.leading_dims = [ leading_dim if isinstance(leading_dim, str) - else gtx_dace_fieldview_util.get_map_variable(leading_dim) + else gtir_sdfg_utils.get_map_variable(leading_dim) for leading_dim in leading_dims ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_promoter.py similarity index 99% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py rename to src/gt4py/next/program_processors/runners/dace/transformations/map_promoter.py index 46d46c4bbe..90ad67e7cb 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_promoter.py @@ -17,9 +17,7 @@ ) from dace.sdfg import nodes as dace_nodes -from gt4py.next.program_processors.runners.dace_fieldview import ( - transformations as gtx_transformations, -) +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations @dace_properties.make_properties diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py similarity index 98% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py rename to src/gt4py/next/program_processors/runners/dace/transformations/simplify.py index cc845505c9..e798df4596 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py @@ -27,9 +27,7 @@ passes as dace_passes, ) -from gt4py.next.program_processors.runners.dace_fieldview import ( - transformations as gtx_transformations, -) +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations GT_SIMPLIFY_DEFAULT_SKIP_SET: Final[set[str]] = {"ScalarToSymbolPromotion", "ConstantPropagation"} @@ -336,7 +334,7 @@ def _is_read_downstream( write_g: dace_nodes.AccessNode = self.node_write_g tmp_node: dace_nodes.AccessNode = self.node_tmp - return gtx_transformations.util.is_accessed_downstream( + return gtx_transformations.utils.is_accessed_downstream( start_state=start_state, sdfg=sdfg, data_to_look=data_to_look, @@ -591,7 +589,7 @@ def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: for def_node, def_state in def_locations: # Test if `temp_storage` is only accessed where it is defined and # where it is written back. - if gtx_transformations.util.is_accessed_downstream( + if gtx_transformations.utils.is_accessed_downstream( start_state=def_state, sdfg=sdfg, data_to_look=wb_node.data, @@ -607,7 +605,7 @@ def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: global_nodes_in_def_state = { dnode for dnode in def_state.data_nodes() if dnode.data == global_data_name } - if gtx_transformations.util.is_accessed_downstream( + if gtx_transformations.utils.is_accessed_downstream( start_state=def_state, sdfg=sdfg, data_to_look=global_data_name, @@ -891,7 +889,7 @@ def apply( # The data is no longer referenced in this state, so we can potentially # remove if graph.out_degree(access_node) == 0: - if not gtx_transformations.util.is_accessed_downstream( + if not gtx_transformations.utils.is_accessed_downstream( start_state=graph, sdfg=sdfg, data_to_look=access_node.data, @@ -1001,7 +999,7 @@ def can_be_applied( return False if graph.in_degree(tmp_ac) != 1: return False - if any(gtx_transformations.util.is_view(ac, sdfg) for ac in [tmp_ac, glob_ac]): + if any(gtx_transformations.utils.is_view(ac, sdfg) for ac in [tmp_ac, glob_ac]): return False if len(glob_desc.shape) != len(tmp_desc.shape): return False @@ -1017,7 +1015,7 @@ def can_be_applied( # Test if `tmp` is only anywhere else, this is important for removing it. if graph.out_degree(tmp_ac) != 1: return False - if gtx_transformations.util.is_accessed_downstream( + if gtx_transformations.utils.is_accessed_downstream( start_state=graph, sdfg=sdfg, data_to_look=tmp_ac.data, @@ -1067,7 +1065,7 @@ def _perform_pointwise_test( # Find the source of this data, if it is a view we trace it to # its origin. - src_node: dace_nodes.AccessNode = gtx_transformations.util.track_view( + src_node: dace_nodes.AccessNode = gtx_transformations.utils.track_view( in_edge.src, state, sdfg ) @@ -1089,7 +1087,7 @@ def _perform_pointwise_test( # Currently the only test that we do is, if we have a view, then we # are not point wise. # TODO(phimuell): Improve/implement this. - return any(gtx_transformations.util.is_view(node, sdfg) for node in conflicting_inputs) + return any(gtx_transformations.utils.is_view(node, sdfg) for node in conflicting_inputs) def apply( self, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace/transformations/strides.py similarity index 99% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py rename to src/gt4py/next/program_processors/runners/dace/transformations/strides.py index d1bf8fe266..9af76e5b57 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/strides.py @@ -12,9 +12,7 @@ from dace import data as dace_data from dace.sdfg import nodes as dace_nodes -from gt4py.next.program_processors.runners.dace_fieldview import ( - transformations as gtx_transformations, -) +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations PropagatedStrideRecord: TypeAlias = tuple[str, dace_nodes.NestedSDFG] @@ -653,7 +651,7 @@ def _gt_find_toplevel_data_accesses( top_level_data[data].append((state, dnode)) continue - elif gtx_transformations.util.is_view(dnode, sdfg): + elif gtx_transformations.utils.is_view(dnode, sdfg): # The AccessNode refers to a View so we ignore it anyway. continue diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py similarity index 98% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py rename to src/gt4py/next/program_processors/runners/dace/transformations/utils.py index 29c099eecf..87308061e7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -14,7 +14,7 @@ from dace import data as dace_data from dace.sdfg import nodes as dace_nodes -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils +from gt4py.next.program_processors.runners.dace import utils as gtx_dace_utils def gt_make_transients_persistent( @@ -114,7 +114,7 @@ def gt_find_constant_arguments( ret_value: dict[str, Any] = {} for name, value in call_args.items(): - if name in include or (dace_utils.is_field_symbol(name) and value == 1): + if name in include or (gtx_dace_utils.is_field_symbol(name) and value == 1): ret_value[name] = value return ret_value diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace/utils.py similarity index 79% rename from src/gt4py/next/program_processors/runners/dace_common/utility.py rename to src/gt4py/next/program_processors/runners/dace/utils.py index 4a3e5d4e4c..cca0c001e7 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace/utils.py @@ -9,12 +9,11 @@ from __future__ import annotations import re -from typing import Final, Literal, Optional +from typing import Final, Literal import dace from gt4py.next import common as gtx_common -from gt4py.next.iterator import ir as gtir from gt4py.next.type_system import type_specifications as ts @@ -78,22 +77,6 @@ def is_field_symbol(name: str) -> bool: return FIELD_SYMBOL_RE.match(name) is not None -def debug_info( - node: gtir.Node, *, default: Optional[dace.dtypes.DebugInfo] = None -) -> Optional[dace.dtypes.DebugInfo]: - """Include the GT4Py node location as debug information in the corresponding SDFG nodes.""" - location = node.location - if location: - return dace.dtypes.DebugInfo( - start_line=location.line, - start_column=location.column if location.column else 0, - end_line=location.end_line if location.end_line else -1, - end_column=location.end_column if location.end_column else 0, - filename=location.filename, - ) - return default - - def filter_connectivity_types( offset_provider_type: gtx_common.OffsetProviderType, ) -> dict[str, gtx_common.NeighborConnectivityType]: diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py b/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py new file mode 100644 index 0000000000..4d825c0c9b --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py @@ -0,0 +1,20 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements the On-The-Fly (OTF) compilation workflow for the GTIR-DaCe backend. + +The main module is `backend`, that exports the backends for CPU and GPU devices. +The `backend` module uses `factory` to define a workflow that implements the +`OTFCompileWorkflow` recipe. The different stages are implemeted in separate modules: +- `translation` for lowering of GTIR to SDFG and applying SDFG transformations +- `compilation` for compiling the SDFG into a program +- `decoration` to parse the program arguments and pass them to the program call + +The GTIR-DaCe backend factory extends `CachedBackendFactory`, thus it provides +caching of the GTIR program. +""" diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py similarity index 72% rename from src/gt4py/next/program_processors/runners/dace.py rename to src/gt4py/next/program_processors/runners/dace/workflow/backend.py index b7f419a749..55d7122767 100644 --- a/src/gt4py/next/program_processors/runners/dace.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py @@ -6,16 +6,18 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import factory -import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators +from gt4py._core import definitions as core_defs from gt4py.next import backend from gt4py.next.otf import stages, workflow -from gt4py.next.program_processors.runners.dace_fieldview import workflow as dace_fieldview_workflow +from gt4py.next.program_processors.runners.dace.workflow.factory import DaCeWorkflowFactory -class DaCeFieldviewBackendFactory(factory.Factory): +class DaCeBackendFactory(factory.Factory): class Meta: model = backend.Backend @@ -37,7 +39,7 @@ class Params: device_type = core_defs.DeviceType.CPU hash_function = stages.compilation_hash otf_workflow = factory.SubFactory( - dace_fieldview_workflow.DaCeWorkflowFactory, + DaCeWorkflowFactory, device_type=factory.SelfAttribute("..device_type"), auto_optimize=factory.SelfAttribute("..auto_optimize"), ) @@ -52,8 +54,8 @@ class Params: transforms = backend.DEFAULT_TRANSFORMS -run_dace_cpu = DaCeFieldviewBackendFactory(cached=True, auto_optimize=True) -run_dace_cpu_noopt = DaCeFieldviewBackendFactory(cached=True, auto_optimize=False) +run_dace_cpu = DaCeBackendFactory(cached=True, auto_optimize=True) +run_dace_cpu_noopt = DaCeBackendFactory(cached=True, auto_optimize=False) -run_dace_gpu = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=True) -run_dace_gpu_noopt = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=False) +run_dace_gpu = DaCeBackendFactory(gpu=True, cached=True, auto_optimize=True) +run_dace_gpu_noopt = DaCeBackendFactory(gpu=True, cached=True, auto_optimize=False) diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py similarity index 50% rename from src/gt4py/next/program_processors/runners/dace_common/workflow.py rename to src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index 6fb7539c92..c0d1c74c7a 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -8,19 +8,16 @@ from __future__ import annotations -import ctypes import dataclasses -from typing import Any, Sequence +from typing import Any import dace import factory -from dace.codegen.compiled_sdfg import _array_interface_ptr as get_array_interface_ptr from gt4py._core import definitions as core_defs -from gt4py.next import common, config, utils as gtx_utils -from gt4py.next.otf import arguments, languages, stages, step_types, workflow +from gt4py.next import config +from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.compilation import cache -from gt4py.next.program_processors.runners.dace_common import dace_backend, utility as dace_utils class CompiledDaceProgram(stages.ExtendedCompiledProgram): @@ -95,75 +92,3 @@ def __call__( class DaCeCompilationStepFactory(factory.Factory): class Meta: model = DaCeCompiler - - -def convert_args( - inp: CompiledDaceProgram, - device: core_defs.DeviceType = core_defs.DeviceType.CPU, - use_field_canonical_representation: bool = False, -) -> stages.CompiledProgram: - sdfg_program = inp.sdfg_program - sdfg = sdfg_program.sdfg - on_gpu = True if device in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM] else False - - def decorated_program( - *args: Any, - offset_provider: common.OffsetProvider, - out: Any = None, - ) -> None: - if out is not None: - args = (*args, out) - flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(tuple(args)) - if inp.implicit_domain: - # generate implicit domain size arguments only if necessary - size_args = arguments.iter_size_args(args) - flat_size_args: Sequence[int] = gtx_utils.flatten_nested_tuple(tuple(size_args)) - flat_args = (*flat_args, *flat_size_args) - - if sdfg_program._lastargs: - kwargs = dict(zip(sdfg.arg_names, flat_args, strict=True)) - kwargs.update(dace_backend.get_sdfg_conn_args(sdfg, offset_provider, on_gpu)) - - use_fast_call = True - last_call_args = sdfg_program._lastargs[0] - # The scalar arguments should be overridden with the new value; for field arguments, - # the data pointer should remain the same otherwise fast_call cannot be used and - # the arguments list has to be reconstructed. - for i, (arg_name, arg_type) in enumerate(inp.sdfg_arglist): - if isinstance(arg_type, dace.data.Array): - assert arg_name in kwargs, f"argument '{arg_name}' not found." - data_ptr = get_array_interface_ptr(kwargs[arg_name], arg_type.storage) - assert isinstance(last_call_args[i], ctypes.c_void_p) - if last_call_args[i].value != data_ptr: - use_fast_call = False - break - else: - assert isinstance(arg_type, dace.data.Scalar) - assert isinstance(last_call_args[i], ctypes._SimpleCData) - if arg_name in kwargs: - # override the scalar value used in previous program call - actype = arg_type.dtype.as_ctypes() - last_call_args[i] = actype(kwargs[arg_name]) - else: - # shape and strides of arrays are supposed not to change, and can therefore be omitted - assert dace_utils.is_field_symbol( - arg_name - ), f"argument '{arg_name}' not found." - - if use_fast_call: - return inp.fast_call() - - sdfg_args = dace_backend.get_sdfg_args( - sdfg, - offset_provider, - *flat_args, - check_args=False, - on_gpu=on_gpu, - use_field_canonical_representation=use_field_canonical_representation, - ) - - with dace.config.temporary_config(): - dace.config.Config.set("compiler", "allow_view_arguments", value=True) - return inp(**sdfg_args) - - return decorated_program diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py new file mode 100644 index 0000000000..2ee99f5fa4 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -0,0 +1,96 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import ctypes +from typing import Any, Sequence + +import dace +from dace.codegen.compiled_sdfg import _array_interface_ptr as get_array_interface_ptr + +from gt4py._core import definitions as core_defs +from gt4py.next import common, utils as gtx_utils +from gt4py.next.otf import arguments, stages +from gt4py.next.program_processors.runners.dace import ( + sdfg_callable, + utils as gtx_dace_utils, + workflow as dace_worflow, +) + + +def convert_args( + inp: dace_worflow.compilation.CompiledDaceProgram, + device: core_defs.DeviceType = core_defs.DeviceType.CPU, + use_field_canonical_representation: bool = False, +) -> stages.CompiledProgram: + sdfg_program = inp.sdfg_program + sdfg = sdfg_program.sdfg + on_gpu = True if device in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM] else False + + def decorated_program( + *args: Any, + offset_provider: common.OffsetProvider, + out: Any = None, + ) -> None: + if out is not None: + args = (*args, out) + flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(tuple(args)) + if inp.implicit_domain: + # generate implicit domain size arguments only if necessary + size_args = arguments.iter_size_args(args) + flat_size_args: Sequence[int] = gtx_utils.flatten_nested_tuple(tuple(size_args)) + flat_args = (*flat_args, *flat_size_args) + + if sdfg_program._lastargs: + kwargs = dict(zip(sdfg.arg_names, flat_args, strict=True)) + kwargs.update(sdfg_callable.get_sdfg_conn_args(sdfg, offset_provider, on_gpu)) + + use_fast_call = True + last_call_args = sdfg_program._lastargs[0] + # The scalar arguments should be overridden with the new value; for field arguments, + # the data pointer should remain the same otherwise fast_call cannot be used and + # the arguments list has to be reconstructed. + for i, (arg_name, arg_type) in enumerate(inp.sdfg_arglist): + if isinstance(arg_type, dace.data.Array): + assert arg_name in kwargs, f"argument '{arg_name}' not found." + data_ptr = get_array_interface_ptr(kwargs[arg_name], arg_type.storage) + assert isinstance(last_call_args[i], ctypes.c_void_p) + if last_call_args[i].value != data_ptr: + use_fast_call = False + break + else: + assert isinstance(arg_type, dace.data.Scalar) + assert isinstance(last_call_args[i], ctypes._SimpleCData) + if arg_name in kwargs: + # override the scalar value used in previous program call + actype = arg_type.dtype.as_ctypes() + last_call_args[i] = actype(kwargs[arg_name]) + else: + # shape and strides of arrays are supposed not to change, and can therefore be omitted + assert gtx_dace_utils.is_field_symbol( + arg_name + ), f"argument '{arg_name}' not found." + + if use_fast_call: + return inp.fast_call() + + sdfg_args = sdfg_callable.get_sdfg_args( + sdfg, + offset_provider, + *flat_args, + check_args=False, + on_gpu=on_gpu, + use_field_canonical_representation=use_field_canonical_representation, + ) + + with dace.config.temporary_config(): + dace.config.Config.set("compiler", "allow_view_arguments", value=True) + return inp(**sdfg_args) + + return decorated_program diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py new file mode 100644 index 0000000000..02a089c88c --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -0,0 +1,58 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import functools + +import factory + +from gt4py._core import definitions as core_defs +from gt4py.next import config +from gt4py.next.otf import recipes, stages +from gt4py.next.program_processors.runners.dace.workflow import decoration as decoration_step +from gt4py.next.program_processors.runners.dace.workflow.compilation import ( + DaCeCompilationStepFactory, +) +from gt4py.next.program_processors.runners.dace.workflow.translation import ( + DaCeTranslationStepFactory, +) + + +def _no_bindings(inp: stages.ProgramSource) -> stages.CompilableSource: + return stages.CompilableSource(program_source=inp, binding_source=None) + + +class DaCeWorkflowFactory(factory.Factory): + class Meta: + model = recipes.OTFCompileWorkflow + + class Params: + device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + cmake_build_type: config.CMakeBuildType = factory.LazyFunction( + lambda: config.CMAKE_BUILD_TYPE + ) + auto_optimize: bool = False + + translation = factory.SubFactory( + DaCeTranslationStepFactory, + device_type=factory.SelfAttribute("..device_type"), + auto_optimize=factory.SelfAttribute("..auto_optimize"), + ) + bindings = _no_bindings + compilation = factory.SubFactory( + DaCeCompilationStepFactory, + cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), + cmake_build_type=factory.SelfAttribute("..cmake_build_type"), + ) + decoration = factory.LazyAttribute( + lambda o: functools.partial( + decoration_step.convert_args, + device=o.device_type, + ) + ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py similarity index 71% rename from src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py rename to src/gt4py/next/program_processors/runners/dace/workflow/translation.py index a83654ebc9..96be93de5e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -9,20 +9,18 @@ from __future__ import annotations import dataclasses -import functools from typing import Optional import dace import factory from gt4py._core import definitions as core_defs -from gt4py.next import allocators as gtx_allocators, common, config +from gt4py.next import allocators as gtx_allocators, common from gt4py.next.iterator import ir as itir, transforms as itir_transforms -from gt4py.next.otf import languages, recipes, stages, step_types, workflow +from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import interface from gt4py.next.otf.languages import LanguageSettings -from gt4py.next.program_processors.runners.dace_common import workflow as dace_workflow -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( gtir_sdfg, transformations as gtx_transformations, ) @@ -106,37 +104,3 @@ def __call__( class DaCeTranslationStepFactory(factory.Factory): class Meta: model = DaCeTranslator - - -def _no_bindings(inp: stages.ProgramSource) -> stages.CompilableSource: - return stages.CompilableSource(program_source=inp, binding_source=None) - - -class DaCeWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFCompileWorkflow - - class Params: - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( - lambda: config.CMAKE_BUILD_TYPE - ) - auto_optimize: bool = False - - translation = factory.SubFactory( - DaCeTranslationStepFactory, - device_type=factory.SelfAttribute("..device_type"), - auto_optimize=factory.SelfAttribute("..auto_optimize"), - ) - bindings = _no_bindings - compilation = factory.SubFactory( - dace_workflow.DaCeCompilationStepFactory, - cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), - cmake_build_type=factory.SelfAttribute("..cmake_build_type"), - ) - decoration = factory.LazyAttribute( - lambda o: functools.partial( - dace_workflow.convert_args, - device=o.device_type, - ) - ) diff --git a/src/gt4py/next/program_processors/runners/dace_common/__init__.py b/src/gt4py/next/program_processors/runners/dace_common/__init__.py deleted file mode 100644 index abf4c3e24c..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_common/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py deleted file mode 100644 index 602453fc5a..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - - -from gt4py.next.program_processors.runners.dace_common.dace_backend import get_sdfg_args -from gt4py.next.program_processors.runners.dace_fieldview.gtir_sdfg import build_sdfg_from_gtir - - -__all__ = [ - "build_sdfg_from_gtir", - "get_sdfg_args", -] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 225d22562f..faf611878d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -37,7 +37,7 @@ from . import pytestmark -dace_backend = pytest.importorskip("gt4py.next.program_processors.runners.dace_fieldview") +dace_backend = pytest.importorskip("gt4py.next.program_processors.runners.dace") N = 10 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py index 04a4f098ef..8177ea9ae7 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py @@ -11,7 +11,7 @@ dace = pytest.importorskip("dace") from dace.sdfg import nodes as dace_nodes -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( transformations as gtx_transformations, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py index 3d9201c603..88786ee0e3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py @@ -14,7 +14,7 @@ from dace.sdfg import nodes as dace_nodes from dace import data as dace_data -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( transformations as gtx_transformations, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py index 709079dd0d..9241bae4bf 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py @@ -9,7 +9,7 @@ import pytest import numpy as np -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( transformations as gtx_transformations, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py index 4ca44d43eb..1d98fef8c4 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py @@ -11,7 +11,7 @@ dace = pytest.importorskip("dace") from dace.sdfg import nodes as dace_nodes -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( transformations as gtx_transformations, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index 89f067e5a9..cdc66d4ffd 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -14,7 +14,7 @@ dace = pytest.importorskip("dace") from dace.sdfg import nodes as dace_nodes -from gt4py.next.program_processors.runners.dace_fieldview.transformations import ( +from gt4py.next.program_processors.runners.dace.transformations import ( gpu_utils as gtx_dace_fieldview_gpu_utils, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py index 67bec9c09f..a08cf12a5a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py @@ -16,7 +16,7 @@ dace = pytest.importorskip("dace") from dace.sdfg import nodes as dace_nodes, propagation as dace_propagation -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( transformations as gtx_transformations, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py index a98eac3c2c..f2c31a7188 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py @@ -13,7 +13,7 @@ dace = pytest.importorskip("dace") from dace.sdfg import nodes as dace_nodes -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( transformations as gtx_transformations, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py index b468b80b8e..516a70b579 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py @@ -17,7 +17,7 @@ from dace.sdfg import nodes as dace_nodes from dace.transformation import dataflow as dace_dataflow -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( transformations as gtx_transformations, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py index 72efc2fe34..762040e20d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py @@ -9,7 +9,7 @@ import pytest import numpy as np -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( transformations as gtx_transformations, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py index 7b39bc4e1d..7718977d53 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py @@ -14,7 +14,7 @@ from dace.sdfg import nodes as dace_nodes, propagation as dace_propagation from dace.transformation import dataflow as dace_dataflow -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( transformations as gtx_transformations, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py index 8626cb8e07..fa7c7255e3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py @@ -12,7 +12,7 @@ dace = pytest.importorskip("dace") from dace.sdfg import nodes as dace_nodes -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( transformations as gtx_transformations, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index 19b33d0bef..c89fe566c0 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -14,7 +14,7 @@ from dace import symbolic as dace_symbolic from dace.sdfg import nodes as dace_nodes -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( transformations as gtx_transformations, ) From 758ee037da2c15ce89510e9314eaa47507c9452a Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 22 Jan 2025 10:02:35 +0100 Subject: [PATCH 109/178] test[cartesian]: Increased coverage for horizontal regions (#1807) Working on the gt4py/dace bridge showed that code coverage for horizontal regions is low. In particular (code generation) tests for conditionals inside horizontal regions were missing. --- .../stencil_definitions.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py index 1a8cfef695..e1d9a0061a 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py @@ -13,6 +13,8 @@ __INLINED, BACKWARD, FORWARD, + I, + J, PARALLEL, acos, acosh, @@ -28,6 +30,7 @@ exp, floor, gamma, + horizontal, interval, isfinite, isinf, @@ -35,6 +38,7 @@ log, log10, mod, + region, sin, sinh, sqrt, @@ -402,3 +406,23 @@ def two_optional_fields( out_a = out_a + dt * phys_tend_a if __INLINED(PHYS_TEND_B): out_b = out_b + dt * phys_tend_b + + +@register +def horizontal_regions(field_in: Field3D, field_out: Field3D): + with computation(PARALLEL), interval(...): + with horizontal(region[I[0] : I[2], J[0] : J[2]], region[I[-3] : I[-1], J[-3] : J[-1]]): + field_out = field_in + 1.0 + + with horizontal(region[I[0] : I[2], J[-3] : J[-1]], region[I[-3] : I[-1], J[0] : J[2]]): + field_out = field_in - 1.0 + + +@register +def horizontal_region_with_conditional(field_in: Field3D, field_out: Field3D): + with computation(PARALLEL), interval(...): + with horizontal(region[I[0] : I[2], J[0] : J[2]], region[I[-3] : I[-1], J[-3] : J[-1]]): + if field_in > 0: + field_out = field_in + 1.0 + else: + field_out = 0 From 57ad892e4a73957edba6df1b0fcc041dc4fd97eb Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 22 Jan 2025 15:54:56 +0100 Subject: [PATCH 110/178] style[cartesian]: fixing typos (#1815) ## Description Fixing typos (in comments) in the cartesian part of the codebase. No code was touched. ## Requirements - [ ] All fixes and/or new features come with corresponding tests: N/A - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder: N/A Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- src/gt4py/cartesian/backend/module_generator.py | 4 +--- src/gt4py/cartesian/caching.py | 2 +- src/gt4py/cartesian/frontend/nodes.py | 1 - src/gt4py/cartesian/gtc/common.py | 2 +- src/gt4py/cartesian/gtc/cuir/cuir_codegen.py | 2 +- src/gt4py/cartesian/gtc/dace/expansion/expansion.py | 1 - src/gt4py/cartesian/gtc/definitions.py | 4 ++-- src/gt4py/cartesian/gtc/gtir_to_oir.py | 1 - src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py | 3 +-- src/gt4py/cartesian/gtscript_imports.py | 3 +-- src/gt4py/cartesian/testing/suites.py | 2 +- src/gt4py/eve/__init__.py | 1 - src/gt4py/eve/datamodels/__init__.py | 1 - 13 files changed, 9 insertions(+), 18 deletions(-) diff --git a/src/gt4py/cartesian/backend/module_generator.py b/src/gt4py/cartesian/backend/module_generator.py index e2266b709c..8cc63ae34e 100644 --- a/src/gt4py/cartesian/backend/module_generator.py +++ b/src/gt4py/cartesian/backend/module_generator.py @@ -62,8 +62,6 @@ def parameter_names(self) -> Set[str]: def make_args_data_from_gtir(pipeline: GtirPipeline) -> ModuleData: """ Compute module data containing information about stencil arguments from gtir. - - This is no longer compatible with the legacy backends. """ if pipeline.stencil_id in _args_data_cache: return _args_data_cache[pipeline.stencil_id] @@ -142,7 +140,7 @@ def __call__( """ Generate source code for a Python module containing a StencilObject. - A possible reaosn for extending is processing additional kwargs, + A possible reason for extending is processing additional kwargs, using a different template might require completely overriding. """ if builder: diff --git a/src/gt4py/cartesian/caching.py b/src/gt4py/cartesian/caching.py index 20c0b49fae..2df2589ded 100644 --- a/src/gt4py/cartesian/caching.py +++ b/src/gt4py/cartesian/caching.py @@ -61,7 +61,7 @@ def generate_cache_info(self) -> Dict[str, Any]: """ Generate the cache info dict. - Backend specific additions can be added via a hook propery on the backend instance. + Backend specific additions can be added via a hook properly on the backend instance. Override :py:meth:`gt4py.backend.base.Backend.extra_cache_info` to store extra info. """ diff --git a/src/gt4py/cartesian/frontend/nodes.py b/src/gt4py/cartesian/frontend/nodes.py index f84577e7b5..2ca9e8fe1f 100644 --- a/src/gt4py/cartesian/frontend/nodes.py +++ b/src/gt4py/cartesian/frontend/nodes.py @@ -130,7 +130,6 @@ parameters: List[VarDecl], computations: List[ComputationBlock], [externals: Dict[str, Any], sources: Dict[str, str]]) - """ from __future__ import annotations diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index 8c3c731c75..7b2fbc93d8 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -60,7 +60,7 @@ class AssignmentKind(eve.StrEnum): @enum.unique class UnaryOperator(eve.StrEnum): - """Unary operator indentifier.""" + """Unary operator identifier.""" POS = "+" NEG = "-" diff --git a/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py b/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py index 76f076874a..96149a1723 100644 --- a/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py +++ b/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py @@ -592,7 +592,7 @@ def ctype(symbol: str) -> str: @classmethod def apply(cls, root: LeafNode, **kwargs: Any) -> str: if not isinstance(root, cuir.Program): - raise ValueError("apply() requires gtcpp.Progam root node") + raise ValueError("apply() requires gtcpp.Program root node") generated_code = super().apply(root, **kwargs) if kwargs.get("format_source", True): generated_code = codegen.format_source("cpp", generated_code, style="LLVM") diff --git a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py index 055bf64015..27f55d451d 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py @@ -74,7 +74,6 @@ def _fix_context( * change connector names to match inner array name (before expansion prefixed to satisfy uniqueness) * change in- and out-edges' subsets so that they have the same shape as the corresponding array inside * determine the domain size based on edges to StencilComputation - """ # change connector names for in_edge in parent_state.in_edges(node): diff --git a/src/gt4py/cartesian/gtc/definitions.py b/src/gt4py/cartesian/gtc/definitions.py index 16c7fbc46a..467ba04f99 100644 --- a/src/gt4py/cartesian/gtc/definitions.py +++ b/src/gt4py/cartesian/gtc/definitions.py @@ -118,7 +118,7 @@ def __add__(self, other): return self._apply(self._broadcast(other), operator.add) def __sub__(self, other): - """Element-wise substraction.""" + """Element-wise subtraction.""" return self._apply(self._broadcast(other), operator.sub) def __mul__(self, other): @@ -335,7 +335,7 @@ def __add__(self, other): return self._apply(self._broadcast(other), lambda a, b: a + b) def __sub__(self, other): - """Element-wise substraction.""" + """Element-wise subtraction.""" return self._apply(self._broadcast(other), lambda a, b: a - b) def __and__(self, other): diff --git a/src/gt4py/cartesian/gtc/gtir_to_oir.py b/src/gt4py/cartesian/gtc/gtir_to_oir.py index d36c2e5c4a..96f8077ec4 100644 --- a/src/gt4py/cartesian/gtc/gtir_to_oir.py +++ b/src/gt4py/cartesian/gtc/gtir_to_oir.py @@ -22,7 +22,6 @@ def validate_stencil_memory_accesses(node: oir.Stencil) -> oir.Stencil: at the OIR level. This is similar to the check at the gtir level for read-with-offset and writes, but more complete because it involves extent analysis, so it catches indirect read-with-offset through temporaries. - """ def _writes(node: oir.Stencil) -> Set[str]: diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py index f6c864aaba..fd09017720 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py @@ -39,7 +39,6 @@ Note that filling and flushing k-caches can always be replaced by a local (non-filling or flushing) k-cache plus additional filling and flushing statements. - """ @@ -261,7 +260,7 @@ class FillFlushToLocalKCaches(eve.NodeTranslator, eve.VisitorWithSymbolTableTrai For each cached field, the following actions are performed: 1. A new locally-k-cached temporary is introduced. 2. All accesses to the original field are replaced by accesses to this temporary. - 3. Loop sections are split where necessary to allow single-level loads whereever possible. + 3. Loop sections are split where necessary to allow single-level loads wherever possible. 3. Fill statements from the original field to the temporary are introduced. 4. Flush statements from the temporary to the original field are introduced. """ diff --git a/src/gt4py/cartesian/gtscript_imports.py b/src/gt4py/cartesian/gtscript_imports.py index 109f19759e..6fe49f18dd 100644 --- a/src/gt4py/cartesian/gtscript_imports.py +++ b/src/gt4py/cartesian/gtscript_imports.py @@ -23,13 +23,12 @@ gtscript_imports.enable( search_path=[, , ...], # for allowing only in search_path generate_path=, # for generating python modules in a specific dir - in_source=False, # set True to generate python modules next to gtscfipt files + in_source=False, # set True to generate python modules next to gtscript files ) # scoped usage with gtscript_imports.enabled(): import ... - """ import importlib diff --git a/src/gt4py/cartesian/testing/suites.py b/src/gt4py/cartesian/testing/suites.py index 48bead86e2..f680a1dbef 100644 --- a/src/gt4py/cartesian/testing/suites.py +++ b/src/gt4py/cartesian/testing/suites.py @@ -534,7 +534,7 @@ def _run_test_implementation(cls, parameters_dict, implementation): # too compl # call implementation implementation(**test_values, origin=origin, domain=domain, exec_info=exec_info) - # for validation data, data is cropped to actually touched domain, so that origin offseting + # for validation data, data is cropped to actually touched domain, so that origin offsetting # does not have to be implemented for every test suite. This is done based on info # specified in test suite cropped_validation_values = {} diff --git a/src/gt4py/eve/__init__.py b/src/gt4py/eve/__init__.py index e6044f15ef..e294108011 100644 --- a/src/gt4py/eve/__init__.py +++ b/src/gt4py/eve/__init__.py @@ -21,7 +21,6 @@ 7. visitors 8. traits 9. codegen - """ from __future__ import annotations diff --git a/src/gt4py/eve/datamodels/__init__.py b/src/gt4py/eve/datamodels/__init__.py index 6fd9c7bb21..5f6806c5dd 100644 --- a/src/gt4py/eve/datamodels/__init__.py +++ b/src/gt4py/eve/datamodels/__init__.py @@ -104,7 +104,6 @@ >>> CustomModel(3, 2) Instance 1 == 1.5 CustomModel(value=1.5) - """ from . import core as core, validators as validators # imported but unused From 2866f38bc6c93b527773661467acc55f4485ec0c Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 23 Jan 2025 00:51:59 +0100 Subject: [PATCH 111/178] feat[next]: Only inline scalars outside of stencils (#1794) The `InlineScalar` pass previously inlined all scalar expressions which is not needed and unnecessarily increases the tree size. Instead we only inline them in field view now. Co-authored-by: Hannes Vogt --- .../next/iterator/transforms/inline_scalar.py | 6 +++ .../transforms_tests/test_inline_scalar.py | 47 +++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_scalar.py diff --git a/src/gt4py/next/iterator/transforms/inline_scalar.py b/src/gt4py/next/iterator/transforms/inline_scalar.py index 87b576d14d..d8a6e14d8a 100644 --- a/src/gt4py/next/iterator/transforms/inline_scalar.py +++ b/src/gt4py/next/iterator/transforms/inline_scalar.py @@ -21,6 +21,12 @@ def apply(cls, program: itir.Program, offset_provider_type: common.OffsetProvide program = itir_inference.infer(program, offset_provider_type=offset_provider_type) return cls().visit(program) + def generic_visit(self, node, **kwargs): + if cpm.is_call_to(node, "as_fieldop"): + return node + + return super().generic_visit(node, **kwargs) + def visit_Expr(self, node: itir.Expr): node = self.generic_visit(node) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_scalar.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_scalar.py new file mode 100644 index 0000000000..3e655b71f4 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_scalar.py @@ -0,0 +1,47 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import pytest + +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.iterator.transforms import inline_scalar +from gt4py.next.iterator.ir_utils import ir_makers as im + +TDim = common.Dimension(value="TDim") +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) + + +def program_factory(expr: itir.Expr) -> itir.Program: + return itir.Program( + id="testee", + function_definitions=[], + params=[im.sym("out", ts.FieldType(dims=[TDim], dtype=int_type))], + declarations=[], + body=[ + itir.SetAt( + expr=expr, + target=im.ref("out"), + domain=im.domain(common.GridType.CARTESIAN, {TDim: (0, 1)}), + ) + ], + ) + + +def test_simple(): + testee = program_factory(im.let("a", 1)(im.op_as_fieldop("plus")("a", "a"))) + expected = program_factory(im.op_as_fieldop("plus")(1, 1)) + actual = inline_scalar.InlineScalar.apply(testee, offset_provider_type={}) + assert actual == expected + + +def test_fo_inline_only(): + scalar_expr = im.let("a", 1)(im.plus("a", "a")) + testee = program_factory(im.as_fieldop(im.lambda_()(scalar_expr))()) + actual = inline_scalar.InlineScalar.apply(testee, offset_provider_type={}) + assert actual == testee From 12f8ce86c713c78cbd4785eacccd6ef9e986bc6e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 24 Jan 2025 09:25:18 +0100 Subject: [PATCH 112/178] build: bump actions/upload-artifacts v3 -> v4 (#1821) `actions/upload-artifacts@v3` is deprecated and GitHub won't support it starting January 30th. Starting today, jobs using v3 were rejected, blocking merges. See [GithHub announcement](https://github.blog/changelog/2024-04-16-deprecation-notice-v3-of-the-artifact-actions/). Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- .github/workflows/deploy-release.yml | 2 +- .github/workflows/test-eve.yml | 4 ++-- .github/workflows/test-next.yml | 4 ++-- .github/workflows/test-storage.yml | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/deploy-release.yml b/.github/workflows/deploy-release.yml index 9ce6983de1..b519e008ec 100644 --- a/.github/workflows/deploy-release.yml +++ b/.github/workflows/deploy-release.yml @@ -26,7 +26,7 @@ jobs: run: | python -m build --sdist --wheel --outdir dist/ - name: Upload artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: gt4py-dist path: ./dist/** diff --git a/.github/workflows/test-eve.yml b/.github/workflows/test-eve.yml index 6b9f16e29b..9d48d50c03 100644 --- a/.github/workflows/test-eve.yml +++ b/.github/workflows/test-eve.yml @@ -50,7 +50,7 @@ jobs: tox run -e eve-py${pyversion_no_dot} # mv coverage.json coverage-py${{ matrix.python-version }}-${{ matrix.os }}.json # - name: Upload coverage.json artifact - # uses: actions/upload-artifact@v3 + # uses: actions/upload-artifact@v4 # with: # name: coverage-py${{ matrix.python-version }}-${{ matrix.os }} # path: coverage-py${{ matrix.python-version }}-${{ matrix.os }}.json @@ -64,7 +64,7 @@ jobs: # echo ${{ github.event.pull_request.head.sha }} >> info.txt # echo ${{ github.run_id }} >> info.txt # - name: Upload info artifact - # uses: actions/upload-artifact@v3 + # uses: actions/upload-artifact@v4 # with: # name: info-py${{ matrix.python-version }}-${{ matrix.os }} # path: info.txt diff --git a/.github/workflows/test-next.yml b/.github/workflows/test-next.yml index 35dcfe336b..1928370202 100644 --- a/.github/workflows/test-next.yml +++ b/.github/workflows/test-next.yml @@ -60,7 +60,7 @@ jobs: tox run -e next-py${pyversion_no_dot}-${{ matrix.tox-factor }}-cpu # mv coverage.json coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu.json # - name: Upload coverage.json artifact - # uses: actions/upload-artifact@v3 + # uses: actions/upload-artifact@v4 # with: # name: coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu # path: coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu.json @@ -74,7 +74,7 @@ jobs: # echo ${{ github.event.pull_request.head.sha }} >> info.txt # echo ${{ github.run_id }} >> info.txt # - name: Upload info artifact - # uses: actions/upload-artifact@v3 + # uses: actions/upload-artifact@v4 # with: # name: info-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu # path: info.txt diff --git a/.github/workflows/test-storage.yml b/.github/workflows/test-storage.yml index a7f3b69c8d..3748ac193e 100644 --- a/.github/workflows/test-storage.yml +++ b/.github/workflows/test-storage.yml @@ -52,7 +52,7 @@ jobs: tox run -e storage-py${pyversion_no_dot}-${{ matrix.tox-factor }}-cpu # mv coverage.json coverage-py${{ matrix.python-version }}-${{ matrix.os }}.json # - name: Upload coverage.json artifact - # uses: actions/upload-artifact@v3 + # uses: actions/upload-artifact@v4 # with: # name: coverage-py${{ matrix.python-version }}-${{ matrix.os }} # path: coverage-py${{ matrix.python-version }}-${{ matrix.os }}.json @@ -66,7 +66,7 @@ jobs: # echo ${{ github.event.pull_request.head.sha }} >> info.txt # echo ${{ github.run_id }} >> info.txt # - name: Upload info artifact - # uses: actions/upload-artifact@v3 + # uses: actions/upload-artifact@v4 # with: # name: info-py${{ matrix.python-version }}-${{ matrix.os }} # path: info.txt From 5588c85e2ae642ffe5122005f029a4d2b7405c63 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 27 Jan 2025 09:41:43 +0100 Subject: [PATCH 113/178] ci[cartesian]: mypy warns about unused ignores (#1823) ## Description `mypy` was configured not to report unused ignores in `gt4py.cartesian`. I could remove all unused ignores without any problem and thus removed this extra configuration. From now on, unused ignores will be reported. This is work towards https://github.com/GEOS-ESM/SMT-Nebulae/issues/89. ## Requirements - [x] All fixes and/or new features come with corresponding tests. `mypy` is happy with the change. - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- pyproject.toml | 1 - src/gt4py/cartesian/backend/cuda_backend.py | 2 +- src/gt4py/cartesian/backend/dace_backend.py | 2 +- src/gt4py/cartesian/backend/gtcpp_backend.py | 2 +- src/gt4py/cartesian/gtc/common.py | 10 +++++----- src/gt4py/cartesian/gtc/cuir/cuir.py | 8 ++++---- src/gt4py/cartesian/gtc/dace/daceir.py | 2 +- .../cartesian/gtc/dace/expansion/daceir_builder.py | 3 +-- src/gt4py/cartesian/gtc/gtcpp/gtcpp.py | 8 ++++---- src/gt4py/cartesian/gtc/gtir.py | 8 ++++---- src/gt4py/cartesian/gtc/oir.py | 8 ++++---- src/gt4py/cartesian/gtc/passes/gtir_upcaster.py | 2 +- src/gt4py/cartesian/stencil_builder.py | 5 +---- src/gt4py/cartesian/stencil_object.py | 2 +- 14 files changed, 29 insertions(+), 34 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a9f62b8ae7..979dfbbd02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -172,7 +172,6 @@ allow_incomplete_defs = true allow_untyped_defs = true follow_imports = 'silent' module = 'gt4py.cartesian.*' -warn_unused_ignores = false [[tool.mypy.overrides]] ignore_errors = true diff --git a/src/gt4py/cartesian/backend/cuda_backend.py b/src/gt4py/cartesian/backend/cuda_backend.py index f0238e309b..afa749e3f1 100644 --- a/src/gt4py/cartesian/backend/cuda_backend.py +++ b/src/gt4py/cartesian/backend/cuda_backend.py @@ -136,7 +136,7 @@ class CudaBackend(BaseGTBackend, CLIBackendMixin): } languages = {"computation": "cuda", "bindings": ["python"]} storage_info = gt_storage.layout.CUDALayout - PYEXT_GENERATOR_CLASS = CudaExtGenerator # type: ignore + PYEXT_GENERATOR_CLASS = CudaExtGenerator MODULE_GENERATOR_CLASS = CUDAPyExtModuleGenerator GT_BACKEND_T = "gpu" diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index a6d28f5994..35265f0530 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -760,7 +760,7 @@ class DaCeCUDAPyExtModuleGenerator(DaCePyExtModuleGenerator, CUDAPyExtModuleGene class BaseDaceBackend(BaseGTBackend, CLIBackendMixin): GT_BACKEND_T = "dace" - PYEXT_GENERATOR_CLASS = DaCeExtGenerator # type: ignore + PYEXT_GENERATOR_CLASS = DaCeExtGenerator def generate(self) -> Type[StencilObject]: self.check_options(self.builder.options) diff --git a/src/gt4py/cartesian/backend/gtcpp_backend.py b/src/gt4py/cartesian/backend/gtcpp_backend.py index 5d3fd623d9..8053409195 100644 --- a/src/gt4py/cartesian/backend/gtcpp_backend.py +++ b/src/gt4py/cartesian/backend/gtcpp_backend.py @@ -126,7 +126,7 @@ def apply(cls, root, *, module_name="stencil", **kwargs) -> str: class GTBaseBackend(BaseGTBackend, CLIBackendMixin): options = BaseGTBackend.GT_BACKEND_OPTS - PYEXT_GENERATOR_CLASS = GTExtGenerator # type: ignore + PYEXT_GENERATOR_CLASS = GTExtGenerator def _generate_extension(self, uses_cuda: bool) -> Tuple[str, str]: return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=uses_cuda) diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index 7b2fbc93d8..ef38a9a658 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -38,14 +38,14 @@ class GTCPreconditionError(eve.exceptions.EveError, RuntimeError): message_template = "GTC pass precondition error: [{info}]" def __init__(self, *, expected: str, **kwargs: Any) -> None: - super().__init__(expected=expected, **kwargs) # type: ignore + super().__init__(expected=expected, **kwargs) class GTCPostconditionError(eve.exceptions.EveError, RuntimeError): message_template = "GTC pass postcondition error: [{info}]" def __init__(self, *, expected: str, **kwargs: Any) -> None: - super().__init__(expected=expected, **kwargs) # type: ignore + super().__init__(expected=expected, **kwargs) class AssignmentKind(eve.StrEnum): @@ -267,7 +267,7 @@ def verify_and_get_common_dtype( ) -> Optional[DataType]: assert len(exprs) > 0 if all(e.dtype is not DataType.AUTO for e in exprs): - dtypes: List[DataType] = [e.dtype for e in exprs] # type: ignore # guaranteed to be not None + dtypes: List[DataType] = [e.dtype for e in exprs] # guaranteed to be not None dtype = dtypes[0] if strict: if all(dt == dtype for dt in dtypes): @@ -908,7 +908,7 @@ def op_to_ufunc( @functools.lru_cache(maxsize=None) def typestr_to_data_type(typestr: str) -> DataType: if not isinstance(typestr, str) or len(typestr) < 3 or not typestr[2:].isnumeric(): - return DataType.INVALID # type: ignore + return DataType.INVALID table = { ("b", 1): DataType.BOOL, ("i", 1): DataType.INT8, @@ -919,4 +919,4 @@ def typestr_to_data_type(typestr: str) -> DataType: ("f", 8): DataType.FLOAT64, } key = (typestr[1], int(typestr[2:])) - return table.get(key, DataType.INVALID) # type: ignore + return table.get(key, DataType.INVALID) diff --git a/src/gt4py/cartesian/gtc/cuir/cuir.py b/src/gt4py/cartesian/gtc/cuir/cuir.py index 62c3c520ac..fb6d28d071 100644 --- a/src/gt4py/cartesian/gtc/cuir/cuir.py +++ b/src/gt4py/cartesian/gtc/cuir/cuir.py @@ -32,11 +32,11 @@ class Stmt(common.Stmt): pass -class Literal(common.Literal, Expr): # type: ignore +class Literal(common.Literal, Expr): pass -class ScalarAccess(common.ScalarAccess, Expr): # type: ignore +class ScalarAccess(common.ScalarAccess, Expr): pass @@ -44,7 +44,7 @@ class VariableKOffset(common.VariableKOffset[Expr]): pass -class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr): # type: ignore +class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr): pass @@ -113,7 +113,7 @@ class TernaryOp(common.TernaryOp[Expr], Expr): _dtype_propagation = common.ternary_op_dtype_propagation(strict=True) -class Cast(common.Cast[Expr], Expr): # type: ignore +class Cast(common.Cast[Expr], Expr): pass diff --git a/src/gt4py/cartesian/gtc/dace/daceir.py b/src/gt4py/cartesian/gtc/dace/daceir.py index 78451c30f5..492a9598c5 100644 --- a/src/gt4py/cartesian/gtc/dace/daceir.py +++ b/src/gt4py/cartesian/gtc/dace/daceir.py @@ -771,7 +771,7 @@ class TernaryOp(common.TernaryOp[Expr], Expr): _dtype_propagation = common.ternary_op_dtype_propagation(strict=True) -class Cast(common.Cast[Expr], Expr): # type: ignore +class Cast(common.Cast[Expr], Expr): pass diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index 952bafd46a..e93a15debe 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -420,8 +420,7 @@ def visit_HorizontalExecution( k_interval, **kwargs: Any, ): - # skip type checking due to https://github.com/python/mypy/issues/5485 - extent = global_ctx.library_node.get_extents(node) # type: ignore + extent = global_ctx.library_node.get_extents(node) decls = [self.visit(decl, **kwargs) for decl in node.declarations] targets: Set[str] = set() stmts = [ diff --git a/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py b/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py index 0d19814b9c..5ca766c272 100644 --- a/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py +++ b/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py @@ -31,11 +31,11 @@ class Offset(common.CartesianOffset): pass -class Literal(common.Literal, Expr): # type: ignore +class Literal(common.Literal, Expr): pass -class LocalAccess(common.ScalarAccess, Expr): # type: ignore +class LocalAccess(common.ScalarAccess, Expr): pass @@ -43,7 +43,7 @@ class VariableKOffset(common.VariableKOffset[Expr]): pass -class AccessorRef(common.FieldAccess[Expr, VariableKOffset], Expr): # type: ignore +class AccessorRef(common.FieldAccess[Expr, VariableKOffset], Expr): pass @@ -88,7 +88,7 @@ class NativeFuncCall(common.NativeFuncCall[Expr], Expr): _dtype_propagation = common.native_func_call_dtype_propagation(strict=True) -class Cast(common.Cast[Expr], Expr): # type: ignore +class Cast(common.Cast[Expr], Expr): pass diff --git a/src/gt4py/cartesian/gtc/gtir.py b/src/gt4py/cartesian/gtc/gtir.py index c9f58de2da..0ee4f7ebe1 100644 --- a/src/gt4py/cartesian/gtc/gtir.py +++ b/src/gt4py/cartesian/gtc/gtir.py @@ -43,7 +43,7 @@ class BlockStmt(common.BlockStmt[Stmt], Stmt): pass -class Literal(common.Literal, Expr): # type: ignore +class Literal(common.Literal, Expr): pass @@ -51,11 +51,11 @@ class VariableKOffset(common.VariableKOffset[Expr]): pass -class ScalarAccess(common.ScalarAccess, Expr): # type: ignore +class ScalarAccess(common.ScalarAccess, Expr): pass -class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr): # type: ignore +class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr): pass @@ -163,7 +163,7 @@ class TernaryOp(common.TernaryOp[Expr], Expr): _dtype_propagation = common.ternary_op_dtype_propagation(strict=False) -class Cast(common.Cast[Expr], Expr): # type: ignore +class Cast(common.Cast[Expr], Expr): pass diff --git a/src/gt4py/cartesian/gtc/oir.py b/src/gt4py/cartesian/gtc/oir.py index df71ef26cf..9f24db6e48 100644 --- a/src/gt4py/cartesian/gtc/oir.py +++ b/src/gt4py/cartesian/gtc/oir.py @@ -33,11 +33,11 @@ class Stmt(common.Stmt): pass -class Literal(common.Literal, Expr): # type: ignore +class Literal(common.Literal, Expr): pass -class ScalarAccess(common.ScalarAccess, Expr): # type: ignore +class ScalarAccess(common.ScalarAccess, Expr): pass @@ -45,7 +45,7 @@ class VariableKOffset(common.VariableKOffset[Expr]): pass -class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr): # type: ignore +class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr): pass @@ -88,7 +88,7 @@ class TernaryOp(common.TernaryOp[Expr], Expr): _dtype_propagation = common.ternary_op_dtype_propagation(strict=True) -class Cast(common.Cast[Expr], Expr): # type: ignore +class Cast(common.Cast[Expr], Expr): pass diff --git a/src/gt4py/cartesian/gtc/passes/gtir_upcaster.py b/src/gt4py/cartesian/gtc/passes/gtir_upcaster.py index 41fa127d6d..94c3d6cd78 100644 --- a/src/gt4py/cartesian/gtc/passes/gtir_upcaster.py +++ b/src/gt4py/cartesian/gtc/passes/gtir_upcaster.py @@ -24,7 +24,7 @@ def _upcast_node(target_dtype: DataType, node: Expr) -> Expr: def _upcast_nodes(*exprs: Expr, upcasting_rule: Callable) -> Iterator[Expr]: assert all(e.dtype for e in exprs) - dtypes: List[DataType] = [e.dtype for e in exprs] # type: ignore # guaranteed to be not None + dtypes: List[DataType] = [e.dtype for e in exprs] # guaranteed to be not None target_dtypes = upcasting_rule(*dtypes) return iter(_upcast_node(target_dtype, arg) for target_dtype, arg in zip(target_dtypes, exprs)) diff --git a/src/gt4py/cartesian/stencil_builder.py b/src/gt4py/cartesian/stencil_builder.py index c0f58c0bc9..6ca2c673a1 100644 --- a/src/gt4py/cartesian/stencil_builder.py +++ b/src/gt4py/cartesian/stencil_builder.py @@ -58,10 +58,7 @@ def __init__( frontend: Optional[Type[FrontendType]] = None, ): self._definition = definition_func - # type ignore explanation: Attribclass generated init not recognized by mypy - self.options = options or BuildOptions( # type: ignore - **self.default_options_dict(definition_func) - ) + self.options = options or BuildOptions(**self.default_options_dict(definition_func)) backend = backend or "numpy" backend = gt4pyc.backend.from_name(backend) if isinstance(backend, str) else backend if backend is None: diff --git a/src/gt4py/cartesian/stencil_object.py b/src/gt4py/cartesian/stencil_object.py index b76415e17f..5e5976e3e5 100644 --- a/src/gt4py/cartesian/stencil_object.py +++ b/src/gt4py/cartesian/stencil_object.py @@ -513,7 +513,7 @@ def _normalize_origins( *((0,) * len(field_info.data_dims)), ) elif (info_origin := getattr(array_infos.get(name), "origin", None)) is not None: - origin[name] = info_origin # type: ignore + origin[name] = info_origin else: origin[name] = (0,) * field_info.ndim From 312b4a85628c6d10cfeaa214fa095172933a99ad Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 27 Jan 2025 17:55:01 +0100 Subject: [PATCH 114/178] build: bump actions/download-artifacts v3 -> v4 (#1822) --- .github/workflows/deploy-release.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/deploy-release.yml b/.github/workflows/deploy-release.yml index b519e008ec..7a7505caa5 100644 --- a/.github/workflows/deploy-release.yml +++ b/.github/workflows/deploy-release.yml @@ -42,7 +42,7 @@ jobs: id-token: write steps: - name: Download wheel - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: gt4py-dist path: dist @@ -60,7 +60,7 @@ jobs: id-token: write steps: - name: Download wheel - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: gt4py-dist path: dist From 12ff610cd520581b5bb06bae06666e7f4a9973be Mon Sep 17 00:00:00 2001 From: SF-N Date: Mon, 27 Jan 2025 18:50:56 +0100 Subject: [PATCH 115/178] refactor[next]: fixed point transformation pass infrastructure (#1826) Executing transformations until a fixed point is reached, i.e., not transformation is applicable anymore is a common pattern encountered in many passes. The `CollapseTuple` pass already contained some infrastructure / utilities for this purpose. Since we now want to use the same approach in `ConstantFolding` and `FieldOpFusion`, they are extracted into a common base class here. This is essentially a step into the direction of a more general pass manager that allows composition of transformations. It is very hard to design something on the scratch board that fulfills all needs we have in the various transformations. The idea is to extend the capabilities of this class step-by-step to cover more and more transformation use-cases such that more passes can use this infrastructure until we can eventually evaluate how dissimilar passes can be composed, e.g. CollapseTuple and ConstantFolding. --------- Co-authored-by: Till Ehrengruber --- .../iterator/transforms/collapse_tuple.py | 74 ++++++------------- .../transforms/fixed_point_transformation.py | 67 +++++++++++++++++ .../next/iterator/transforms/pass_manager.py | 8 +- .../transforms_tests/test_collapse_tuple.py | 38 +++++----- 4 files changed, 114 insertions(+), 73 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/fixed_point_transformation.py diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 0a0cf6d37e..6db58f3765 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -23,6 +23,7 @@ ir_makers as im, misc as ir_misc, ) +from gt4py.next.iterator.transforms import fixed_point_transformation from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.type_system import type_info, type_specifications as ts @@ -86,8 +87,10 @@ def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: # go through all available transformation and apply them. However the final result here still # reads a little convoluted and is also different to how we write other transformations. We # should revisit the pattern here and try to find a more general mechanism. -@dataclasses.dataclass(frozen=True) -class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): +@dataclasses.dataclass(frozen=True, kw_only=True) +class CollapseTuple( + fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor +): """ Simplifies `make_tuple`, `tuple_get` calls. @@ -98,7 +101,7 @@ class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): # TODO(tehrengruber): This Flag mechanism is a little low level. What we actually want # is something like a pass manager, where for each pattern we have a corresponding # transformation, etc. - class Flag(enum.Flag): + class Transformation(enum.Flag): #: `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t` COLLAPSE_MAKE_TUPLE_TUPLE_GET = enum.auto() #: `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` @@ -137,12 +140,12 @@ class Flag(enum.Flag): INLINE_TRIVIAL_LET = enum.auto() @classmethod - def all(self) -> CollapseTuple.Flag: + def all(self) -> CollapseTuple.Transformation: return functools.reduce(operator.or_, self.__members__.values()) uids: eve_utils.UIDGenerator ignore_tuple_size: bool - flags: Flag = Flag.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] + enabled_transformations: Transformation = Transformation.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] PRESERVED_ANNEX_ATTRS = ("type",) @@ -155,8 +158,8 @@ def apply( remove_letified_make_tuple_elements: bool = True, offset_provider_type: Optional[common.OffsetProviderType] = None, within_stencil: Optional[bool] = None, - # manually passing flags is mostly for allowing separate testing of the modes - flags: Optional[Flag] = None, + # manually passing enabled transformations is mostly for allowing separate testing of the modes + enabled_transformations: Optional[Transformation] = None, # allow sym references without a symbol declaration, mostly for testing allow_undeclared_symbols: bool = False, uids: Optional[eve_utils.UIDGenerator] = None, @@ -174,7 +177,7 @@ def apply( to remove left-overs from `LETIFY_MAKE_TUPLE_ELEMENTS` transformation. `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` -> {1, 2}` """ - flags = flags or cls.flags + enabled_transformations = enabled_transformations or cls.enabled_transformations offset_provider_type = offset_provider_type or {} uids = uids or eve_utils.UIDGenerator() @@ -194,7 +197,7 @@ def apply( new_node = cls( ignore_tuple_size=ignore_tuple_size, - flags=flags, + enabled_transformations=enabled_transformations, uids=uids, ).visit(node, within_stencil=within_stencil) @@ -210,45 +213,17 @@ def apply( return new_node - def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: + def visit(self, node, **kwargs): if cpm.is_call_to(node, "as_fieldop"): kwargs = {**kwargs, "within_stencil": True} - node = self.generic_visit(node, **kwargs) - return self.fp_transform(node, **kwargs) - - def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: - while True: - new_node = self.transform(node, **kwargs) - if new_node is None: - break - assert new_node != node - node = new_node - return node - - def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: - if not isinstance(node, ir.FunCall): - return None - - for transformation in self.Flag: - if self.flags & transformation: - assert isinstance(transformation.name, str) - method = getattr(self, f"transform_{transformation.name.lower()}") - result = method(node, **kwargs) - if result is not None: - assert ( - result is not node - ) # transformation should have returned None, since nothing changed - itir_type_inference.reinfer(result) - return result - return None + return super().visit(node, **kwargs) def transform_collapse_make_tuple_tuple_get( self, node: ir.FunCall, **kwargs ) -> Optional[ir.Node]: - if node.fun == ir.SymRef(id="make_tuple") and all( - isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get") - for arg in node.args + if cpm.is_call_to(node, "make_tuple") and all( + cpm.is_call_to(arg, "tuple_get") for arg in node.args ): # `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t` assert isinstance(node.args[0], ir.FunCall) @@ -275,10 +250,9 @@ def transform_collapse_tuple_get_make_tuple( self, node: ir.FunCall, **kwargs ) -> Optional[ir.Node]: if ( - node.fun == ir.SymRef(id="tuple_get") - and isinstance(node.args[1], ir.FunCall) - and node.args[1].fun == ir.SymRef(id="make_tuple") + cpm.is_call_to(node, "tuple_get") and isinstance(node.args[0], ir.Literal) + and cpm.is_call_to(node.args[1], "make_tuple") ): # `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` assert type_info.is_integer(node.args[0].type) @@ -291,7 +265,7 @@ def transform_collapse_tuple_get_make_tuple( return None def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[0], ir.Literal): + if cpm.is_call_to(node, "tuple_get") and isinstance(node.args[0], ir.Literal): # TODO(tehrengruber): extend to general symbols as long as the tail call in the let # does not capture # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` @@ -314,8 +288,8 @@ def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[ ) return None - def transform_letify_make_tuple_elements(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if node.fun == ir.SymRef(id="make_tuple"): + def transform_letify_make_tuple_elements(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: + if cpm.is_call_to(node, "make_tuple"): # `make_tuple(expr1, expr1)` # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` bound_vars: dict[ir.Sym, ir.Expr] = {} @@ -334,7 +308,7 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall, **kwargs) -> Op ) return None - def transform_inline_trivial_make_tuple(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + def transform_inline_trivial_make_tuple(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` @@ -349,7 +323,7 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt # in local-view for now. Revisit. return None - if not cpm.is_call_to(node, "if_"): + if isinstance(node, ir.FunCall) and not cpm.is_call_to(node, "if_"): # TODO(tehrengruber): Only inline if type of branch value is a tuple. # Examples: # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` @@ -391,7 +365,7 @@ def transform_propagate_to_if_on_tuples_cps( # `if True then {2, 1} else {4, 3}`. The examples in the comments below all refer to this # tuple reordering example here. - if cpm.is_call_to(node, "if_"): + if not isinstance(node, ir.FunCall) or cpm.is_call_to(node, "if_"): return None # The first argument that is eligible also transforms all remaining args (They will be diff --git a/src/gt4py/next/iterator/transforms/fixed_point_transformation.py b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py new file mode 100644 index 0000000000..be34af846b --- /dev/null +++ b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py @@ -0,0 +1,67 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dataclasses +import enum +from typing import ClassVar, Optional, Type + +from gt4py import eve +from gt4py.next.iterator import ir +from gt4py.next.iterator.type_system import inference as itir_type_inference + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class FixedPointTransformation(eve.NodeTranslator): + """ + Transformation pass that transforms until no transformation is applicable anymore. + """ + + #: Enum of all transformation (names). The transformations need to be defined as methods + #: named `transform_`. + Transformation: ClassVar[Type[enum.Flag]] + + #: All transformations enabled in this instance, e.g. `Transformation.T1 & Transformation.T2`. + #: Usually the default value is chosen to be all transformations. + enabled_transformations: enum.Flag + + def visit(self, node, **kwargs): + node = super().visit(node, **kwargs) + return self.fp_transform(node, **kwargs) if isinstance(node, ir.Node) else node + + def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: + """ + Transform node until a fixed point is reached, e.g. no transformation is applicable anymore. + """ + while True: + new_node = self.transform(node, **kwargs) + if new_node is None: + break + assert new_node != node + node = new_node + return node + + def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: + """ + Transform node once. + + Execute transformations until one is applicable. As soon as a transformation occured + the function will return the transformed node. Note that the transformation itself + may call other transformations on child nodes again. + """ + for transformation in self.Transformation: + if self.enabled_transformations & transformation: + assert isinstance(transformation.name, str) + method = getattr(self, f"transform_{transformation.name.lower()}") + result = method(node, **kwargs) + if result is not None: + assert ( + result is not node + ) # transformation should have returned None, since nothing changed + itir_type_inference.reinfer(result) + return result + return None diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 6906f81e3f..0a79848443 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -76,7 +76,7 @@ def apply_common_transforms( # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) ir = CollapseTuple.apply( ir, - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, uids=collapse_tuple_uids, offset_provider_type=offset_provider_type, ) # type: ignore[assignment] # always an itir.Program @@ -98,7 +98,7 @@ def apply_common_transforms( # is constant-folded the surrounding tuple_get calls can be removed. inlined = CollapseTuple.apply( inlined, - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, uids=collapse_tuple_uids, offset_provider_type=offset_provider_type, ) # type: ignore[assignment] # always an itir.Program @@ -136,7 +136,7 @@ def apply_common_transforms( ir, ignore_tuple_size=True, uids=collapse_tuple_uids, - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, offset_provider_type=offset_provider_type, ) # type: ignore[assignment] # always an itir.Program @@ -176,7 +176,7 @@ def apply_fieldview_transforms( ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) ir = CollapseTuple.apply( ir, - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, offset_provider_type=common.offset_provider_to_type(offset_provider), ) # type: ignore[assignment] # type is still `itir.Program` ir = inline_dynamic_shifts.InlineDynamicShifts.apply( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 938b998565..916ae4e578 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -19,7 +19,7 @@ def test_simple_make_tuple_tuple_get(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + enabled_transformations=CollapseTuple.Transformation.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, within_stencil=False, ) @@ -37,7 +37,7 @@ def test_nested_make_tuple_tuple_get(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + enabled_transformations=CollapseTuple.Transformation.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, within_stencil=False, ) @@ -53,7 +53,7 @@ def test_different_tuples_make_tuple_tuple_get(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + enabled_transformations=CollapseTuple.Transformation.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, within_stencil=False, ) @@ -67,7 +67,7 @@ def test_incompatible_order_make_tuple_tuple_get(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + enabled_transformations=CollapseTuple.Transformation.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, within_stencil=False, ) @@ -79,7 +79,7 @@ def test_incompatible_size_make_tuple_tuple_get(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + enabled_transformations=CollapseTuple.Transformation.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, within_stencil=False, ) @@ -91,7 +91,7 @@ def test_merged_with_smaller_outer_size_make_tuple_tuple_get(): actual = CollapseTuple.apply( testee, ignore_tuple_size=True, - flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + enabled_transformations=CollapseTuple.Transformation.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, within_stencil=False, ) @@ -104,7 +104,7 @@ def test_simple_tuple_get_make_tuple(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE, + enabled_transformations=CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE, allow_undeclared_symbols=True, within_stencil=False, ) @@ -117,7 +117,7 @@ def test_propagate_tuple_get(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET, + enabled_transformations=CollapseTuple.Transformation.PROPAGATE_TUPLE_GET, allow_undeclared_symbols=True, within_stencil=False, ) @@ -135,7 +135,7 @@ def test_letify_make_tuple_elements(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + enabled_transformations=CollapseTuple.Transformation.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, within_stencil=False, ) @@ -149,7 +149,7 @@ def test_letify_make_tuple_with_trivial_elements(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + enabled_transformations=CollapseTuple.Transformation.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, within_stencil=False, ) @@ -163,7 +163,7 @@ def test_inline_trivial_make_tuple(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE, + enabled_transformations=CollapseTuple.Transformation.INLINE_TRIVIAL_MAKE_TUPLE, allow_undeclared_symbols=True, within_stencil=False, ) @@ -182,7 +182,7 @@ def test_propagate_to_if_on_tuples(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + enabled_transformations=CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, allow_undeclared_symbols=True, within_stencil=False, ) @@ -199,8 +199,8 @@ def test_propagate_to_if_on_tuples_with_let(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=True, - flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES - | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + enabled_transformations=CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES + | CollapseTuple.Transformation.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, within_stencil=False, ) @@ -213,7 +213,7 @@ def test_propagate_nested_lift(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET, + enabled_transformations=CollapseTuple.Transformation.PROPAGATE_NESTED_LET, allow_undeclared_symbols=True, within_stencil=False, ) @@ -249,7 +249,7 @@ def test_if_make_tuple_reorder_cps(): expected = im.if_(True, im.make_tuple(2, 1), im.make_tuple(4, 3)) actual = CollapseTuple.apply( testee, - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, allow_undeclared_symbols=True, within_stencil=False, ) @@ -275,7 +275,7 @@ def test_nested_if_make_tuple_reorder_cps(): ) actual = CollapseTuple.apply( testee, - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, allow_undeclared_symbols=True, within_stencil=False, ) @@ -291,7 +291,7 @@ def test_if_make_tuple_reorder_cps_nested(): expected = im.if_(True, im.make_tuple(2, 1, 1), im.make_tuple(4, 3, 3)) actual = CollapseTuple.apply( testee, - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, allow_undeclared_symbols=True, within_stencil=False, ) @@ -306,7 +306,7 @@ def test_if_make_tuple_reorder_cps_external(): expected = im.if_(True, im.make_tuple(external_ref, 2, 1), im.make_tuple(external_ref, 4, 3)) actual = CollapseTuple.apply( testee, - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, allow_undeclared_symbols=True, within_stencil=False, ) From e3788775edc7304fae35c075a398f8274a1277d1 Mon Sep 17 00:00:00 2001 From: SF-N Date: Mon, 27 Jan 2025 20:55:55 +0100 Subject: [PATCH 116/178] refactor[next]: neg builtin for unary minus (#1819) Add `neg` builtin to GTIR and use it for unary minus, e.g. `-val`, instead of `0-val` as before. --------- Co-authored-by: Till Ehrengruber --- src/gt4py/next/ffront/fbuiltins.py | 3 +- src/gt4py/next/ffront/foast_to_gtir.py | 12 +++---- src/gt4py/next/iterator/builtins.py | 7 ++++- src/gt4py/next/iterator/embedded.py | 8 +++++ .../codegens/gtfn/codegen.py | 31 ++++++++++--------- .../runners/dace/gtir_python_codegen.py | 1 + .../ffront_tests/test_math_unary_builtins.py | 8 +++++ .../iterator_tests/test_builtins.py | 8 +++-- .../ffront_tests/test_foast_to_gtir.py | 4 +-- 9 files changed, 55 insertions(+), 27 deletions(-) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index cef7fc101f..ee14006b22 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -10,6 +10,7 @@ import functools import inspect import math +import operator from builtins import bool, float, int, tuple # noqa: A004 shadowing a Python built-in from typing import Any, Callable, Final, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast @@ -203,7 +204,7 @@ def astype( return core_defs.dtype(type_).scalar_type(value) -_UNARY_MATH_NUMBER_BUILTIN_IMPL: Final = {"abs": abs} +_UNARY_MATH_NUMBER_BUILTIN_IMPL: Final = {"abs": abs, "neg": operator.neg} UNARY_MATH_NUMBER_BUILTIN_NAMES: Final = [*_UNARY_MATH_NUMBER_BUILTIN_IMPL.keys()] _UNARY_MATH_FP_BUILTIN_IMPL: Final = { diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 4519b4e571..007e195f3e 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -241,12 +241,12 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: if dtype.kind != ts.ScalarKind.BOOL: raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") return self._lower_and_map("not_", node.operand) - - return self._lower_and_map( - node.op.value, - foast.Constant(value="0", type=dtype, location=node.location), - node.operand, - ) + if node.op in [dialect_ast_enums.UnaryOperator.USUB]: + return self._lower_and_map("neg", node.operand) + if node.op in [dialect_ast_enums.UnaryOperator.UADD]: + return self.visit(node.operand) + else: + raise NotImplementedError(f"Unary operator '{node.op}' is not supported.") def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: return self._lower_and_map(node.op.value, node.left, node.right) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 959f451e01..8e5f7addca 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -292,6 +292,11 @@ def trunc(*args): raise BackendNotSelectedError() +@builtin_dispatch +def neg(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def isfinite(*args): raise BackendNotSelectedError() @@ -397,7 +402,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() -UNARY_MATH_NUMBER_BUILTINS = {"abs"} +UNARY_MATH_NUMBER_BUILTINS = {"abs", "neg"} UNARY_LOGICAL_BUILTINS = {"not_"} UNARY_MATH_FP_BUILTINS = { "sin", diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 970e88e8c5..16b1fa9d03 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -392,6 +392,13 @@ def not_(a): return not a +@builtins.neg.register(EMBEDDED) +def neg(a): + if isinstance(a, Column): + return np.negative(a) + return np.negative(a) + + @builtins.gamma.register(EMBEDDED) def gamma(a): gamma_ = np.vectorize(math.gamma) @@ -538,6 +545,7 @@ def promote_scalars(val: CompositeOfScalarOrField): "and_": operator.and_, "or_": operator.or_, "xor_": operator.xor, + "neg": operator.neg, } decorator = getattr(builtins, math_builtin_name).register(EMBEDDED) impl: Callable diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index c6bf28d8e0..969e203689 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -23,6 +23,7 @@ class GTFNCodegen(codegen.TemplatedGenerator): _builtins_mapping: Final = { "abs": "std::abs", + "neg": "std::negate<>{}", "sin": "std::sin", "cos": "std::cos", "tan": "std::tan", @@ -61,21 +62,21 @@ class GTFNCodegen(codegen.TemplatedGenerator): "int64": "std::int64_t", "uint64": "std::uint64_t", "bool": "bool", - "plus": "std::plus{}", - "minus": "std::minus{}", - "multiplies": "std::multiplies{}", - "divides": "std::divides{}", - "eq": "std::equal_to{}", - "not_eq": "std::not_equal_to{}", - "less": "std::less{}", - "less_equal": "std::less_equal{}", - "greater": "std::greater{}", - "greater_equal": "std::greater_equal{}", - "and_": "std::logical_and{}", - "or_": "std::logical_or{}", - "xor_": "std::bit_xor{}", - "mod": "std::modulus{}", - "not_": "std::logical_not{}", + "plus": "std::plus<>{}", + "minus": "std::minus<>{}", + "multiplies": "std::multiplies<>{}", + "divides": "std::divides<>{}", + "eq": "std::equal_to<>{}", + "not_eq": "std::not_equal_to<>{}", + "less": "std::less<>{}", + "less_equal": "std::less_equal<>{}", + "greater": "std::greater<>{}", + "greater_equal": "std::greater_equal<>{}", + "and_": "std::logical_and<>{}", + "or_": "std::logical_or<>{}", + "xor_": "std::bit_xor<>{}", + "mod": "std::modulus<>{}", + "not_": "std::logical_not<>{}", } Sym = as_fmt("{id}") diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py index dfbba9c88b..56a67510e7 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py @@ -20,6 +20,7 @@ MATH_BUILTINS_MAPPING = { "abs": "abs({})", + "neg": "(- {})", "sin": "math.sin({})", "cos": "math.cos({})", "tan": "math.tan({})", diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 89c341e9a6..1707adada8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -128,6 +128,14 @@ def uneg(inp: cases.IField) -> cases.IField: cases.verify_with_default_data(cartesian_case, uneg, ref=lambda inp1: -inp1) +def test_unary_pos(cartesian_case): + @gtx.field_operator + def upos(inp: cases.IField) -> cases.IField: + return +inp + + cases.verify_with_default_data(cartesian_case, upos, ref=lambda inp1: inp1) + + def test_unary_neg_float_conversion(cartesian_case): @gtx.field_operator def uneg_float() -> cases.IFloatField: diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 885a272bfe..01637e56e0 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -46,6 +46,8 @@ plus, shift, xor_, + neg, + abs, ) from gt4py.next.iterator.runtime import fendef, fundef, offset, set_at from gt4py.next.program_processors.runners.gtfn import run_gtfn @@ -135,6 +137,8 @@ def fenimpl(size, arg0, arg1, arg2, out): def arithmetic_and_logical_test_data(): return [ # (builtin, inputs, expected) + (abs, [[-1.0, 1.0]], [1.0, 1.0]), + (neg, [[-1.0, 1.0, -1, 1]], [1.0, -1.0, 1, -1]), (plus, [2.0, 3.0], 5.0), (minus, [2.0, 3.0], -1.0), (multiplies, [2.0, 3.0], 6.0), @@ -180,8 +184,8 @@ def test_arithmetic_and_logical_builtins(program_processor, builtin, inputs, exp @pytest.mark.parametrize("builtin, inputs, expected", arithmetic_and_logical_test_data()) def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): - if builtin == if_: - pytest.skip("If cannot be used unapplied") + if builtin == if_ or builtin == abs: + pytest.skip("If and abs cannot be used unapplied.") inps = field_maker(*array_maker(*inputs)) out = field_maker((np.zeros_like(*array_maker(expected))))[0] diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 59a8dc961b..d2d5404cb5 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -378,7 +378,7 @@ def foo(inp: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.op_as_fieldop("minus")(im.literal("0", "float64"), "inp") + reference = im.op_as_fieldop("neg")("inp") assert lowered.expr == reference @@ -390,7 +390,7 @@ def foo(inp: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.op_as_fieldop("plus")(im.literal("0", "float64"), "inp") + reference = im.ref("inp") assert lowered.expr == reference From 9be2d2d22df9cb45c8cf1b75f78d72dbeda87cf5 Mon Sep 17 00:00:00 2001 From: SF-N Date: Tue, 28 Jan 2025 12:50:33 +0100 Subject: [PATCH 117/178] refactor[next]: new ir.makers for common builtins (#1827) Using `im.call("...")` inside the transformations is cumbersome. This PR adds new helpers to the `ir_makers` for all commonly used builtins used inside of the transformations, namely: `reduce`, `scan`, `list_get`, `maximum`, `minimum`, `cast_`, and `can_deref`. The helpers for `floordiv_` and `mod` were removed since they weren't used and are rather uncommon anyway. --------- Co-authored-by: Till Ehrengruber --- src/gt4py/next/ffront/foast_to_gtir.py | 14 +- .../ir_utils/common_pattern_matcher.py | 2 +- .../next/iterator/ir_utils/domain_utils.py | 4 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 56 ++++++-- .../iterator/transforms/collapse_list_get.py | 4 +- .../iterator/transforms/inline_fundefs.py | 4 +- .../next/iterator/transforms/inline_lifts.py | 4 +- .../next/iterator/transforms/unroll_reduce.py | 52 ++------ .../ffront_tests/test_foast_to_gtir.py | 44 +++--- .../iterator_tests/test_type_inference.py | 126 ++++++++---------- .../transforms_tests/test_constant_folding.py | 6 +- .../transforms_tests/test_cse.py | 8 +- .../transforms_tests/test_domain_inference.py | 4 +- .../transforms_tests/test_inline_lifts.py | 10 +- .../transforms_tests/test_prune_casts.py | 4 +- .../transforms_tests/test_trace_shifts.py | 2 +- .../transforms_tests/test_unroll_reduce.py | 39 ++---- .../gtfn_tests/test_gtfn_module.py | 2 +- .../gtfn_tests/test_itir_to_gtfn_ir.py | 2 +- .../dace_tests/test_gtir_to_sdfg.py | 78 ++++------- 20 files changed, 197 insertions(+), 268 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 007e195f3e..f884ec555d 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -138,7 +138,7 @@ def visit_ScanOperator( definition = itir.Lambda(params=func_definition.params, expr=new_body) - body = im.as_fieldop(im.call("scan")(definition, forward, init))(*stencil_args) + body = im.as_fieldop(im.scan(definition, forward, init))(*stencil_args) return itir.FunctionDefinition(id=node.id, params=definition.params[1:], expr=body) @@ -360,7 +360,7 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall: - return _map(im.lambda_("val")(im.call("cast_")("val", str(new_type))), (expr,), t) + return _map(im.lambda_("val")(im.cast_("val", str(new_type))), (expr,), t) if not isinstance(node.type, ts.TupleType): # to keep the IR simpler return create_cast(obj, (node.args[0].type,)) @@ -409,7 +409,7 @@ def _make_reduction_expr( # TODO(havogt): deal with nested reductions of the form neighbor_sum(neighbor_sum(field(off1)(off2))) it = self.visit(node.args[0], **kwargs) assert isinstance(node.kwargs["axis"].type, ts.DimensionType) - val = im.call(im.call("reduce")(op, init_expr)) + val = im.reduce(op, init_expr) return im.op_as_fieldop(val)(it) def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: @@ -462,14 +462,14 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: return self._make_literal(node.value, node.type) - def _lower_and_map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: + def _lower_and_map(self, op: itir.Lambda | str, *args: Any, **kwargs: Any) -> itir.FunCall: return _map( op, tuple(self.visit(arg, **kwargs) for arg in args), tuple(arg.type for arg in args) ) def _map( - op: itir.Expr | str, + op: itir.Lambda | str, lowered_args: tuple, original_arg_types: tuple[ts.TypeSpec, ...], ) -> itir.FunCall: @@ -487,9 +487,9 @@ def _map( promote_to_list(arg_type)(larg) for arg_type, larg in zip(original_arg_types, lowered_args) ) - op = im.call("map_")(op) + op = im.map_(op) - return im.op_as_fieldop(im.call(op))(*lowered_args) + return im.op_as_fieldop(op)(*lowered_args) class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 19d0802f4b..c16b9f2b48 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -72,7 +72,7 @@ def is_call_to(node: Any, fun: str | Iterable[str]) -> TypeGuard[itir.FunCall]: attribute which can be anything. >>> from gt4py.next.iterator.ir_utils import ir_makers as im - >>> node = im.call("plus")(1, 2) + >>> node = im.plus(1, 2) >>> is_call_to(node, "plus") True >>> is_call_to(node, "minus") diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index c84e2c0228..27900b6db6 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -162,11 +162,11 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: assert all(domain.ranges.keys() == domains[0].ranges.keys() for domain in domains) for dim in domains[0].ranges.keys(): start = functools.reduce( - lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr), + lambda current_expr, el_expr: im.minimum(current_expr, el_expr), [domain.ranges[dim].start for domain in domains], ) stop = functools.reduce( - lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), + lambda current_expr, el_expr: im.maximum(current_expr, el_expr), [domain.ranges[dim].stop for domain in domains], ) # constant fold expression to keep the tree small diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index c5cf2efa5a..24842ad3be 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -169,18 +169,6 @@ def divides_(left, right): return call("divides")(left, right) -def floordiv_(left, right): - """Create a floor division FunCall, shorthand for ``call("floordiv")(left, right)``.""" - # TODO(tehrengruber): Use int(floor(left/right)) as soon as we support integer casting - # and remove the `floordiv` builtin again. - return call("floordiv")(left, right) - - -def mod(left, right): - """Create a modulo FunCall, shorthand for ``call("mod")(left, right)``.""" - return call("mod")(left, right) - - def and_(left, right): """Create an and_ FunCall, shorthand for ``call("and_")(left, right)``.""" return call("and_")(left, right) @@ -302,7 +290,10 @@ def shift(offset, value=None): offset = ensure_offset(offset) args = [offset] if value is not None: - value = ensure_offset(value) + if isinstance(value, int): + value = ensure_offset(value) + elif isinstance(value, str): + value = ref(value) args.append(value) return call(call("shift")(*args)) @@ -469,7 +460,7 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> cal def op_as_fieldop( - op: str | itir.SymRef | Callable, domain: Optional[itir.FunCall] = None + op: str | itir.SymRef | itir.Lambda | Callable, domain: Optional[itir.FunCall] = None ) -> Callable[..., itir.FunCall]: """ Promotes a function `op` to a field_operator. @@ -536,3 +527,40 @@ def index(dim: common.Dimension) -> itir.FunCall: def map_(op): """Create a `map_` call.""" return call(call("map_")(op)) + + +def reduce(op, expr): + """Create a `reduce` call.""" + return call(call("reduce")(op, expr)) + + +def scan(expr, forward, init): + """Create a `scan` call.""" + return call("scan")(expr, forward, init) + + +def list_get(list_idx, list_): + """Create a `list_get` call.""" + return call("list_get")(list_idx, list_) + + +def maximum(expr1, expr2): + """Create a `maximum` call.""" + return call("maximum")(expr1, expr2) + + +def minimum(expr1, expr2): + """Create a `minimum` call.""" + return call("minimum")(expr1, expr2) + + +def cast_(expr, dtype: ts.ScalarType | str): + """Create a `cast_` call.""" + if isinstance(dtype, ts.ScalarType): + dtype = dtype.kind.name.lower() + return call("cast_")(expr, dtype) + + +def can_deref(expr): + """Create a `can_deref` call.""" + return call("can_deref")(expr) diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index 4a354879ca..b0a0c1e1dc 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -27,8 +27,8 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node: cond, true_val, false_val = node.args[1].args return im.if_( cond, - self.visit(im.call("list_get")(list_idx, true_val)), - self.visit(im.call("list_get")(list_idx, false_val)), + self.visit(im.list_get(list_idx, true_val)), + self.visit(im.list_get(list_idx, false_val)), ) if cpm.is_call_to(node.args[1], "neighbors"): offset_tag = node.args[1].args[0] diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index e4cae978da..03b20d14fe 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -36,12 +36,12 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program: >>> fun1 = itir.FunctionDefinition( ... id="fun1", ... params=[im.sym("a")], - ... expr=im.call("deref")("a"), + ... expr=im.deref("a"), ... ) >>> fun2 = itir.FunctionDefinition( ... id="fun2", ... params=[im.sym("a")], - ... expr=im.call("deref")("a"), + ... expr=im.deref("a"), ... ) >>> program = itir.Program( ... id="testee", diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index f27dbbb74c..7724aa86f6 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -197,11 +197,11 @@ def visit_FunCall( if len(args) == 0: return im.literal_from_value(True) - res = ir.FunCall(fun=ir.SymRef(id="can_deref"), args=[args[0]]) + res = im.can_deref(args[0]) for arg in args[1:]: res = ir.FunCall( fun=ir.SymRef(id="and_"), - args=[res, ir.FunCall(fun=ir.SymRef(id="can_deref"), args=[arg])], + args=[res, im.can_deref(arg)], ) return res elif ( diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 042a86cd8e..6e993a2ed7 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -14,7 +14,7 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -85,34 +85,6 @@ def _get_connectivity( return connectivities[0] -def _make_shift(offsets: list[itir.Expr], iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall( - fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=offsets), - args=[iterator], - location=iterator.location, - ) - - -def _make_deref(iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="deref"), args=[iterator], location=iterator.location) - - -def _make_can_deref(iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall( - fun=itir.SymRef(id="can_deref"), args=[iterator], location=iterator.location - ) - - -def _make_if(cond: itir.Expr, true_expr: itir.Expr, false_expr: itir.Expr) -> itir.FunCall: - return itir.FunCall( - fun=itir.SymRef(id="if_"), args=[cond, true_expr, false_expr], location=cond.location - ) - - -def _make_list_get(offset: itir.Expr, expr: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="list_get"), args=[offset, expr], location=expr.location) - - @dataclasses.dataclass(frozen=True) class UnrollReduce(PreserveLocationVisitor, NodeTranslator): # we use one UID generator per instance such that the generated ids are @@ -130,27 +102,25 @@ def _visit_reduce( max_neighbors = connectivity_type.max_neighbors has_skip_values = connectivity_type.has_skip_values - acc = itir.SymRef(id=self.uids.sequential_id(prefix="_acc")) - offset = itir.SymRef(id=self.uids.sequential_id(prefix="_i")) - step = itir.SymRef(id=self.uids.sequential_id(prefix="_step")) + acc: str = self.uids.sequential_id(prefix="_acc") + offset: str = self.uids.sequential_id(prefix="_i") + step: str = self.uids.sequential_id(prefix="_step") assert isinstance(node.fun, itir.FunCall) fun, init = node.fun.args - elems = [_make_list_get(offset, arg) for arg in node.args] - step_fun: itir.Expr = itir.FunCall(fun=fun, args=[acc, *elems]) + elems = [im.list_get(offset, arg) for arg in node.args] + step_fun: itir.Expr = im.call(fun)(acc, *elems) if has_skip_values: check_arg = next(_get_neighbors_args(node.args)) offset_tag, it = check_arg.args - can_deref = _make_can_deref(_make_shift([offset_tag, offset], it)) - step_fun = _make_if(can_deref, step_fun, acc) - step_fun = itir.Lambda(params=[itir.Sym(id=acc.id), itir.Sym(id=offset.id)], expr=step_fun) + can_deref = im.can_deref(im.shift(offset_tag, offset)(it)) + step_fun = im.if_(can_deref, step_fun, acc) + step_fun = im.lambda_(acc, offset)(step_fun) expr = init for i in range(max_neighbors): - expr = itir.FunCall(fun=step, args=[expr, itir.OffsetLiteral(value=i)]) - expr = itir.FunCall( - fun=itir.Lambda(params=[itir.Sym(id=step.id)], expr=expr), args=[step_fun] - ) + expr = im.call(step)(expr, itir.OffsetLiteral(value=i)) + expr = im.let(step, step_fun)(expr) return expr diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index d2d5404cb5..c0d762efc8 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -91,7 +91,7 @@ def foo(bar: int64, alpha: int64) -> int64: parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.call("multiplies")("alpha", "bar") + reference = im.multiplies_("alpha", "bar") assert lowered.expr == reference @@ -297,7 +297,7 @@ def foo(a: gtx.Field[gtx.Dims[Vertex, V2EDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.op_as_fieldop(im.map_(im.lambda_("val")(im.call("cast_")("val", "int32"))))("a") + reference = im.op_as_fieldop(im.map_(im.lambda_("val")(im.cast_("val", "int32"))))("a") assert lowered.expr == reference @@ -310,7 +310,7 @@ def foo(a: float64): lowered = FieldOperatorLowering.apply(parsed) lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) - reference = im.call("cast_")("a", "int32") + reference = im.cast_("a", "int32") assert lowered_inlined.expr == reference @@ -341,7 +341,7 @@ def foo(a: tuple[gtx.Field[[TDim], float64], float64]): reference = im.make_tuple( im.cast_as_fieldop("int32")(im.tuple_get(0, "a")), - im.call("cast_")(im.tuple_get(1, "a"), "int32"), + im.cast_(im.tuple_get(1, "a"), "int32"), ) assert lowered_inlined.expr == reference @@ -551,7 +551,7 @@ def foo(a: gtx.Field[[TDim], "int32"]) -> gtx.Field[[TDim], "int32"]: reference = im.let( ssa.unique_name("tmp", 0), - im.call("plus")( + im.plus( im.literal("1", "int32"), im.literal("1", "int32"), ), @@ -656,7 +656,7 @@ def foo() -> bool: parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.call("greater")( + reference = im.greater( im.literal("3", "int32"), im.literal("4", "int32"), ) @@ -761,11 +761,9 @@ def foo(edge_f: gtx.Field[[Edge], float64]): lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop( - im.call( - im.call("reduce")( - "plus", - im.literal(value="0", typename="float64"), - ) + im.reduce( + "plus", + im.literal(value="0", typename="float64"), ) )(im.as_fieldop_neighbors("V2E", "edge_f")) @@ -780,11 +778,9 @@ def foo(edge_f: gtx.Field[[Edge], float64]): lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop( - im.call( - im.call("reduce")( - "maximum", - im.literal(value=str(np.finfo(np.float64).min), typename="float64"), - ) + im.reduce( + "maximum", + im.literal(value=str(np.finfo(np.float64).min), typename="float64"), ) )(im.as_fieldop_neighbors("V2E", "edge_f")) @@ -799,11 +795,9 @@ def foo(edge_f: gtx.Field[[Edge], float64]): lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop( - im.call( - im.call("reduce")( - "minimum", - im.literal(value=str(np.finfo(np.float64).max), typename="float64"), - ) + im.reduce( + "minimum", + im.literal(value=str(np.finfo(np.float64).max), typename="float64"), ) )(im.as_fieldop_neighbors("V2E", "edge_f")) @@ -828,11 +822,9 @@ def foo(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64] im.as_fieldop_neighbors("V2E", "e1"), )( im.op_as_fieldop( - im.call( - im.call("reduce")( - "plus", - im.literal(value="0", typename="float64"), - ) + im.reduce( + "plus", + im.literal(value="0", typename="float64"), ) )(mapped) ) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index d4d7c60d69..a39fe3c6d8 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -77,10 +77,10 @@ def expression_test_cases(): (im.plus(1, 2), int_type), (im.eq(1, 2), bool_type), (im.deref(im.ref("it", it_on_e_of_e_type)), it_on_e_of_e_type.element_type), - (im.call("can_deref")(im.ref("it", it_on_e_of_e_type)), bool_type), + (im.can_deref(im.ref("it", it_on_e_of_e_type)), bool_type), (im.if_(True, 1, 2), int_type), (im.call("make_const_list")(True), ts.ListType(element_type=bool_type)), - (im.call("list_get")(0, im.ref("l", ts.ListType(element_type=bool_type))), bool_type), + (im.list_get(0, im.ref("l", ts.ListType(element_type=bool_type))), bool_type), ( im.call("named_range")( itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 @@ -119,7 +119,7 @@ def expression_test_cases(): ts.ListType(element_type=it_on_e_of_e_type.element_type), ), # cast - (im.call("cast_")(1, "int32"), int_type), + (im.cast_(1, int_type), int_type), # TODO: lift # TODO: scan # map @@ -128,18 +128,16 @@ def expression_test_cases(): int_list_type, ), # reduce - (im.call(im.call("reduce")("plus", 0))(im.ref("l", int_list_type)), int_type), + (im.reduce("plus", 0)(im.ref("l", int_list_type)), int_type), ( - im.call( - im.call("reduce")( - im.lambda_("acc", "a", "b")( - im.make_tuple( - im.plus(im.tuple_get(0, "acc"), "a"), - im.plus(im.tuple_get(1, "acc"), "b"), - ) - ), - im.make_tuple(0, 0.0), - ) + im.reduce( + im.lambda_("acc", "a", "b")( + im.make_tuple( + im.plus(im.tuple_get(0, "acc"), "a"), + im.plus(im.tuple_get(1, "acc"), "b"), + ) + ), + im.make_tuple(0, 0.0), )(im.ref("la", int_list_type), im.ref("lb", float64_list_type)), ts.TupleType(types=[int_type, float64_type]), ), @@ -148,42 +146,36 @@ def expression_test_cases(): (im.shift("Ioff", 1)(im.ref("it", it_ijk_type)), it_ijk_type), # as_fieldop ( - im.call( - im.call("as_fieldop")( - "deref", - im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) - ), - ) + im.as_fieldop( + "deref", + im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ), )(im.ref("inp", float_i_field)), float_i_field, ), ( - im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), - im.call("unstructured_domain")( - im.call("named_range")( - itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), - 0, - 1, - ), - im.call("named_range")( - itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 - ), + im.as_fieldop( + im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), + im.call("unstructured_domain")( + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), + 0, + 1, ), - ) + im.call("named_range")( + itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 + ), + ), )(im.ref("inp", float_edge_k_field)), float_vertex_k_field, ), ( - im.call( - im.call("as_fieldop")( - im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))), - im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) - ), - ) + im.as_fieldop( + im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))), + im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ), )(im.ref("inp1", float_i_field), im.ref("inp2", float_i_field)), ts.TupleType(types=[float_i_field, float_i_field]), ), @@ -197,21 +189,17 @@ def expression_test_cases(): ( im.if_( False, - im.call( - im.call("as_fieldop")( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) - ), - ) + im.as_fieldop( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ), )(im.ref("inp", float_i_field), 1.0), - im.call( - im.call("as_fieldop")( - "deref", - im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) - ), - ) + im.as_fieldop( + "deref", + im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ), )(im.ref("inp", float_i_field)), ), float_i_field, @@ -276,9 +264,7 @@ def test_cast_first_arg_inference(): # since cast_ is a grammar builtin whose return type is given by its second argument it is # easy to forget inferring the types of the first argument and its children. Simply check # if the first argument has a type inferred correctly here. - testee = im.call("cast_")( - im.plus(im.literal_from_value(1), im.literal_from_value(2)), "float64" - ) + testee = im.cast_(im.plus(im.literal_from_value(1), im.literal_from_value(2)), "float64") result = itir_type_inference.infer( testee, offset_provider_type={}, allow_undeclared_symbols=True ) @@ -299,9 +285,7 @@ def test_cartesian_fencil_definition(): declarations=[], body=[ itir.SetAt( - expr=im.call(im.call("as_fieldop")(im.ref("deref"), cartesian_domain))( - im.ref("inp") - ), + expr=im.as_fieldop(im.ref("deref"), cartesian_domain)(im.ref("inp")), domain=cartesian_domain, target=im.ref("out"), ), @@ -336,10 +320,8 @@ def test_unstructured_fencil_definition(): declarations=[], body=[ itir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), unstructured_domain - ) + expr=im.as_fieldop( + im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), unstructured_domain )(im.ref("inp")), domain=unstructured_domain, target=im.ref("out"), @@ -375,7 +357,7 @@ def test_function_definition(): body=[ itir.SetAt( domain=cartesian_domain, - expr=im.call(im.call("as_fieldop")(im.ref("bar"), cartesian_domain))(im.ref("inp")), + expr=im.as_fieldop(im.ref("bar"), cartesian_domain)(im.ref("inp")), target=im.ref("out"), ), ], @@ -408,11 +390,9 @@ def test_fencil_with_nb_field_input(): body=[ itir.SetAt( domain=unstructured_domain, - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.call(im.call("reduce")("plus", 0.0))(im.deref("it"))), - unstructured_domain, - ) + expr=im.as_fieldop( + im.lambda_("it")(im.reduce("plus", 0.0)(im.deref("it"))), + unstructured_domain, )(im.ref("inp")), target=im.ref("out"), ), @@ -438,9 +418,7 @@ def test_program_tuple_setat_short_target(): declarations=[], body=[ itir.SetAt( - expr=im.call( - im.call("as_fieldop")(im.lambda_()(im.make_tuple(1.0, 2.0)), cartesian_domain) - )(), + expr=im.as_fieldop(im.lambda_()(im.make_tuple(1.0, 2.0)), cartesian_domain)(), domain=cartesian_domain, target=im.make_tuple("out"), ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 0bf8dcb65d..cf325c2daa 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -31,7 +31,7 @@ def test_constant_folding_math_op(): def test_constant_folding_if(): - expected = im.call("plus")("a", 2) + expected = im.plus("a", 2) testee = im.if_( im.literal_from_value(True), im.plus(im.ref("a"), im.literal_from_value(2)), @@ -42,7 +42,7 @@ def test_constant_folding_if(): def test_constant_folding_minimum(): - testee = im.call("minimum")("a", "a") + testee = im.minimum("a", "a") expected = im.ref("a") actual = ConstantFolding.apply(testee) assert actual == expected @@ -56,7 +56,7 @@ def test_constant_folding_literal(): def test_constant_folding_literal_maximum(): - testee = im.call("maximum")(im.literal_from_value(1), im.literal_from_value(2)) + testee = im.maximum(im.literal_from_value(1), im.literal_from_value(2)) expected = im.literal_from_value(2) actual = ConstantFolding.apply(testee) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index 14860d9bdd..3909c6f26a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -135,7 +135,7 @@ def test_if_can_deref_no_extraction(offset_provider_type): # if can_deref(⟪Iₒ, 1ₒ⟫(it)) then ·⟪Iₒ, 1ₒ⟫(it) + ·⟪Iₒ, 1ₒ⟫(it) else 1 testee = im.if_( - im.call("can_deref")(im.shift("I", 1)("it")), + im.can_deref(im.shift("I", 1)("it")), im.plus(im.deref(im.shift("I", 1)("it")), im.deref(im.shift("I", 1)("it"))), # use something more involved where a subexpression can still be eliminated im.literal("1", "int32"), @@ -143,7 +143,7 @@ def test_if_can_deref_no_extraction(offset_provider_type): # (λ(_cs_1) → if can_deref(_cs_1) then (λ(_cs_2) → _cs_2 + _cs_2)(·_cs_1) else 1)(⟪Iₒ, 1ₒ⟫(it)) expected = im.let("_cs_1", im.shift("I", 1)("it"))( im.if_( - im.call("can_deref")("_cs_1"), + im.can_deref("_cs_1"), im.let("_cs_2", im.deref("_cs_1"))(im.plus("_cs_2", "_cs_2")), im.literal("1", "int32"), ) @@ -159,14 +159,14 @@ def test_if_can_deref_eligible_extraction(offset_provider_type): # if can_deref(⟪Iₒ, 1ₒ⟫(it)) then ·⟪Iₒ, 1ₒ⟫(it) else ·⟪Iₒ, 1ₒ⟫(it) + ·⟪Iₒ, 1ₒ⟫(it) testee = im.if_( - im.call("can_deref")(im.shift("I", 1)("it")), + im.can_deref(im.shift("I", 1)("it")), im.deref(im.shift("I", 1)("it")), im.plus(im.deref(im.shift("I", 1)("it")), im.deref(im.shift("I", 1)("it"))), ) # (λ(_cs_3) → (λ(_cs_1) → if can_deref(_cs_3) then _cs_1 else _cs_1 + _cs_1)(·_cs_3))(⟪Iₒ, 1ₒ⟫(it)) expected = im.let("_cs_3", im.shift("I", 1)("it"))( im.let("_cs_1", im.deref("_cs_3"))( - im.if_(im.call("can_deref")("_cs_3"), "_cs_1", im.plus("_cs_1", "_cs_1")) + im.if_(im.can_deref("_cs_3"), "_cs_1", im.plus("_cs_1", "_cs_1")) ) ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 779ab738cb..4a2a441510 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -1028,10 +1028,10 @@ def test_arithmetic_builtin(offset_provider): def test_scan(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) testee = im.as_fieldop( - im.call("scan")(im.lambda_("init", "it")(im.deref(im.shift("Ioff", 1)("it"))), True, 0.0) + im.scan(im.lambda_("init", "it")(im.deref(im.shift("Ioff", 1)("it"))), True, 0.0) )("a") expected = im.as_fieldop( - im.call("scan")(im.lambda_("init", "it")(im.deref(im.shift("Ioff", 1)("it"))), True, 0.0), + im.scan(im.lambda_("init", "it")(im.deref(im.shift("Ioff", 1)("it"))), True, 0.0), domain, )("a") diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py index f81ca5a666..957e7ffe63 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py @@ -27,15 +27,15 @@ def inline_lift_test_data(): ), ( # can_deref(lift(f)(args...)) -> and(can_deref(arg[0]), and(can_deref(arg[1]), ...)) - im.call("can_deref")(im.lift("f")("arg1", "arg2")), - im.and_(im.call("can_deref")("arg1"), im.call("can_deref")("arg2")), + im.can_deref(im.lift("f")("arg1", "arg2")), + im.and_(im.can_deref("arg1"), im.can_deref("arg2")), ), ( # can_deref(shift(...)(lift(f)(args...)) -> and(can_deref(shift(...)(arg[0])), and(can_deref(shift(...)(arg[1])), ...)) - im.call("can_deref")(im.shift("I", 1)(im.lift("f")("arg1", "arg2"))), + im.can_deref(im.shift("I", 1)(im.lift("f")("arg1", "arg2"))), im.and_( - im.call("can_deref")(im.shift("I", 1)("arg1")), - im.call("can_deref")(im.shift("I", 1)("arg2")), + im.can_deref(im.shift("I", 1)("arg1")), + im.can_deref(im.shift("I", 1)("arg2")), ), ), ( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py index 77d3323fb4..b1a18ddab8 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py @@ -16,10 +16,10 @@ def test_prune_casts_simple(): x_ref = im.ref("x", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) y_ref = im.ref("y", ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) - testee = im.call("plus")(im.call("cast_")(x_ref, "float64"), im.call("cast_")(y_ref, "float64")) + testee = im.plus(im.cast_(x_ref, "float64"), im.cast_(y_ref, "float64")) testee = type_inference.infer(testee, offset_provider_type={}, allow_undeclared_symbols=True) - expected = im.call("plus")(im.call("cast_")(x_ref, "float64"), y_ref) + expected = im.plus(im.cast_(x_ref, "float64"), y_ref) actual = PruneCasts.apply(testee) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py index 1cf662e221..dd7a8f4d43 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py @@ -44,7 +44,7 @@ def test_neighbors(): def test_reduce(): # λ(inp) → reduce(plus, 0.)(·inp) - testee = im.lambda_("inp")(im.call(im.call("reduce")("plus", 0.0))(im.deref("inp"))) + testee = im.lambda_("inp")(im.reduce("plus", 0.0)(im.deref("inp"))) expected = [{()}] actual = TraceShifts.trace_stencil(testee) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py index 0760247996..2415a42267 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py @@ -35,27 +35,25 @@ def has_skip_values(request): @pytest.fixture def basic_reduction(): UIDs.reset_sequence() - return im.call(im.call("reduce")("foo", 0.0))(im.neighbors("Dim", "x")) + return im.reduce("foo", 0.0)(im.neighbors("Dim", "x")) @pytest.fixture def reduction_with_shift_on_second_arg(): UIDs.reset_sequence() - return im.call(im.call("reduce")("foo", 0.0))("x", im.neighbors("Dim", "y")) + return im.reduce("foo", 0.0)("x", im.neighbors("Dim", "y")) @pytest.fixture def reduction_with_incompatible_shifts(): UIDs.reset_sequence() - return im.call(im.call("reduce")("foo", 0.0))( - im.neighbors("Dim", "x"), im.neighbors("Dim2", "y") - ) + return im.reduce("foo", 0.0)(im.neighbors("Dim", "x"), im.neighbors("Dim2", "y")) @pytest.fixture def reduction_with_irrelevant_full_shift(): UIDs.reset_sequence() - return im.call(im.call("reduce")("foo", 0.0))( + return im.reduce("foo", 0.0)( im.neighbors("Dim", im.shift("IrrelevantDim", 0)("x")), im.neighbors("Dim", "y") ) @@ -63,7 +61,7 @@ def reduction_with_irrelevant_full_shift(): @pytest.fixture def reduction_if(): UIDs.reset_sequence() - return im.call(im.call("reduce")("foo", 0.0))(im.if_(True, im.neighbors("Dim", "x"), "y")) + return im.reduce("foo", 0.0)(im.if_(True, im.neighbors("Dim", "x"), "y")) @pytest.mark.parametrize( @@ -83,35 +81,26 @@ def test_get_partial_offsets(reduction, request): def _expected(red, dim, max_neighbors, has_skip_values, shifted_arg=0): - acc = ir.SymRef(id="_acc_1") - offset = ir.SymRef(id="_i_2") - step = ir.SymRef(id="_step_3") + acc, offset, step = "_acc_1", "_i_2", "_step_3" red_fun, red_init = red.fun.args - elements = [ir.FunCall(fun=ir.SymRef(id="list_get"), args=[offset, arg]) for arg in red.args] + elements = [im.list_get(offset, arg) for arg in red.args] - step_expr = ir.FunCall(fun=red_fun, args=[acc] + elements) + step_expr = im.call(red_fun)(acc, *elements) if has_skip_values: neighbors_offset = red.args[shifted_arg].args[0] neighbors_it = red.args[shifted_arg].args[1] - can_deref = ir.FunCall( - fun=ir.SymRef(id="can_deref"), - args=[ - ir.FunCall( - fun=ir.FunCall(fun=ir.SymRef(id="shift"), args=[neighbors_offset, offset]), - args=[neighbors_it], - ) - ], - ) - step_expr = ir.FunCall(fun=ir.SymRef(id="if_"), args=[can_deref, step_expr, acc]) - step_fun = ir.Lambda(params=[ir.Sym(id=acc.id), ir.Sym(id=offset.id)], expr=step_expr) + can_deref = im.can_deref(im.shift(neighbors_offset, offset)(neighbors_it)) + + step_expr = im.if_(can_deref, step_expr, acc) + step_fun = im.lambda_(acc, offset)(step_expr) step_app = red_init for i in range(max_neighbors): - step_app = ir.FunCall(fun=step, args=[step_app, ir.OffsetLiteral(value=i)]) + step_app = im.call(step)(step_app, ir.OffsetLiteral(value=i)) - return ir.FunCall(fun=ir.Lambda(params=[ir.Sym(id=step.id)], expr=step_app), args=[step_fun]) + return im.let(step, step_fun)(step_app) def test_basic(basic_reduction, has_skip_values): diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 53e463c6c7..e7053d3317 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -60,7 +60,7 @@ def program_example(): declarations=[], body=[ itir.SetAt( - expr=im.call(im.call("as_fieldop")(itir.SymRef(id="stencil"), domain))( + expr=im.as_fieldop(itir.SymRef(id="stencil"), domain)( itir.SymRef(id="buf"), itir.SymRef(id="sc") ), domain=domain, diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py index 97591122e5..50e8fa43f0 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py @@ -47,7 +47,7 @@ def test_get_domains(): declarations=[], body=[ itir.SetAt( - expr=im.call(im.call("as_fieldop")("deref"))(), + expr=im.as_fieldop("deref")(), domain=domain, target=itir.SymRef(id="bar"), ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index faf611878d..bfde179e33 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1183,9 +1183,7 @@ def test_gtir_neighbors_as_input(): gtir.SetAt( expr=im.as_fieldop( im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) + im.reduce("plus", im.literal_from_value(init_value))(im.deref("it")) ), vertex_domain, )( @@ -1283,25 +1281,15 @@ def test_gtir_neighbors_as_output(): def test_gtir_reduce(): init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) - stencil_inlined = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.neighbors("V2E", "it") - ) - ), - vertex_domain, - ) + stencil_inlined = im.as_fieldop( + im.lambda_("it")( + im.reduce("plus", im.literal_from_value(init_value))(im.neighbors("V2E", "it")) + ), + vertex_domain, )("edges") - stencil_fieldview = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) - ), - vertex_domain, - ) + stencil_fieldview = im.as_fieldop( + im.lambda_("it")(im.reduce("plus", im.literal_from_value(init_value))(im.deref("it"))), + vertex_domain, )(im.as_fieldop_neighbors("V2E", "edges", vertex_domain)) connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] @@ -1349,25 +1337,15 @@ def test_gtir_reduce(): def test_gtir_reduce_with_skip_values(): init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) - stencil_inlined = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.neighbors("V2E", "it") - ) - ), - vertex_domain, - ) + stencil_inlined = im.as_fieldop( + im.lambda_("it")( + im.reduce("plus", im.literal_from_value(init_value))(im.neighbors("V2E", "it")) + ), + vertex_domain, )("edges") - stencil_fieldview = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) - ), - vertex_domain, - ) + stencil_fieldview = im.as_fieldop( + im.lambda_("it")(im.reduce("plus", im.literal_from_value(init_value))(im.deref("it"))), + vertex_domain, )(im.as_fieldop_neighbors("V2E", "edges", vertex_domain)) connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] @@ -1450,15 +1428,11 @@ def test_gtir_reduce_dot_product(): declarations=[], body=[ gtir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) - ), - vertex_domain, - ) + expr=im.as_fieldop( + im.lambda_("it")( + im.reduce("plus", im.literal_from_value(init_value))(im.deref("it")) + ), + vertex_domain, )( im.op_as_fieldop(im.map_("plus"), vertex_domain)( im.op_as_fieldop(im.map_("multiplies"), vertex_domain)( @@ -1508,9 +1482,7 @@ def test_gtir_reduce_with_cond_neighbors(): gtir.SetAt( expr=im.as_fieldop( im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) + im.reduce("plus", im.literal_from_value(init_value))(im.deref("it")) ), vertex_domain, )( @@ -1958,8 +1930,8 @@ def test_gtir_if_scalars(): "f", im.if_( "pred", - im.call("cast_")("y_0", "float64"), - im.call("cast_")("y_1", "float64"), + im.cast_("y_0", "float64"), + im.cast_("y_1", "float64"), ), ) ) From a21b1bf2c0fa7a2c1a7e31d6202cba87f215cc26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Wed, 29 Jan 2025 07:47:04 +0100 Subject: [PATCH 118/178] docs[next]: Updated the ADR0018 (#1798) This PR clarifies some aspects in the ADR0018, that governs valid structures of the SDFG that are generated by the lowering. --------- Co-authored-by: Rico Haeuselmann --- ...Canonical_SDFG_in_GT4Py_Transformations.md | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md b/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md index 18b9c1f878..69e09c7fae 100644 --- a/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md +++ b/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md @@ -7,6 +7,7 @@ tags: [backend, dace, optimization] - **Status**: valid - **Authors**: Philip Müller (@philip-paul-mueller) - **Created**: 2024-08-27 +- **Updated**: 2025-01-15 In the context of the implementation of the new DaCe fieldview we decided about a particular form of the SDFG. Their main intent is to reduce the complexity of the GT4Py specific transformations. @@ -22,6 +23,12 @@ In the pipeline we distinguish between: The current (GT4Py) pipeline mainly focus on intrastate optimization and relays on DaCe, especially its simplify pass, for interstate optimizations. +## Changelog + +#### 2025-01-15: + +- Made the rules clearer. Specifically, made a restriction on global memory more explicit. + ## Decision The canonical form is defined by several rules that affect different aspects of an SDFG and what a transformation can assume. @@ -38,20 +45,24 @@ The following rules especially affects transformations and how they operate: - [Note 2]: It is allowed for an _intrastate_ transformation to act in a way that allows state fusion by later intrastate transformations. - [Note 3]: The DaCe simplification pass violates this rule, for that reason this pass must always be called on its own, see also rule 2. -2. It is invalid to call the simplification pass directly, i.e. the usage of `SDFG.simplify()` is not allowed. The only valid way to call _simplify()_ is to call the `gt_simplify()` function provided by GT4Py. +2. It is invalid to call DaCe's simplification pass directly, i.e. the usage of `SDFG.simplify()` is not allowed. The only valid way to call _simplify()_ is to call the `gt_simplify()` function provided by GT4Py. + - [Rationale]: It was observed that some sub-passes in _simplify()_ have a negative impact and that additional passes might be needed in the future. By using a single function later modifications to _simplify()_ are easy. - [Note]: One issue is that the remove redundant array transformation is not able to handle all cases. #### Global Memory -The only restriction we impose on global memory is: +Global memory has to adhere to the same rules as transient memory. +However, the following rule takes precedence, i.e. if this rule is fulfilled then rules 6 to 10 may be violated. + +3. The same global memory is allowed to be used as input and output at the same time, either in the SDFG or in a state, if and only if the output depends _elementwise_ on the input. -3. The same global memory is allowed to be used as input and output at the same time, if and only if the output depends _elementwise_ on the input. - [Rationale 1]: This allows the removal of double buffering, that DaCe may not remove. See also rule 2. - [Rationale 2]: This formulation allows writing expressions such as `a += 1`, with only memory for `a`. Phrased more technically, using global memory for input and output is allowed if and only if the two computations `tmp = computation(global_memory); global_memory = tmp;` and `global_memory = computation(global_memory);` are equivalent. - - [Note]: In the long term this rule will be changed to: Global memory (an array) is either used as input (only read from) or as output (only written to) but never for both. + - [Note 1]: This rule also forbids expressions such as `A[0:10] = A[1:11]`, where `A` refers to a global memory. + - [Note 2]: In the long term this rule will be changed to: Global memory (an array) is either used as input (only read from) or as output (only written to) but never for both. #### State Machine @@ -63,6 +74,7 @@ For the SDFG state machine we assume that: - [Note]: Running _simplify()_ might actually result in the violation of this rule, see note of rule 9. 5. The state graph does not contain any cycles, i.e. the implementation of a for/while loop using states is not allowed, the new loop construct or serial maps must be used in that case. + - [Rationale]: This is a simplification that makes it much simpler to define what "later in the computation" means, as we will never have a cycle. - [Note]: Currently the code generator does not support the `LoopRegion` construct and it is transformed to a state machine. @@ -93,7 +105,7 @@ It is important to note that these rules only have to be met after _simplify()_ 8. No two access nodes in a state can refer to the same array. - [Rationale]: Together with rule 5 this guarantees SSA style. - - [Note]: An SDFG can still be constructed using different access node for the same underlying data; _simplify()_ will combine them. + - [Note]: An SDFG can still be constructed using different access node for the same underlying data in the same state; _simplify()_ will combine them. 9. Every access node that reads from an array (having an outgoing edge) that was not written to in the same state must be a source node. @@ -103,6 +115,7 @@ It is important to note that these rules only have to be met after _simplify()_ Excess interstate transients, that will be kept alive that way, will be removed by later calls to _simplify()_. 10. Every AccessNode within a map scope must refer to a data descriptor whose lifetime must be `dace.dtypes.AllocationLifetime.Scope` and its storage class should either be `dace.dtypes.StorageType.Default` or _preferably_ `dace.dtypes.StorageType.Register`. + - [Rationale 1]: This makes optimizations operating inside maps/kernels simpler, as it guarantees that the AccessNode does not propagate outside. - [Rationale 2]: The storage type avoids the need to dynamically allocate memory inside a kernel. @@ -120,6 +133,7 @@ For maps we assume the following: - [Rationale]: Without this rule it is very hard to tell which map variable does what, this way we can transmit information from GT4Py to DaCe, see also rule 12. 12. Two map ranges, i.e. the pair map/iteration variable and range, can only be fused if they have the same name _and_ cover the same range. + - [Rationale 1]: Because of rule 11, we will only fuse maps that actually makes sense to fuse. - [Rationale 2]: This allows fusing maps without renaming the map variables. - [Note]: This rule might be dropped in the future. From d67bd7e1acb71b32f9fdd13ad55ff1734c2a131d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Wed, 29 Jan 2025 13:49:33 +0100 Subject: [PATCH 119/178] fix[dace]: Fixed a bug in `gt_make_transients_persistent()` (#1831) Make a transient persistent essentially makes it a global, it is thus shared among different iterations of a Map. Consider the following: ```c++ #pragma omp parallel for for(int k = 0; k != N; ++k) { double b = foo(k); bar(b); }; ``` In the above code each iteration has its own local copy of `b` that is not shared with other threads. If `b` would be declared persistent, the code essentially becomes ```c++ double* b = new double; #pragma omp parallel for for(int k = 0; k != N; ++k) { *b = foo(k); bar(*b); }; ``` i.e. now `b` is shared among the different threads and we have a data race. To avoid this situation we have to ensure that there is no data race and we do this by requiring that a data descriptor that should be turned into a persistent one, can not have an AccessNode inside a scope other than the top scope. Note that this restriction is stronger than necessary and could be relaxed, but it might be very difficult to figuring out if in a particular case it is possible or not. --- .../runners/dace/transformations/utils.py | 10 +++ .../test_make_transients_persistent.py | 74 +++++++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_make_transients_persistent.py diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py index 87308061e7..3cc2dadd89 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -50,6 +50,7 @@ def gt_make_transients_persistent( not_modify_lifetime: set[str] = set() for state in nsdfg.states(): + scope_dict = state.scope_dict() for dnode in state.data_nodes(): if dnode.data in not_modify_lifetime: continue @@ -70,6 +71,15 @@ def gt_make_transients_persistent( not_modify_lifetime.add(dnode.data) continue + # If the data is referenced inside a scope, such as a map, it might be possible + # that it is only used inside that scope. If we would make it persistent, then + # it would essentially be allocated outside and be shared among the different + # map iterations. So we can not make it persistent. + # The downside is, that we might have to perform dynamic allocation. + if scope_dict[dnode] is not None: + not_modify_lifetime.add(dnode.data) + continue + try: # The symbols describing the total size must be a subset of the # free symbols of the SDFG (symbols passed as argument). diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_make_transients_persistent.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_make_transients_persistent.py new file mode 100644 index 0000000000..d8cf8e33f8 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_make_transients_persistent.py @@ -0,0 +1,74 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_transients_persistent_inner_access_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("transients_persistent_inner_access_sdfg")) + state = sdfg.add_state(is_start_block=True) + + for name in "abc": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["b"].transient = True + + me: dace_nodes.MapEntry + mx: dace_nodes.MapExit + me, mx = state.add_map("comp", ndrange={"__i0": "0:10"}) + a, b, c = (state.add_access(name) for name in "abc") + tsklt: dace_nodes.Tasklet = state.add_tasklet( + "tsklt", + inputs={"__in"}, + code="__out = __in + 1.0", + outputs={"__out"}, + ) + + me.add_in_connector("IN_A") + state.add_edge(a, None, me, "IN_A", dace.Memlet("a[0:10]")) + + me.add_out_connector("OUT_A") + state.add_edge(me, "OUT_A", b, None, dace.Memlet("a[__i0] -> [__i0]")) + + state.add_edge(b, None, tsklt, "__in", dace.Memlet("b[__i0]")) + + mx.add_in_connector("IN_C") + state.add_edge(tsklt, "__out", mx, "IN_C", dace.Memlet("c[__i0]")) + + mx.add_out_connector("OUT_C") + state.add_edge(mx, "OUT_C", c, None, dace.Memlet("c[0:10]")) + sdfg.validate() + return sdfg, state + + +def test_make_transients_persistent_inner_access(): + sdfg, state = _make_transients_persistent_inner_access_sdfg() + assert sdfg.arrays["b"].lifetime is dace.dtypes.AllocationLifetime.Scope + + # Because `b`, the only transient, is used inside a map scope, it is not selected, + # although in this situation it would be possible. + change_report: dict[int, set[str]] = gtx_transformations.gt_make_transients_persistent( + sdfg, device=dace.DeviceType.CPU + ) + assert len(change_report) == 1 + assert change_report[sdfg.cfg_id] == set() From ac47ca691aa32deefe925b8d82246b3a56b38cee Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 29 Jan 2025 14:38:37 +0100 Subject: [PATCH 120/178] fix[next][dace]: handle if-expression as non-exclusive in field view (#1824) The implementation of scan field operator introduced a bug in the lowering of if-statements to SDFG. The scan requires extended lowering support for iterator view, therefore #1790 introduced the lowering of local-if with exclusive branch execution. However, attention was not paid to enable the exclusive behavior of if-expressions only in iterator view, that is only in the scope of scan field operators. Regular field operators should instead lower if-expressions to a tasklet, because the field view behavior is that of a local select. --- .../runners/dace/gtir_dataflow.py | 91 ++++++++++++++++++- .../runners/dace/gtir_scan_translator.py | 7 +- 2 files changed, 93 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py index e00e363ac4..a34828afcb 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -332,6 +332,7 @@ class LambdaToDataflow(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState subgraph_builder: gtir_sdfg.DataflowBuilder + scan_carry_symbol: Optional[gtir.Sym] input_edges: list[DataflowInputEdge] = dataclasses.field(default_factory=lambda: []) symbol_map: dict[ str, @@ -693,7 +694,7 @@ def _visit_if_branch( # visit each branch of the if-statement as if it was a Lambda node lambda_node = gtir.Lambda(params=lambda_params, expr=expr) input_edges, output_tree = translate_lambda_to_dataflow( - if_sdfg, if_branch_state, self.subgraph_builder, lambda_node, args=lambda_args + if_sdfg, if_branch_state, self.subgraph_builder, lambda_node, lambda_args ) for data_node in if_branch_state.data_nodes(): @@ -736,7 +737,12 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A """ Lowers an if-expression with exclusive branch execution into a nested SDFG, in which each branch is lowered into a dataflow in a separate state and - the if-condition is represented as the inter-state edge condtion. + the if-condition is represented as the inter-state edge condition. + + Exclusive branch execution for local if expressions is meant to be used + in iterator view. Iterator view is required ONLY inside scan field operators. + For regular field operators, the fieldview behavior of if-expressions + corresponds to a local select, therefore it should be lowered to a tasklet. """ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExpr: @@ -1633,13 +1639,87 @@ def _visit_tuple_get( tuple_fields = self.visit(node.args[1]) return tuple_fields[index] + def requires_exclusive_if(self, node: gtir.FunCall) -> bool: + """ + The meaning of `if_` builtin function is unclear in GTIR. + In some context, it corresponds to a ternary operator where, depending on + the condition result, only one branch or the other should be executed, + because one of them is invalid. The typical case is the use of `if_` to + decide whether it is possible or not to access a shifted iterator, for + example when the condition expression calls `can_deref`. + The ternary operator is also used in iterator view, where the field arguments + are not necessarily both defined on the entire output domain (this behavior + should not appear in field view, because there the user code should use + `concat_where` instead of `where` for such cases). It is difficult to catch + such behavior, because it would require to know the exact domain of all + fields, which is not known at compile time. However, the iterator view + behavior should only appear inside scan field operators. + A different usage of `if_` expressions is selecting one argument value or + the other, where both arguments are defined on the output domain, therefore + always valid. + In order to simplify the SDFG and facilitate the optimization stage, we + try to avoid the ternary operator form when not needed. The reason is that + exclusive branch execution is represented in the SDFG as a conditional + state transition, which prevents fusion. + """ + assert cpm.is_call_to(node, "if_") + assert len(node.args) == 3 + + condition_vars = ( + eve.walk_values(node.args[0]) + .if_isinstance(gtir.SymRef) + .map(lambda node: str(node.id)) + .filter(lambda x: x in self.symbol_map) + .to_set() + ) + + # first, check if any argument contains shift expressions that depend on the condition variables + for arg in node.args[1:3]: + shift_nodes = ( + eve.walk_values(arg).filter(lambda node: cpm.is_applied_shift(node)).to_set() + ) + for shift_node in shift_nodes: + shift_vars = ( + eve.walk_values(shift_node) + .if_isinstance(gtir.SymRef) + .map(lambda node: str(node.id)) + .filter(lambda x: x in self.symbol_map) + .to_set() + ) + # require exclusive branch execution if any shift expression one of + # the if branches accesses a variable used in the condition expression + depend_vars = condition_vars.intersection(shift_vars) + if len(depend_vars) != 0: + return True + + # secondly, check whether the `if_` branches access different sets of fields + # and this happens inside a scan field operator + if self.scan_carry_symbol is not None: + # the `if_` node is inside a scan stencil expression + scan_carry_var = str(self.scan_carry_symbol.id) + if scan_carry_var in condition_vars: + br1_vars, br2_vars = ( + eve.walk_values(arg) + .if_isinstance(gtir.SymRef) + .map(lambda node: str(node.id)) + .filter(lambda x: isinstance(self.symbol_map.get(x, None), MemletExpr)) + .to_set() + for arg in node.args[1:3] + ) + if br1_vars != br2_vars: + # the two branches of the `if_` expression access different sets of fields, + # depending on the scan carry value + return True + + return False + def visit_FunCall( self, node: gtir.FunCall ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) - elif cpm.is_call_to(node, "if_"): + elif cpm.is_call_to(node, "if_") and self.requires_exclusive_if(node): return self._visit_if(node) elif cpm.is_call_to(node, "neighbors"): @@ -1776,6 +1856,7 @@ def translate_lambda_to_dataflow( | ValueExpr | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] ], + scan_carry_symbol: Optional[gtir.Sym] = None, ) -> tuple[ list[DataflowInputEdge], tuple[DataflowOutputEdge | tuple[Any, ...], ...], @@ -1794,13 +1875,15 @@ def translate_lambda_to_dataflow( sdfg_builder: Helper class to build the dataflow inside the given SDFG. node: Lambda node to visit. args: Arguments passed to lambda node. + scan_carry_symbol: When set, the lowering of `if_` expression will consider + using the ternary operator form with exclusive branch execution. Returns: A tuple of two elements: - List of connections for data inputs to the dataflow. - Tree representation of output data connections. """ - taskgen = LambdaToDataflow(sdfg, state, sdfg_builder) + taskgen = LambdaToDataflow(sdfg, state, sdfg_builder, scan_carry_symbol) lambda_output = taskgen.visit_let(node, args) if isinstance(lambda_output, DataflowOutputEdge): diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py b/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py index d3d8e101a7..ec88cd8f84 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py @@ -417,7 +417,12 @@ def init_scan_carry(sym: gtir.Sym) -> None: # stil inside the 'compute' state, generate the dataflow representing the stencil # to be applied on the horizontal domain lambda_input_edges, lambda_result = gtir_dataflow.translate_lambda_to_dataflow( - nsdfg, compute_state, lambda_translator, lambda_node, args=stencil_args + nsdfg, + compute_state, + lambda_translator, + lambda_node, + stencil_args, + scan_carry_symbol=scan_carry_symbol, ) # connect the dataflow input directly to the source data nodes, without passing through a map node; # the reason is that the map for horizontal domain is outside the scan loop region From f732f08020e0786fe955231eb8794c83c5778d00 Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 29 Jan 2025 15:51:52 +0100 Subject: [PATCH 121/178] refactor[next][daxce]: cleanup get_sdfg_args() parameter list (#1834) Remove unnecessary `kwargs` from parameter list of `get_sdfg_args()`. --- .../runners/dace/sdfg_callable.py | 14 +++++++++++--- .../runners/dace/workflow/decoration.py | 1 - 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py b/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py index ecd23619e5..7f221a5a41 100644 --- a/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py +++ b/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py @@ -129,15 +129,23 @@ def get_sdfg_args( *args: Any, check_args: bool = False, on_gpu: bool = False, - **kwargs: Any, ) -> dict[str, Any]: """Extracts the arguments needed to call the SDFG. - This function can handle the same arguments that are passed to dace runner. + This function can handle the arguments that are passed to the dace runner + and that end up in the decoration stage of the dace backend workflow. Args: sdfg: The SDFG for which we want to get the arguments. - offset_provider: Offset provider. + offset_provider: The offset provider. + args: The list of arguments passed to the dace runner. + check_args: If True, return only the arguments that are expected + according to the SDFG signature. + on_gpu: If True, this method ensures that the arrays for the + connectivity tables are allocated in GPU memory. + + Returns: + A dictionary of keyword arguments to be passed in the SDFG call. """ dace_args = _get_args(sdfg, args) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 2ee99f5fa4..9648ac9e04 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -86,7 +86,6 @@ def decorated_program( *flat_args, check_args=False, on_gpu=on_gpu, - use_field_canonical_representation=use_field_canonical_representation, ) with dace.config.temporary_config(): From c228f224e4a416d8b224d3bcfad0eda92755ff88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 29 Jan 2025 15:53:00 +0100 Subject: [PATCH 122/178] build: update infrastructure to use the uv tool (#1813) Modernize development infrastructure to use newer python tools. The `uv` (https://docs.astral.sh/uv/) tool is used as project management tool, `nox` (https://nox.thea.codes/) replaces `tox` as the project test runner, custom scripts and `cog` snippets are replaced by a central `dev-tasks.py` tool with multiple commands. Additionally, documentation and CI workflows are updated accordingly and some tests which were not properly executed properly are fixed. Other related tasks to enhance the development infrastructure which could be done after this PR is merged are being tracked in issue #1829 --- .github/workflows/code-quality.yml | 18 +- .github/workflows/daily-ci.yml | 126 +- .github/workflows/test-cartesian-fallback.yml | 6 +- .github/workflows/test-cartesian.yml | 43 +- .github/workflows/test-eve-fallback.yml | 10 +- .github/workflows/test-eve.yml | 51 +- .github/workflows/test-examples.yml | 44 + .github/workflows/test-next-fallback.yml | 5 +- .github/workflows/test-next.yml | 60 +- .github/workflows/test-notebooks.yml | 42 - .github/workflows/test-storage-fallback.yml | 11 +- .github/workflows/test-storage.yml | 52 +- .pre-commit-config.yaml | 95 +- .python-version | 1 + CONTRIBUTING.md | 26 +- LICENSE.txt => LICENSE | 0 README.md | 86 +- ci/base.Dockerfile | 2 +- ci/cscs-ci.yml | 19 +- constraints.txt | 178 - dev-tasks.py | 97 + docs/development/tools/ci-infrastructure.md | 2 +- docs/development/tools/requirements.md | 27 - min-extra-requirements-test.txt | 110 - min-requirements-test.txt | 104 - noxfile.py | 251 ++ pyproject.toml | 122 +- requirements-dev.in | 36 - requirements-dev.txt | 178 - tach.toml | 4 +- .../multi_feature_tests/test_dace_parsing.py | 2 +- tests/conftest.py | 50 + tests/next_tests/definitions.py | 11 +- .../feature_tests/dace/__init__.py | 4 + .../feature_tests/dace/test_orchestration.py | 7 +- .../feature_tests/dace/test_program.py | 13 +- .../iterator_tests/test_extractors.py | 56 +- .../ffront_tests/test_ffront_fvm_nabla.py | 21 +- .../runners_tests/dace_tests/__init__.py | 3 +- .../runners_tests/dace_tests/test_dace.py | 2 - .../dace_tests/test_gtir_to_sdfg.py | 2 - .../transformation_tests/__init__.py | 3 +- .../transformation_tests/conftest.py | 2 +- .../test_distributed_buffer_relocator.py | 9 +- .../transformation_tests/test_gpu_utils.py | 2 +- .../test_loop_blocking.py | 2 +- .../transformation_tests/test_map_fusion.py | 2 +- .../transformation_tests/test_map_order.py | 8 +- .../test_serial_map_promoter.py | 2 +- tox.ini | 190 - uv.lock | 3345 +++++++++++++++++ 51 files changed, 4181 insertions(+), 1361 deletions(-) create mode 100644 .github/workflows/test-examples.yml delete mode 100644 .github/workflows/test-notebooks.yml create mode 100644 .python-version rename LICENSE.txt => LICENSE (100%) delete mode 100644 constraints.txt create mode 100755 dev-tasks.py delete mode 100644 docs/development/tools/requirements.md delete mode 100644 min-extra-requirements-test.txt delete mode 100644 min-requirements-test.txt create mode 100644 noxfile.py delete mode 100644 requirements-dev.in delete mode 100644 requirements-dev.txt delete mode 100644 tox.ini create mode 100644 uv.lock diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index d54fea9269..10bb537e3e 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -13,13 +13,17 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.10" - cache: 'pip' - cache-dependency-path: | - **/pyproject.toml - **/constraints.txt - **/requirements-dev.txt - - uses: pre-commit/action@v3.0.1 + python-version-file: ".python-version" + + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: "Run pre-commit" + uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/daily-ci.yml b/.github/workflows/daily-ci.yml index 28512a18ac..a2a52ce1ff 100644 --- a/.github/workflows/daily-ci.yml +++ b/.github/workflows/daily-ci.yml @@ -5,7 +5,8 @@ on: - cron: '0 4 * * *' workflow_dispatch: - ## COMMENTED OUT: only for testing CI action changes + ## COMMENTED OUT: only for testing CI action changes. + ## It only works for PRs to `main` branch from branches in the upstream gt4py repo. # pull_request: # branches: # - main @@ -15,106 +16,87 @@ jobs: daily-ci: strategy: matrix: + # dependencies-strategy -> The strategy that `uv lock` should use to select + # between the different compatible versions for a given package requirement + # [arg: --resolution, env: UV_RESOLUTION=] + dependencies-strategy: ["lowest-direct", "highest"] + gt4py-module: ["cartesian", "eve", "next", "storage"] + os: ["ubuntu-latest"] #, "macos-latest"] python-version: ["3.10", "3.11"] - tox-module-factor: ["cartesian", "eve", "next", "storage"] - os: ["ubuntu-latest"] - requirements-file: ["requirements-dev.txt", "min-requirements-test.txt", "min-extra-requirements-test.txt"] fail-fast: false runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 + - name: Install C++ libraries if: ${{ matrix.os == 'macos-latest' }} shell: bash - run: | - brew install boost + run: brew install boost + - name: Install C++ libraries if: ${{ matrix.os == 'ubuntu-latest' }} shell: bash - run: | - sudo apt install libboost-dev - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + run: sudo apt install libboost-dev + + - name: Install uv and set the python version + uses: astral-sh/setup-uv@v5 with: + enable-cache: true + cache-dependency-glob: "uv.lock" python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: | - **/pyproject.toml - **/constraints.txt - **/requirements-dev.txt - - name: Install tox - run: | - python -m pip install -c ./constraints.txt pip setuptools wheel tox - python -m pip list - - name: Update requirements - run: | - pyversion=${{ matrix.python-version }} - pyversion_no_dot=${pyversion//./} - tox run -e requirements-py${pyversion_no_dot} - # TODO(egparedes): add notification for dependencies updates - # - name: Check for updated requirements - # id: update-requirements - # continue-on-error: true - # if: ${{ matrix.python-version == '3.8' && matrix.tox-module-factor == 'cartesian' }} - # shell: bash - # run: | - # if diff -q constraints.txt CURRENT-constraints.txt; then - # echo "REQS_DIFF=''" >> $GITHUB_OUTPUT - # else - # diff --changed-group-format='%<' --unchanged-group-format='' constraints.txt CURRENT-constraints.txt | tr '\n' ' ' > constraints.txt.diff - # echo "REQS_DIFF='$(cat constraints.txt.diff)'" >> $GITHUB_OUTPUT - # fi - # echo "REQS_DIFF_TEST="FOOOOOOOO" >> $GITHUB_OUTPUT - # - name: Notify updated requirements (if any) - # if: ${{ steps.update-requirements.outputs.REQS_DIFF }} - # env: - # SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} - # uses: slackapi/slack-github-action@v1.23.0 - # with: - # channel-id: ${{ vars.SLACK_BOT_CHANNEL }} - # payload: | - # { - # "text": "TEXT", - # "blocks": [ - # { - # "type": "section", - # "text": { - # "type": "plain_text", - # "text": "@channel: AA/${{ steps.update-requirements.outputs.REQS_DIFF }}/BB/ ${{ steps.update-requirements.outputs.REQS_DIFF_TEST }} /CC" - # } - # }, - # { - # "type": "section", - # "text": { - # "type": "mrkdwn", - # "text": "@channel: AA/${{ steps.update-requirements.outputs.REQS_DIFF }}/BB/ ${{ steps.update-requirements.outputs.REQS_DIFF_TEST }} /CC" - # } - # } - # ] - # } - - name: Run tests + + - name: Run CPU tests for '${{ matrix.gt4py-module }}' with '${{ matrix.dependencies-strategy }}' resolution strategy env: NUM_PROCESSES: auto - ENV_REQUIREMENTS_FILE: ${{ matrix.requirements-file }} - run: | - tox run --skip-missing-interpreters -m test-${{ matrix.tox-module-factor }}-cpu + UV_RESOLUTION: ${{ matrix.dependencies-strategy }} + run: uv run nox -s 'test_${{ matrix.gt4py-module }}-${{ matrix.python-version }}' -t 'cpu' + - name: Notify slack if: ${{ failure() }} env: SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} uses: slackapi/slack-github-action@v1.23.0 with: - channel-id: ${{ vars.SLACK_BOT_CHANNEL }} + channel-id: ${{ vars.SLACK_BOT_CHANNEL }} # Use SLACK_BOT_CHANNEL_TEST for testing + payload: | + { + "text": "Failed tests for ${{ github.workflow }} (dependencies-strategy=${{ matrix.dependencies-strategy }}, python=${{ matrix.python-version }}, component=${{ matrix.gt4py-module }}) [https://github.com/GridTools/gt4py/actions/runs/${{ github.run_id }}].", + "blocks": [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": "Failed tests: " + } + } + ] + } + + weekly-reminder: + runs-on: ubuntu-latest + steps: + - id: get_day_of_the_week + name: Get day of the week + run: echo "day_of_week=$(date +'%u')" >> $GITHUB_OUTPUT + + - name: Weekly notification + if: ${{ env.DAY_OF_WEEK == 1 }} + env: + DAY_OF_WEEK: ${{ steps.get_day_of_the_week.outputs.day_of_week }} + SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} + uses: slackapi/slack-github-action@v1.23.0 + with: + channel-id: ${{ vars.SLACK_BOT_CHANNEL }} # Use SLACK_BOT_CHANNEL_TEST for testing payload: | { - "text": "${{ github.workflow }}: `test-${{ matrix.tox-module-factor }}-cpu (python${{ matrix.python-version }})`>: *Failed tests!*", + "text": "Weekly reminder to check the latest runs of the GT4Py Daily CI workflow at the GitHub Actions dashboard [https://github.com/GridTools/gt4py/actions/workflows/daily-ci.yml].", "blocks": [ { "type": "section", "text": { "type": "mrkdwn", - "text": ": *Failed tests!*" + "text": "Weekly reminder to check the latest runs of the workflow at the GitHub Actions dashboard." } } ] diff --git a/.github/workflows/test-cartesian-fallback.yml b/.github/workflows/test-cartesian-fallback.yml index 8061ca56b9..a846af2e7b 100644 --- a/.github/workflows/test-cartesian-fallback.yml +++ b/.github/workflows/test-cartesian-fallback.yml @@ -13,11 +13,11 @@ on: jobs: test-cartesian: - runs-on: ubuntu-latest strategy: matrix: + codegen-factor: [internal, dace] + os: ["ubuntu-latest"] python-version: ["3.10", "3.11"] - tox-factor: [internal, dace] - + runs-on: ${{ matrix.os }} steps: - run: 'echo "No build required"' diff --git a/.github/workflows/test-cartesian.yml b/.github/workflows/test-cartesian.yml index 4b5a790f4d..ea6b7940a3 100644 --- a/.github/workflows/test-cartesian.yml +++ b/.github/workflows/test-cartesian.yml @@ -7,7 +7,7 @@ on: pull_request: branches: - main - paths-ignore: # Skip if only gt4py.next and irrelevant doc files have been updated + paths-ignore: # Skip when only gt4py.next or doc files have been updated - "src/gt4py/next/**" - "tests/next_tests/**" - "examples/**" @@ -20,35 +20,36 @@ concurrency: jobs: test-cartesian: - runs-on: ubuntu-latest strategy: matrix: + codegen-factor: [internal, dace] + os: ["ubuntu-latest"] python-version: ["3.10", "3.11"] - tox-factor: [internal, dace] + fail-fast: false + + runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 + - name: Install C++ libraries + if: ${{ matrix.os == 'macos-latest' }} shell: bash - run: | - sudo apt install libboost-dev - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + run: brew install boost + + - name: Install C++ libraries + if: ${{ matrix.os == 'ubuntu-latest' }} + shell: bash + run: sudo apt install libboost-dev + + - name: Install uv and set the python version + uses: astral-sh/setup-uv@v5 with: + enable-cache: true + cache-dependency-glob: "uv.lock" python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: | - **/pyproject.toml - **/constraints.txt - **/requirements-dev.txt - - name: Install python dependencies - run: | - python -m pip install -c ./constraints.txt pip setuptools wheel - python -m pip install -r ./requirements-dev.txt - - name: Test with tox + + - name: Run CPU 'cartesian' tests with nox env: NUM_PROCESSES: auto shell: bash - run: | - pyversion=${{ matrix.python-version }} - pyversion_no_dot=${pyversion//./} - tox run -e cartesian-py${pyversion_no_dot}-${{ matrix.tox-factor }}-cpu + run: uv run nox -s 'test_cartesian-${{ matrix.python-version }}(${{ matrix.codegen-factor }}, cpu)' diff --git a/.github/workflows/test-eve-fallback.yml b/.github/workflows/test-eve-fallback.yml index 78f6136888..f3dbb58acf 100644 --- a/.github/workflows/test-eve-fallback.yml +++ b/.github/workflows/test-eve-fallback.yml @@ -1,15 +1,17 @@ name: "Fallback: Test Eve" on: + push: + branches: + - main pull_request: branches: - main paths-ignore: # Inverse of corresponding workflow - "src/gt4py/eve/**" - "tests/eve_tests/**" - - "workflows/**" - - "*.cfg" - - "*.ini" + - ".github/workflows/**" + - "*.lock" - "*.toml" - "*.yml" @@ -17,8 +19,8 @@ jobs: test-eve: strategy: matrix: - python-version: ["3.10", "3.11"] os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] runs-on: ${{ matrix.os }} steps: diff --git a/.github/workflows/test-eve.yml b/.github/workflows/test-eve.yml index 9d48d50c03..aad3971ad0 100644 --- a/.github/workflows/test-eve.yml +++ b/.github/workflows/test-eve.yml @@ -10,9 +10,8 @@ on: paths: # Run when gt4py.eve files (or package settings) are changed - "src/gt4py/eve/**" - "tests/eve_tests/**" - - "workflows/**" - - "*.cfg" - - "*.ini" + - ".github/workflows/**" + - "*.lock" - "*.toml" - "*.yml" @@ -20,51 +19,23 @@ jobs: test-eve: strategy: matrix: - python-version: ["3.10", "3.11"] os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] fail-fast: false runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + + - name: Install uv and set the python version + uses: astral-sh/setup-uv@v5 with: + enable-cache: true + cache-dependency-glob: "uv.lock" python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: | - **/pyproject.toml - **/constraints.txt - **/requirements-dev.txt - - name: Install python dependencies - run: | - python -m pip install -c ./constraints.txt pip setuptools wheel - python -m pip install -r ./requirements-dev.txt - - name: Run tox tests + + - name: Run 'eve' tests with nox env: NUM_PROCESSES: auto shell: bash - run: | - pyversion=${{ matrix.python-version }} - pyversion_no_dot=${pyversion//./} - tox run -e eve-py${pyversion_no_dot} - # mv coverage.json coverage-py${{ matrix.python-version }}-${{ matrix.os }}.json - # - name: Upload coverage.json artifact - # uses: actions/upload-artifact@v4 - # with: - # name: coverage-py${{ matrix.python-version }}-${{ matrix.os }} - # path: coverage-py${{ matrix.python-version }}-${{ matrix.os }}.json - # - name: Gather info - # run: | - # echo ${{ github.ref_type }} >> info.txt - # echo ${{ github.ref }} >> info.txt - # echo ${{ github.sha }} >> info.txt - # echo ${{ github.event.number }} >> info.txt - # echo ${{ github.event.pull_request.head.ref }} >> info.txt - # echo ${{ github.event.pull_request.head.sha }} >> info.txt - # echo ${{ github.run_id }} >> info.txt - # - name: Upload info artifact - # uses: actions/upload-artifact@v4 - # with: - # name: info-py${{ matrix.python-version }}-${{ matrix.os }} - # path: info.txt + run: uv run nox -s test_eve-${{ matrix.python-version }} diff --git a/.github/workflows/test-examples.yml b/.github/workflows/test-examples.yml new file mode 100644 index 0000000000..836af45dd1 --- /dev/null +++ b/.github/workflows/test-examples.yml @@ -0,0 +1,44 @@ +name: "Test examples in documentation" + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + test-notebooks: + strategy: + matrix: + os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] + fail-fast: false + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + + - name: Install C++ libraries + if: ${{ matrix.os == 'macos-latest' }} + shell: bash + run: brew install boost + + - name: Install C++ libraries + if: ${{ matrix.os == 'ubuntu-latest' }} + shell: bash + run: sudo apt install libboost-dev + + - name: Install uv and set the python version + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + python-version: ${{ matrix.python-version }} + + - name: Run 'docs' nox session + env: + NUM_PROCESSES: auto + shell: bash + run: uv run nox -s 'test_examples-${{ matrix.python-version }}' diff --git a/.github/workflows/test-next-fallback.yml b/.github/workflows/test-next-fallback.yml index 16a0cf0df3..ef8be3df5f 100644 --- a/.github/workflows/test-next-fallback.yml +++ b/.github/workflows/test-next-fallback.yml @@ -15,9 +15,10 @@ jobs: test-next: strategy: matrix: - python-version: ["3.10", "3.11"] - tox-factor: ["nomesh", "atlas"] + codegen-factor: [internal, dace] + mesh-factor: [nomesh, atlas] os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] runs-on: ${{ matrix.os }} steps: diff --git a/.github/workflows/test-next.yml b/.github/workflows/test-next.yml index 1928370202..068377c6c7 100644 --- a/.github/workflows/test-next.yml +++ b/.github/workflows/test-next.yml @@ -7,7 +7,7 @@ on: pull_request: branches: - main - paths-ignore: # Skip if only gt4py.cartesian and irrelevant doc files have been updated + paths-ignore: # Skip when only gt4py.cartesian or doc files have been updated - "src/gt4py/cartesian/**" - "tests/cartesian_tests/**" - "examples/**" @@ -18,63 +18,35 @@ jobs: test-next: strategy: matrix: - python-version: ["3.10", "3.11"] - tox-factor: ["nomesh", "atlas"] + codegen-factor: [internal, dace] + mesh-factor: [nomesh, atlas] os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] fail-fast: false runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 + - name: Install C++ libraries if: ${{ matrix.os == 'macos-latest' }} shell: bash - run: | - brew install boost + run: brew install boost + - name: Install C++ libraries if: ${{ matrix.os == 'ubuntu-latest' }} shell: bash - run: | - sudo apt install libboost-dev - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + run: sudo apt install libboost-dev + + - name: Install uv and set the python version + uses: astral-sh/setup-uv@v5 with: + enable-cache: true + cache-dependency-glob: "uv.lock" python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: | - **/pyproject.toml - **/constraints.txt - **/requirements-dev.txt - - name: Install python dependencies - shell: bash - run: | - python -m pip install -c ./constraints.txt pip setuptools wheel - python -m pip install -r ./requirements-dev.txt - - name: Run tox tests + + - name: Run CPU 'next' tests with nox env: NUM_PROCESSES: auto shell: bash - run: | - pyversion=${{ matrix.python-version }} - pyversion_no_dot=${pyversion//./} - tox run -e next-py${pyversion_no_dot}-${{ matrix.tox-factor }}-cpu - # mv coverage.json coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu.json - # - name: Upload coverage.json artifact - # uses: actions/upload-artifact@v4 - # with: - # name: coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu - # path: coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu.json - # - name: Gather info - # run: | - # echo ${{ github.ref_type }} >> info.txt - # echo ${{ github.ref }} >> info.txt - # echo ${{ github.sha }} >> info.txt - # echo ${{ github.event.number }} >> info.txt - # echo ${{ github.event.pull_request.head.ref }} >> info.txt - # echo ${{ github.event.pull_request.head.sha }} >> info.txt - # echo ${{ github.run_id }} >> info.txt - # - name: Upload info artifact - # uses: actions/upload-artifact@v4 - # with: - # name: info-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu - # path: info.txt + run: uv run nox -s 'test_next-${{ matrix.python-version }}(${{ matrix.codegen-factor }}, cpu, ${{ matrix.mesh-factor }})' diff --git a/.github/workflows/test-notebooks.yml b/.github/workflows/test-notebooks.yml deleted file mode 100644 index ae45cb154d..0000000000 --- a/.github/workflows/test-notebooks.yml +++ /dev/null @@ -1,42 +0,0 @@ -name: "Test Jupyter Notebooks" - -on: - push: - branches: - - main - pull_request: - branches: - - main - -jobs: - test-notebooks: - strategy: - matrix: - python-version: ["3.10", "3.11"] - os: ["ubuntu-latest"] - fail-fast: false - - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: | - **/pyproject.toml - **/constraints.txt - **/requirements-dev.txt - - name: Install python dependencies - run: | - python -m pip install -c ./constraints.txt pip setuptools wheel - python -m pip install -r ./requirements-dev.txt - - name: Run tox tests - env: - NUM_PROCESSES: auto - shell: bash - run: | - pyversion=${{ matrix.python-version }} - pyversion_no_dot=${pyversion//./} - tox run -e notebooks-py${pyversion_no_dot} diff --git a/.github/workflows/test-storage-fallback.yml b/.github/workflows/test-storage-fallback.yml index 46a4442520..c913529a1c 100644 --- a/.github/workflows/test-storage-fallback.yml +++ b/.github/workflows/test-storage-fallback.yml @@ -1,6 +1,9 @@ name: "Fallback: Test Storage (CPU)" on: + push: + branches: + - main pull_request: branches: - main @@ -8,9 +11,8 @@ on: - "src/gt4py/storage/**" - "src/gt4py/cartesian/backend/**" # For DaCe storages - "tests/storage_tests/**" - - "workflows/**" - - "*.cfg" - - "*.ini" + - ".github/workflows/**" + - "*.lock" - "*.toml" - "*.yml" @@ -18,9 +20,8 @@ jobs: test-storage: strategy: matrix: - python-version: ["3.10", "3.11"] - tox-factor: [internal, dace] os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] runs-on: ${{ matrix.os }} steps: diff --git a/.github/workflows/test-storage.yml b/.github/workflows/test-storage.yml index 3748ac193e..b2bb09dfcc 100644 --- a/.github/workflows/test-storage.yml +++ b/.github/workflows/test-storage.yml @@ -11,9 +11,8 @@ on: - "src/gt4py/storage/**" - "src/gt4py/cartesian/backend/**" # For DaCe storages - "tests/storage_tests/**" - - "workflows/**" - - "*.cfg" - - "*.ini" + - ".github/workflows/**" + - "*.lock" - "*.toml" - "*.yml" @@ -21,52 +20,23 @@ jobs: test-storage: strategy: matrix: - python-version: ["3.10", "3.11"] - tox-factor: [internal, dace] os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] fail-fast: false runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + + - name: Install uv and set the python version + uses: astral-sh/setup-uv@v5 with: + enable-cache: true + cache-dependency-glob: "uv.lock" python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: | - **/pyproject.toml - **/constraints.txt - **/requirements-dev.txt - - name: Install python dependencies - run: | - python -m pip install -c ./constraints.txt pip setuptools wheel - python -m pip install -r ./requirements-dev.txt - - name: Run tox tests + + - name: Run CPU 'storage' tests with nox env: NUM_PROCESSES: auto shell: bash - run: | - pyversion=${{ matrix.python-version }} - pyversion_no_dot=${pyversion//./} - tox run -e storage-py${pyversion_no_dot}-${{ matrix.tox-factor }}-cpu - # mv coverage.json coverage-py${{ matrix.python-version }}-${{ matrix.os }}.json - # - name: Upload coverage.json artifact - # uses: actions/upload-artifact@v4 - # with: - # name: coverage-py${{ matrix.python-version }}-${{ matrix.os }} - # path: coverage-py${{ matrix.python-version }}-${{ matrix.os }}.json - # - name: Gather info - # run: | - # echo ${{ github.ref_type }} >> info.txt - # echo ${{ github.ref }} >> info.txt - # echo ${{ github.sha }} >> info.txt - # echo ${{ github.event.number }} >> info.txt - # echo ${{ github.event.pull_request.head.ref }} >> info.txt - # echo ${{ github.event.pull_request.head.sha }} >> info.txt - # echo ${{ github.run_id }} >> info.txt - # - name: Upload info artifact - # uses: actions/upload-artifact@v4 - # with: - # name: info-py${{ matrix.python-version }}-${{ matrix.os }} - # path: info.txt + run: uv run nox -s 'test_storage-${{ matrix.python-version }}(cpu)' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 051781ea49..4222224cc4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,21 +1,23 @@ -# ----------------------------------------------------------------------- -# This file contains 'cog' snippets (https://nedbatchelder.com/code/cog/) -# to keep version numbers in sync with 'constraints.txt' -# ----------------------------------------------------------------------- - default_language_version: python: python3.10 +minimum_pre_commit_version: 3.8.0 repos: # - repo: meta # hooks: # - id: check-hooks-apply # - id: check-useless-excludes + +- repo: https://github.com/astral-sh/uv-pre-commit + # uv version. + rev: 0.5.10 + hooks: + - id: uv-lock + - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks rev: v2.6.0 hooks: - id: pretty-format-ini args: [--autofix] - exclude: tox.ini - id: pretty-format-toml args: [--autofix] exclude: tach.toml @@ -45,86 +47,25 @@ repos: - id: check-yaml - repo: https://github.com/astral-sh/ruff-pre-commit - ##[[[cog - ## import re - ## version = re.search('ruff==([0-9\.]*)', open("constraints.txt").read())[1] - ## print(f"rev: v{version}") - ##]]] rev: v0.8.6 - ##[[[end]]] hooks: - # Run the linter. - # TODO: include tests here - id: ruff - files: ^src/ + files: ^src/ # TODO(egparedes): also add the `tests` folder here args: [--fix] - # Run the formatter. - id: ruff-format - repo: https://github.com/gauge-sh/tach-pre-commit - rev: v0.10.7 + rev: v0.23.0 hooks: - id: tach -- repo: https://github.com/pre-commit/mirrors-mypy - ##[[[cog - ## import re - ## version = re.search('mypy==([0-9\.]*)', open("constraints.txt").read())[1] - ## print(f"#========= FROM constraints.txt: v{version} =========") - ##]]] - #========= FROM constraints.txt: v1.14.1 ========= - ##[[[end]]] - rev: v1.14.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) +- repo: local hooks: - id: mypy - additional_dependencies: # versions from constraints.txt - ##[[[cog - ## import re, sys - ## if sys.version_info >= (3, 11): - ## import tomllib - ## else: - ## import tomli as tomllib - ## constraints = open("constraints.txt").read() - ## project = tomllib.loads(open("pyproject.toml").read()) - ## packages = [re.match('^([\w-][\w\d-]*)', r)[1] for r in project["project"]["dependencies"] if r.strip()] - ## for pkg in packages: - ## print(f"- {pkg}==" + str(re.search(f'\n{pkg}==([0-9\.]*)', constraints)[1])) - ##]]] - - attrs==24.3.0 - - black==24.10.0 - - boltons==24.1.0 - - cached-property==2.0.1 - - click==8.1.8 - - cmake==3.31.2 - - cytoolz==1.0.1 - - deepdiff==8.1.1 - - devtools==0.12.2 - - diskcache==5.6.3 - - factory-boy==3.3.1 - - filelock==3.16.1 - - frozendict==2.4.6 - - gridtools-cpp==2.3.8 - - jinja2==3.1.5 - - lark==1.2.2 - - mako==1.3.8 - - nanobind==2.4.0 - - ninja==1.11.1.3 - - numpy==1.26.4 - - packaging==24.2 - - pybind11==2.13.6 - - setuptools==75.8.0 - - tabulate==0.9.0 - - typing-extensions==4.12.2 - - xxhash==3.0.0 - ##[[[end]]] - - types-tabulate - - types-typed-ast - args: [--no-install-types] - exclude: | - (?x)^( - setup.py | - build/.* | - ci/.* | - docs/.* | - tests/.* - )$ + name: mypy static type checker + entry: uv run mypy --no-install-types src/ + language: system + types_or: [python, pyi] + pass_filenames: false + require_serial: true + stages: [pre-commit] diff --git a/.python-version b/.python-version new file mode 100644 index 0000000000..c8cfe39591 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.10 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 15e139a53e..e0ef75d31e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -55,7 +55,7 @@ Ready to start contributing? We use a [fork and pull request](https://www.atlass 3. Follow instructions in the [README.md](README.md) file to set up an environment for local development. For example: ```bash - $ tox --devenv .venv + $ uv sync --extra all $ source .venv/bin/activate ``` @@ -67,11 +67,11 @@ Ready to start contributing? We use a [fork and pull request](https://www.atlass Now you can make your changes locally. Make sure you follow the project code style documented in [CODING_GUIDELINES.md](CODING_GUIDELINES.md). -5. When you're done making changes, check that your code complies with the project code style and other quality assurance (QA) practices using `pre-commit`. Additionally, make sure that unit and regression tests pass for all supported Python versions by running `tox`: +5. When you're done making changes, check that your code complies with the project code style and other quality assurance (QA) practices using `pre-commit`. Additionally, make sure that unit and regression tests pass for all supported Python versions by running `nox`: ```bash $ pre-commit run - $ tox + $ nox ``` Read [Testing](#testing) section below for further details. @@ -143,21 +143,21 @@ pytest -v -l -s tests/ Check `pytest` documentation (`pytest --help`) for all the options to select and execute tests. -We recommended you to use `tox` for most development-related tasks, like running the complete test suite in different environments. `tox` runs the package installation script in properly isolated environments to run tests (or other tasks) in a reproducible way. A simple way to start with tox could be: +We recommended you to use `nox` for running the test suite in different environments. `nox` runs the package installation script in properly isolated environments to run tests in a reproducible way. A simple way to start with `nox` would be: ```bash # List all the available task environments -tox list +nox list # Run a specific task environment -tox run -e cartesian-py38-internal-cpu +nox -e cartesian-py38-internal-cpu ``` -Check `tox` documentation (`tox --help`) for the complete reference. +Check `nox` documentation (`nox --help`) for the complete reference. +Additionally, `nox` is configured to generate HTML test coverage reports in `tests/_reports/coverage_html/` at the end. --> ## Pull Requests (PRs) and Merge Guidelines @@ -175,27 +175,29 @@ Before submitting a pull request, check that it meets the following criteria: As mentioned above, we use several tools to help us write high-quality code. New tools could be added in the future, especially if they do not add a large overhead to our workflow and they bring extra benefits to keep our codebase in shape. The most important ones which we currently rely on are: -- [ruff][ruff] for style enforcement and code linting. +- [nox][nox] for testing and task automation with different environments. - [pre-commit][pre-commit] for automating the execution of QA tools. - [pytest][pytest] for writing readable tests, extended with: - [Coverage.py][coverage] and [pytest-cov][pytest-cov] for test coverage reports. - [pytest-xdist][pytest-xdist] for running tests in parallel. -- [tox][tox] for testing and task automation with different environments. +- [ruff][ruff] for style enforcement and code linting. - [sphinx][sphinx] for generating documentation, extended with: - [sphinx-autodoc][sphinx-autodoc] and [sphinx-napoleon][sphinx-napoleon] for extracting API documentation from docstrings. - [jupytext][jupytext] for writing new user documentation with code examples. +- [uv][uv] for managing dependencies and environments. [conventional-commits]: https://www.conventionalcommits.org/en/v1.0.0/#summary [coverage]: https://coverage.readthedocs.io/ -[ruff]: https://astral.sh/ruff [jupytext]: https://jupytext.readthedocs.io/ +[nox]: https://nox.thea.codes/en/stable/ [pre-commit]: https://pre-commit.com/ [pytest]: https://docs.pytest.org/ [pytest-cov]: https://pypi.org/project/pytest-cov/ [pytest-xdist]: https://pytest-xdist.readthedocs.io/en/latest/ +[ruff]: https://astral.sh/ruff [sphinx]: https://www.sphinx-doc.org [sphinx-autodoc]: https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html [sphinx-napoleon]: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/index.html -[tox]: https://tox.wiki/en/latest/ +[uv]: https://docs.astral.sh/uv/ diff --git a/LICENSE.txt b/LICENSE similarity index 100% rename from LICENSE.txt rename to LICENSE diff --git a/README.md b/README.md index 07e0e1cdee..f778c4f54b 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,9 @@ ![test-eve](https://github.com/GridTools/gt4py/actions/workflows/test-eve.yml/badge.svg?branch=main) ![qa](https://github.com/GridTools/gt4py/actions/workflows/code-quality.yml/badge.svg?branch=main) +[![uv](https://img.shields.io/badge/-uv-261230.svg?logo=uv)](https://github.com/astral-sh/uv) +[![Nox](https://img.shields.io/badge/%F0%9F%A6%8A-Nox-D85E00.svg)](https://github.com/wntrblm/nox) + # GT4Py: GridTools for Python GT4Py is a Python library for generating high performance implementations of stencil kernels from a high-level definition using regular Python functions. GT4Py is part of the GridTools framework, a set of libraries and utilities to develop performance portable applications in the area of weather and climate modeling. @@ -36,18 +39,18 @@ The following backends are supported: ## 🚜 Installation -GT4Py can be installed as a regular Python package using `pip` (or any other PEP-517 frontend). As usual, we strongly recommended to create a new virtual environment to work on this project. +GT4Py can be installed as a regular Python package using [uv](https://docs.astral.sh/uv/), [pip](https://pip.pypa.io/en/stable/) or any other PEP-517 compatible frontend. We strongly recommended to use`uv` to create and manage virtual environments for your own projects. The performance backends also require the [Boost](https://www.boost.org) library, a dependency of [GridTools C++](https://github.com/GridTools/gridtools), which needs to be installed by the user. ## ⚙ Configuration -If GridTools or Boost are not found in the compiler's standard include path, or a custom version is desired, then a couple configuration environment variables will allow the compiler to use them: +To explicitly set the [GridTools-C++](https://gridtools.github.io/gridtools) or [Boost](https://www.boost.org) versions used by the code generation backends, the following environment variables can be used: - `GT_INCLUDE_PATH`: Path to the GridTools installation. - `BOOST_ROOT`: Path to a boost installation. -Other commonly used environment variables are: +Other useful available environment variables are: - `CUDA_ARCH`: Set the compute capability of the NVIDIA GPU if it is not detected automatically by `cupy`. - `CXX`: Set the C++ compiler. @@ -56,67 +59,68 @@ Other commonly used environment variables are: More options and details are available in [`config.py`](https://github.com/GridTools/gt4py/blob/main/src/gt4py/cartesian/config.py). -## 📖 Documentation +## 🛠 Development Instructions -GT4Py uses Sphinx documentation. To build the documentation install the dependencies in `requirements-dev.txt` +Follow the installation instructions below to initialize a development virtual environment containing an _editable_ installation of the GT4Py package. Make sure you read the [CONTRIBUTING.md](CONTRIBUTING.md) and [CODING_GUIDELINES.md](CODING_GUIDELINES.md) documents before you start working on the project. -```bash -pip install -r ./gt4py/requirements-dev.txt -``` +### Development Environment Installation using `uv` -and then build the docs with +GT4Py uses the [`uv`](https://docs.astral.sh/uv/) project manager for the development workflow. `uv` is a versatile tool that consolidates functionality usually distributed across different applications into subcommands. -```bash -cd gt4py/docs/user/cartesian -make html # run 'make help' for a list of targets -``` +- The `uv pip` subcommand provides a _fast_ Python package manager, emulating [`pip`](https://pip.pypa.io/en/stable/). +- The `uv export | lock | sync` subcommands manage dependency versions in a manner similar to the [`pip-tools`](https://pip-tools.readthedocs.io/en/stable/) command suite. +- The `uv init | add | remove | build | publish | ...` subcommands facilitate project development workflows, akin to [`hatch`](https://hatch.pypa.io/latest/). +- The `uv tool` subcommand serves as a runner for Python applications in isolation, similar to [`pipx`](https://pipx.pypa.io/stable/). +- The `uv python` subcommands manage different Python installations and versions, much like [`pyenv`](https://github.com/pyenv/pyenv). -## 🛠 Development Instructions +`uv` can be installed in various ways (see its [installation instructions](https://docs.astral.sh/uv/getting-started/installation/)), with the recommended method being the standalone installer: -Follow the installation instructions below to initialize a development virtual environment containing an _editable_ installation of the GT4Py package. Make sure you read the [CONTRIBUTING.md](CONTRIBUTING.md) and [CODING_GUIDELINES.md](CODING_GUIDELINES.md) documents before you start working on the project. - -### Recommended Installation using `tox` +```bash +$ curl -LsSf https://astral.sh/uv/install.sh | sh +``` -If [tox](https://tox.wiki/en/latest/) is already installed in your system (`tox` is available in PyPI and many other package managers), the easiest way to create a virtual environment ready for development is: +Once `uv` is installed in your system, it is enough to clone this repository and let `uv` handling the installation of the development environment. ```bash # Clone the repository git clone https://github.com/gridtools/gt4py.git cd gt4py -# Create the development environment in any location (usually `.venv`) -# selecting one of the following templates: -# dev-py310 -> base environment -# dev-py310-atlas -> base environment + atlas4py bindings -tox devenv -e dev-py310 .venv +# Let uv create the development environment at `.venv`. +# The `--extra all` option tells uv to install all the optional +# dependencies of gt4py, and thus it is not strictly necessary. +# Note that if no dependency groups are provided as an option, +# uv uses `--group dev` by default so the development dependencies +# are installed. +uv sync --extra all -# Finally, activate the environment +# Finally, activate the virtual environment and start writing code! source .venv/bin/activate ``` -### Manual Installation +The newly created _venv_ is a standard Python virtual environment preconfigured with all necessary runtime and development dependencies. Additionally, the `gt4py` package is installed in editable mode, allowing for seamless development and testing. To install new packages in this environment, use the `uv pip` subcommand which emulates the `pip` interface and is generally much faster than the original `pip` tool (which is also available within the venv although its use is discouraged). -Alternatively, a development environment can be created from scratch installing the frozen dependencies packages : +The `pyproject.toml` file contains both the definition of the `gt4py` Python distribution package and the settings of the development tools used in this project, most notably `uv`, `ruff`, and `mypy`. It also contains _dependency groups_ (see [PEP 735](https://peps.python.org/pep-0735/) for further reference) with the development requirements listed in different groups (`build`, `docs`, `lint`, `test`, `typing`, ...) and collected together in the general `dev` group, which gets installed by default by `uv` as mentioned above. -```bash -# Clone the repository -git clone https://github.com/gridtools/gt4py.git -cd gt4py +### Development Tasks (`dev-tasks.py`) -# Create a (Python 3.10) virtual environment (usually at `.venv`) -python3.10 -m venv .venv +Recurrent development tasks like bumping versions of used development tools or required third party dependencies have been collected as different subcommands in the [`dev-tasks.py`](./dev-tasks.py) script. Read the tool help for a brief description of every task and always use this tool to update the versions and sync the version configuration accross different files (e.g. `pyproject.toml` and `.pre-commit-config.yaml`). -# Activate the virtual environment and update basic packages -source .venv/bin/activate -pip install --upgrade wheel setuptools pip +## 📖 Documentation + +GT4Py uses the Sphinx tool for the documentation. To build browseable HTML documentation, install the required tools provided in the `docs` dependency group: -# Install the required development tools -pip install -r requirements-dev.txt -# Install GT4Py project in editable mode -pip install -e . +```bash +uv install --group docs --extra all # or --group dev +``` + +(Note that most likely these tools are already installed in your development environment, since the `docs` group is included in the `dev` group, which installed by default by `uv sync` if no dependency groups are specified.) + +Once the requirements are already installed, then build the docs using: -# Optionally, install atlas4py bindings directly from the repo -# pip install git+https://github.com/GridTools/atlas4py#egg=atlas4py +```bash +cd gt4py/docs/user/cartesian +make html # run 'make help' for a list of targets ``` ## ⚖️ License diff --git a/ci/base.Dockerfile b/ci/base.Dockerfile index ea7c4722c7..1ad9aefa03 100644 --- a/ci/base.Dockerfile +++ b/ci/base.Dockerfile @@ -57,4 +57,4 @@ ENV PATH="/root/.pyenv/shims:${PATH}" ARG CUPY_PACKAGE=cupy-cuda12x ARG CUPY_VERSION=13.3.0 -RUN pip install --upgrade pip setuptools wheel tox ${CUPY_PACKAGE}==${CUPY_VERSION} +RUN pip install --upgrade pip setuptools wheel uv nox ${CUPY_PACKAGE}==${CUPY_VERSION} diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index b5ea07b787..05955913ba 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -4,7 +4,6 @@ include: .py311: &py311 PYVERSION_PREFIX: py311 PYVERSION: 3.11.9 - .py310: &py310 PYVERSION_PREFIX: py310 PYVERSION: 3.10.9 @@ -111,13 +110,14 @@ build_py310_image_aarch64: image: $CSCS_REGISTRY_PATH/public/$ARCH/gt4py/gt4py-ci:$CI_COMMIT_SHA-$PYVERSION script: - cd /gt4py.src - - python -c "import cupy" - - tox run -e $SUBPACKAGE-$PYVERSION_PREFIX$VARIANT$SUBVARIANT + - NOX_SESSION_ARGS="${VARIANT:+($VARIANT}${SUBVARIANT:+, $SUBVARIANT}${DETAIL:+, $DETAIL}${VARIANT:+)}" + - nox -e "test_$SUBPACKAGE-${PYVERSION:0:4}$NOX_SESSION_ARGS" variables: CRAY_CUDA_MPS: 1 SLURM_JOB_NUM_NODES: 1 SLURM_TIMELIMIT: 15 NUM_PROCESSES: auto + PYENV_VERSION: $PYVERSION VIRTUALENV_SYSTEM_SITE_PACKAGES: 1 # .test_helper_x86_64: # extends: [.container-runner-daint-gpu, .test_helper] @@ -134,13 +134,16 @@ build_py310_image_aarch64: extends: [.container-runner-daint-gh200, .test_helper] parallel: matrix: - - SUBPACKAGE: [cartesian, storage] - VARIANT: [-internal, -dace] - SUBVARIANT: [-cuda12x, -cpu] + - SUBPACKAGE: [cartesian] + VARIANT: ['internal', 'dace'] + SUBVARIANT: ['cuda12', 'cpu'] - SUBPACKAGE: eve - SUBPACKAGE: next - VARIANT: [-nomesh, -atlas] - SUBVARIANT: [-cuda12x, -cpu] + VARIANT: ['internal', 'dace'] + SUBVARIANT: ['cuda12', 'cpu'] + DETAIL: ['nomesh', 'atlas'] + - SUBPACKAGE: [storage] + VARIANT: ['cuda12', 'cpu'] variables: # Grace-Hopper gpu architecture is not enabled by default in CUDA build CUDAARCHS: "90" diff --git a/constraints.txt b/constraints.txt deleted file mode 100644 index 8b3e5e697f..0000000000 --- a/constraints.txt +++ /dev/null @@ -1,178 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.10 -# by the following command: -# -# "tox run -e requirements-base" -# -aenum==3.1.15 # via dace -alabaster==1.0.0 # via sphinx -annotated-types==0.7.0 # via pydantic -asttokens==2.4.1 # via devtools, stack-data -astunparse==1.6.3 # via dace -attrs==24.3.0 # via gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.16.0 # via sphinx -black==24.10.0 # via gt4py (pyproject.toml) -boltons==24.1.0 # via gt4py (pyproject.toml) -bracex==2.5.post1 # via wcmatch -build==1.2.2.post1 # via pip-tools -bump-my-version==0.29.0 # via -r requirements-dev.in -cached-property==2.0.1 # via gt4py (pyproject.toml) -cachetools==5.5.0 # via tox -certifi==2024.12.14 # via requests -cfgv==3.4.0 # via pre-commit -chardet==5.2.0 # via tox -charset-normalizer==3.4.1 # via requests -clang-format==19.1.6 # via -r requirements-dev.in, gt4py (pyproject.toml) -click==8.1.8 # via black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.31.2 # via gt4py (pyproject.toml) -cogapp==3.4.1 # via -r requirements-dev.in -colorama==0.4.6 # via tox -comm==0.2.2 # via ipykernel -contourpy==1.3.1 # via matplotlib -coverage==7.6.10 # via -r requirements-dev.in, pytest-cov -cycler==0.12.1 # via matplotlib -cytoolz==1.0.1 # via gt4py (pyproject.toml) -dace==1.0.0 # via gt4py (pyproject.toml) -darglint==1.8.1 # via -r requirements-dev.in -debugpy==1.8.11 # via ipykernel -decorator==5.1.1 # via ipython -deepdiff==8.1.1 # via gt4py (pyproject.toml) -devtools==0.12.2 # via gt4py (pyproject.toml) -dill==0.3.9 # via dace -diskcache==5.6.3 # via gt4py (pyproject.toml) -distlib==0.3.9 # via virtualenv -docutils==0.21.2 # via sphinx, sphinx-rtd-theme -exceptiongroup==1.2.2 # via hypothesis, ipython, pytest -execnet==2.1.1 # via pytest-cache, pytest-xdist -executing==2.1.0 # via devtools, stack-data -factory-boy==3.3.1 # via gt4py (pyproject.toml), pytest-factoryboy -faker==33.3.0 # via factory-boy -fastjsonschema==2.21.1 # via nbformat -filelock==3.16.1 # via gt4py (pyproject.toml), tox, virtualenv -fonttools==4.55.3 # via matplotlib -fparser==0.2.0 # via dace -frozendict==2.4.6 # via gt4py (pyproject.toml) -gitdb==4.0.12 # via gitpython -gitpython==3.1.44 # via tach -gridtools-cpp==2.3.8 # via gt4py (pyproject.toml) -hypothesis==6.123.11 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.6.5 # via pre-commit -idna==3.10 # via requests -imagesize==1.4.1 # via sphinx -inflection==0.5.1 # via pytest-factoryboy -iniconfig==2.0.0 # via pytest -ipykernel==6.29.5 # via nbmake -ipython==8.31.0 # via ipykernel -jax==0.4.38 # via gt4py (pyproject.toml) -jaxlib==0.4.38 # via jax -jedi==0.19.2 # via ipython -jinja2==3.1.5 # via gt4py (pyproject.toml), sphinx -jsonschema==4.23.0 # via nbformat -jsonschema-specifications==2024.10.1 # via jsonschema -jupyter-client==8.6.3 # via ipykernel, nbclient -jupyter-core==5.7.2 # via ipykernel, jupyter-client, nbclient, nbformat -jupytext==1.16.6 # via -r requirements-dev.in -kiwisolver==1.4.8 # via matplotlib -lark==1.2.2 # via gt4py (pyproject.toml) -mako==1.3.8 # via gt4py (pyproject.toml) -markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins, rich -markupsafe==3.0.2 # via jinja2, mako -matplotlib==3.10.0 # via -r requirements-dev.in -matplotlib-inline==0.1.7 # via ipykernel, ipython -mdit-py-plugins==0.4.2 # via jupytext -mdurl==0.1.2 # via markdown-it-py -ml-dtypes==0.5.1 # via jax, jaxlib -mpmath==1.3.0 # via sympy -mypy==1.14.1 # via -r requirements-dev.in -mypy-extensions==1.0.0 # via black, mypy -nanobind==2.4.0 # via gt4py (pyproject.toml) -nbclient==0.10.2 # via nbmake -nbformat==5.10.4 # via jupytext, nbclient, nbmake -nbmake==1.5.5 # via -r requirements-dev.in -nest-asyncio==1.6.0 # via ipykernel -networkx==3.4.2 # via dace, tach -ninja==1.11.1.3 # via gt4py (pyproject.toml) -nodeenv==1.9.1 # via pre-commit -numpy==1.26.4 # via contourpy, dace, gt4py (pyproject.toml), jax, jaxlib, matplotlib, ml-dtypes, scipy -opt-einsum==3.4.0 # via jax -orderly-set==5.2.3 # via deepdiff -packaging==24.2 # via black, build, dace, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox -parso==0.8.4 # via jedi -pathspec==0.12.1 # via black -pexpect==4.9.0 # via ipython -pillow==11.1.0 # via matplotlib -pip-tools==7.4.1 # via -r requirements-dev.in -pipdeptree==2.24.0 # via -r requirements-dev.in -platformdirs==4.3.6 # via black, jupyter-core, tox, virtualenv -pluggy==1.5.0 # via pytest, tox -ply==3.11 # via dace -pre-commit==4.0.1 # via -r requirements-dev.in -prompt-toolkit==3.0.48 # via ipython, questionary, tach -psutil==6.1.1 # via -r requirements-dev.in, ipykernel, pytest-xdist -ptyprocess==0.7.0 # via pexpect -pure-eval==0.2.3 # via stack-data -pybind11==2.13.6 # via gt4py (pyproject.toml) -pydantic==2.10.4 # via bump-my-version, pydantic-settings -pydantic-core==2.27.2 # via pydantic -pydantic-settings==2.7.1 # via bump-my-version -pydot==3.0.4 # via tach -pygments==2.19.1 # via -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx -pyparsing==3.2.1 # via matplotlib, pydot -pyproject-api==1.8.0 # via tox -pyproject-hooks==1.2.0 # via build, pip-tools -pytest==8.3.4 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist -pytest-cache==1.0 # via -r requirements-dev.in -pytest-cov==6.0.0 # via -r requirements-dev.in -pytest-custom-exit-code==0.3.0 # via -r requirements-dev.in -pytest-factoryboy==2.7.0 # via -r requirements-dev.in -pytest-instafail==0.5.0 # via -r requirements-dev.in -pytest-xdist==3.6.1 # via -r requirements-dev.in -python-dateutil==2.9.0.post0 # via faker, jupyter-client, matplotlib -python-dotenv==1.0.1 # via pydantic-settings -pyyaml==6.0.2 # via dace, jupytext, pre-commit, tach -pyzmq==26.2.0 # via ipykernel, jupyter-client -questionary==2.1.0 # via bump-my-version -referencing==0.35.1 # via jsonschema, jsonschema-specifications -requests==2.32.3 # via sphinx -rich==13.9.4 # via bump-my-version, rich-click, tach -rich-click==1.8.5 # via bump-my-version -rpds-py==0.22.3 # via jsonschema, referencing -ruff==0.8.6 # via -r requirements-dev.in -scipy==1.15.0 # via gt4py (pyproject.toml), jax, jaxlib -setuptools-scm==8.1.0 # via fparser -six==1.17.0 # via asttokens, astunparse, python-dateutil -smmap==5.0.2 # via gitdb -snowballstemmer==2.2.0 # via sphinx -sortedcontainers==2.4.0 # via hypothesis -sphinx==8.1.3 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==3.0.2 # via -r requirements-dev.in -sphinxcontrib-applehelp==2.0.0 # via sphinx -sphinxcontrib-devhelp==2.0.0 # via sphinx -sphinxcontrib-htmlhelp==2.1.0 # via sphinx -sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme -sphinxcontrib-jsmath==1.0.1 # via sphinx -sphinxcontrib-qthelp==2.0.0 # via sphinx -sphinxcontrib-serializinghtml==2.0.0 # via sphinx -stack-data==0.6.3 # via ipython -sympy==1.13.3 # via dace -tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.19.5 # via -r requirements-dev.in -tomli==2.2.1 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, sphinx, tach, tox -tomli-w==1.1.0 # via tach -tomlkit==0.13.2 # via bump-my-version -toolz==1.0.0 # via cytoolz -tornado==6.4.2 # via ipykernel, jupyter-client -tox==4.23.2 # via -r requirements-dev.in -traitlets==5.14.3 # via comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat -types-tabulate==0.9.0.20241207 # via -r requirements-dev.in -typing-extensions==4.12.2 # via black, faker, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, tox -urllib3==2.3.0 # via requests -virtualenv==20.28.1 # via pre-commit, tox -wcmatch==10.0 # via bump-my-version -wcwidth==0.2.13 # via prompt-toolkit -wheel==0.45.1 # via astunparse, pip-tools -xxhash==3.0.0 # via gt4py (pyproject.toml) - -# The following packages are considered to be unsafe in a requirements file: -pip==24.3.1 # via pip-tools, pipdeptree -setuptools==75.8.0 # via gt4py (pyproject.toml), pip-tools, setuptools-scm diff --git a/dev-tasks.py b/dev-tasks.py new file mode 100755 index 0000000000..437d107807 --- /dev/null +++ b/dev-tasks.py @@ -0,0 +1,97 @@ +#! /usr/bin/env -S uv run -q +# +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +# +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "typer>=0.12.3", +# ] +# [tool.uv] +# exclude-newer = "2025-01-31T00:00:00Z" +# /// + + +"""Script for running recurrent development tasks.""" + +from __future__ import annotations + +import pathlib +import subprocess +from typing import Final + +import typer + +ROOT_DIR: Final = pathlib.Path(__file__).parent + + +# -- Helpers -- +def gather_versions() -> dict[str, str]: + with subprocess.Popen( + [*"uv export --frozen --no-hashes --project".split(), ROOT_DIR], stdout=subprocess.PIPE + ) as proc: + return dict( + line.split("==") + for line in proc.stdout.read().decode().splitlines() + if not any(line.startswith(c) for c in ["-", "#"]) + ) + + +# -- CLI -- +app = typer.Typer(no_args_is_help=True) + + +@app.command() +def sync_precommit() -> None: + """Sync versions of tools used in pre-commit hooks with the project versions.""" + versions = gather_versions() + # Update ruff version in pre-commit config + subprocess.run( + f"""uvx -q --from 'yamlpath' yaml-set --mustexist --change='repos[.repo%https://github.com/astral-sh/ruff-pre-commit].rev' --value='v{versions["ruff"]}' .pre-commit-config.yaml""", + cwd=ROOT_DIR, + shell=True, + check=True, + ) + + # Update tach version in pre-commit config + subprocess.run( + f"""uvx -q --from 'yamlpath' yaml-set --mustexist --change='repos[.repo%https://github.com/gauge-sh/tach-pre-commit].rev' --value='v{versions["tach"]}' .pre-commit-config.yaml""", + cwd=ROOT_DIR, + shell=True, + check=True, + ) + + # Format yaml files + subprocess.run( + f"uv run --project {ROOT_DIR} pre-commit run pretty-format-yaml --all-files", shell=True + ) + + +@app.command() +def update_precommit() -> None: + """Update and sync pre-commit hooks with the latest compatible versions.""" + subprocess.run(f"uv run --project {ROOT_DIR} pre-commit autoupdate", shell=True) + sync_precommit() + + +@app.command() +def update_versions() -> None: + """Update all project dependencies to their latest compatible versions.""" + subprocess.run("uv lock --upgrade", cwd=ROOT_DIR, shell=True, check=True) + + +@app.command() +def update_all() -> None: + """Update all project dependencies and pre-commit hooks.""" + update_versions() + update_precommit() + + +if __name__ == "__main__": + app() diff --git a/docs/development/tools/ci-infrastructure.md b/docs/development/tools/ci-infrastructure.md index 242bea50bd..e76cb7d608 100644 --- a/docs/development/tools/ci-infrastructure.md +++ b/docs/development/tools/ci-infrastructure.md @@ -1,6 +1,6 @@ # CI infrastructure -Any test job that runs on CI is encoded in automation tools like **tox** and **pre-commit** and can be run locally instead. +Any test job that runs on CI is encoded in automation tools like **nox** and **pre-commit** and can be run locally instead. ## GitHub Workflows diff --git a/docs/development/tools/requirements.md b/docs/development/tools/requirements.md deleted file mode 100644 index 010f317493..0000000000 --- a/docs/development/tools/requirements.md +++ /dev/null @@ -1,27 +0,0 @@ -# Requirements - -The specification of required third-party packages is scattered and partially duplicated across several configuration files used by several tools. Keeping all package requirements in sync manually is challenging and error-prone. Therefore, in this project we use [pip-tools](https://pip-tools.readthedocs.io/en/latest/) and the [cog](https://nedbatchelder.com/code/cog/) file generation tool to avoid inconsistencies. - -The following files in this repository contain information about required third-party packages: - -- `pyproject.toml`: GT4Py [package configuration](https://peps.python.org/pep-0621/) used by the build backend (`setuptools`). Install dependencies are specified in the _project.dependencies_ and _project.optional-dependencies_ tables. -- `requirements-dev.in`: [requirements file](https://pip.pypa.io/en/stable/reference/requirements-file-format/) used by **pip**. It contains a list of packages required only for the development of GT4Py. -- `requirements-dev.txt`: requirements file used by **pip**. It contains a completely frozen list of all packages required for installing and developing GT4Py. It is used by **pip** and **tox** to initialize the standard development and testing environments. It is automatically generated automatically from `requirements-dev.in` by **pip-compile**, when running the **tox** environment to update requirements. -- `constraints.txt`: [constraints file](https://pip.pypa.io/en/stable/user_guide/#constraints-files) used by **pip** and **tox** to initialize a subset of the standard development environment making sure that if other packages are installed, transitive dependencies are taken from the frozen package list. It is generated automatically from `requirements-dev.in` using **pip-compile**. -- `min-requirements-test.txt`: requirements file used by **pip**. It contains the minimum list of requirements to run GT4Py tests with the oldest compatible versions of all dependencies. It is generated automatically from `pyproject.toml` using **cog**. -- `min-extra-requirements-test.txt`: requirements file used by **pip**. It contains the minimum list of requirements to run GT4Py tests with the oldest compatible versions of all dependencies, additionally including all GT4Py extras. It is generated automatically from `pyproject.toml` using **cog**. -- `.pre-commit-config.yaml`: **pre-commit** configuration with settings for many linting and formatting tools. Part of its content is generated automatically from `pyproject.toml` using **cog**. - -The expected workflow to update GT4Py requirements is as follows: - -1. For changes in the GT4Py package dependencies, update the relevant table in `pyproject.toml`. When adding new tables to the _project.optional-dependencies_ section, make sure to add the new table as a dependency of the `all-` extra tables when possible. - -2. For changes in the development tools, update the `requirements-dev.in` file. Note that required project packages already appearing in `pyproject.toml` should not be duplicated here. - -3. Run the **tox** _requirements-base_ environment to update all files automatically with **pip-compile** and **cog**. Note that **pip-compile** will most likely update the versions of some unrelated tools if new versions are available in PyPI. - - ```bash - tox r -e requirements-base - ``` - -4. Check that the **mypy** mirror used by **pre-commit** (https://github.com/pre-commit/mirrors-mypy) in `.pre-commit-config.yaml` supports the same version as in `constraints.txt`, and manually update the `rev` version number. diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt deleted file mode 100644 index a4924cc09c..0000000000 --- a/min-extra-requirements-test.txt +++ /dev/null @@ -1,110 +0,0 @@ -# -# Generated automatically by cog from pyproject.toml and requirements-dev.in -# Run: -# tox r -e requirements-common -# - -##[[[cog -## import copy, sys -## from packaging import requirements as reqs, specifiers as specs -## if sys.version_info >= (3, 11): -## import tomllib -## else: -## import tomli as tomllib -## -## def make_min_req(r: reqs.Requirement) -> reqs.Requirement: -## for s in r.specifier: -## if (ss := str(s)).startswith(">"): -## assert ss.startswith(">="), f"'{r!s}' requires a '>=' constraint" -## min_spec = specs.SpecifierSet(f"=={ss[2:]}") -## break -## min_r = copy.deepcopy(r) -## min_r.specifier = min_spec -## return min_r -## -## project = tomllib.loads(open("pyproject.toml").read()) -## all_cpu_extra = project["project"]["optional-dependencies"]["all-cpu"] -## assert len(all_cpu_extra) == 1 and all_cpu_extra[0].startswith("gt4py[") -## opt_req_versions = { -## reqs.Requirement(r).name: reqs.Requirement(r) -## for e in reqs.Requirement(all_cpu_extra[0]).extras -## for r in project["project"]["optional-dependencies"][e] -## } -## requirements = [ -## reqs.Requirement(rr) -## for r in (project["project"]["dependencies"] + open("requirements-dev.in").readlines()) -## if (rr := (r[: r.find("#")] if "#" in r else r)) -## ] -## processed = set() -## result = [] -## for r in requirements: -## assert r.name not in processed -## processed.add(r.name) -## if not r.specifier: -## assert r.name in opt_req_versions, f"Missing contraints for '{r.name}'" -## r = opt_req_versions[r.name] -## result.append(str(make_min_req(r))) -## for r_name, r in opt_req_versions.items(): -## if r_name not in processed: -## result.append(str(make_min_req(r))) -## print("\n".join(sorted(result))) -##]]] -attrs==21.3 -black==22.3 -boltons==20.1 -bump-my-version==0.12.0 -cached-property==1.5.1 -clang-format==9.0 -click==8.0.0 -cmake==3.22 -cogapp==3.3 -coverage[toml]==5.0 -cytoolz==0.12.1 -dace==1.0.0 -darglint==1.6 -deepdiff==5.6.0 -devtools==0.6 -diskcache==5.6.3 -factory-boy==3.3.0 -filelock==3.16.1 -frozendict==2.3 -gridtools-cpp==2.3.8 -hypothesis==6.0.0 -jax[cpu]==0.4.18 -jinja2==3.0.0 -jupytext==1.14 -lark==1.1.2 -mako==1.1 -matplotlib==3.3 -mypy==1.0 -nanobind==1.4.0 -nbmake==1.4.6 -ninja==1.10 -numpy==1.23.3 -packaging==20.0 -pip-tools==6.10 -pipdeptree==2.3 -pre-commit==2.17 -psutil==5.0 -pybind11==2.10.1 -pygments==2.7.3 -pytest-cache==1.0 -pytest-cov==2.8 -pytest-custom-exit-code==0.3.0 -pytest-factoryboy==2.0.3 -pytest-instafail==0.5.0 -pytest-xdist[psutil]==2.4 -pytest==7.0 -ruff==0.2.0 -scipy==1.9.2 -setuptools==65.5.0 -sphinx==4.4 -sphinx_rtd_theme==1.0 -tabulate==0.8.10 -tach==0.10.7 -tomli==2.0.1; python_version < "3.11" -tox==3.2.0 -types-tabulate==0.8.10 -typing-extensions==4.10.0 -xxhash==1.4.4 -##[[[end]]] diff --git a/min-requirements-test.txt b/min-requirements-test.txt deleted file mode 100644 index 4b24385410..0000000000 --- a/min-requirements-test.txt +++ /dev/null @@ -1,104 +0,0 @@ -# -# Generated automatically by cog from pyproject.toml and requirements-dev.in -# Run: -# tox r -e requirements-common -# - -##[[[cog -## import copy, sys -## from packaging import requirements as reqs, specifiers as specs -## if sys.version_info >= (3, 11): -## import tomllib -## else: -## import tomli as tomllib -## -## def make_min_req(r: reqs.Requirement) -> reqs.Requirement: -## for s in r.specifier: -## if (ss := str(s)).startswith(">"): -## assert ss.startswith(">="), f"'{r!s}' requires a '>=' constraint" -## min_spec = specs.SpecifierSet(f"=={ss[2:]}") -## break -## min_r = copy.deepcopy(r) -## min_r.specifier = min_spec -## return min_r -## -## project = tomllib.loads(open("pyproject.toml").read()) -## all_cpu_extra = project["project"]["optional-dependencies"]["all-cpu"] -## assert len(all_cpu_extra) == 1 and all_cpu_extra[0].startswith("gt4py[") -## opt_req_versions = { -## reqs.Requirement(r).name: reqs.Requirement(r) -## for e in reqs.Requirement(all_cpu_extra[0]).extras -## for r in project["project"]["optional-dependencies"][e] -## } -## requirements = [ -## reqs.Requirement(rr) -## for r in (project["project"]["dependencies"] + open("requirements-dev.in").readlines()) -## if (rr := (r[: r.find("#")] if "#" in r else r)) -## ] -## processed = set() -## result = [] -## for r in requirements: -## assert r.name not in processed -## processed.add(r.name) -## if not r.specifier: -## assert r.name in opt_req_versions, f"Missing contraints for '{r.name}'" -## r = opt_req_versions[r.name] -## result.append(str(make_min_req(r))) -## print("\n".join(sorted(result))) -##]]] -attrs==21.3 -black==22.3 -boltons==20.1 -bump-my-version==0.12.0 -cached-property==1.5.1 -clang-format==9.0 -click==8.0.0 -cmake==3.22 -cogapp==3.3 -coverage[toml]==5.0 -cytoolz==0.12.1 -darglint==1.6 -deepdiff==5.6.0 -devtools==0.6 -diskcache==5.6.3 -factory-boy==3.3.0 -filelock==3.16.1 -frozendict==2.3 -gridtools-cpp==2.3.8 -hypothesis==6.0.0 -jinja2==3.0.0 -jupytext==1.14 -lark==1.1.2 -mako==1.1 -matplotlib==3.3 -mypy==1.0 -nanobind==1.4.0 -nbmake==1.4.6 -ninja==1.10 -numpy==1.23.3 -packaging==20.0 -pip-tools==6.10 -pipdeptree==2.3 -pre-commit==2.17 -psutil==5.0 -pybind11==2.10.1 -pygments==2.7.3 -pytest-cache==1.0 -pytest-cov==2.8 -pytest-custom-exit-code==0.3.0 -pytest-factoryboy==2.0.3 -pytest-instafail==0.5.0 -pytest-xdist[psutil]==2.4 -pytest==7.0 -ruff==0.2.0 -setuptools==65.5.0 -sphinx==4.4 -sphinx_rtd_theme==1.0 -tabulate==0.8.10 -tach==0.10.7 -tomli==2.0.1; python_version < "3.11" -tox==3.2.0 -types-tabulate==0.8.10 -typing-extensions==4.10.0 -xxhash==1.4.4 -##[[[end]]] diff --git a/noxfile.py b/noxfile.py new file mode 100644 index 0000000000..0b150c0db7 --- /dev/null +++ b/noxfile.py @@ -0,0 +1,251 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import os +import pathlib +import types +from collections.abc import Sequence +from typing import Final, Literal, TypeAlias + +import nox + +#: This should just be `pytest.ExitCode.NO_TESTS_COLLECTED` but `pytest` +#: is not guaranteed to be available in the venv where `nox` is running. +NO_TESTS_COLLECTED_EXIT_CODE: Final = 5 + +# -- nox configuration -- +nox.options.default_venv_backend = "uv" +nox.options.sessions = [ + "test_cartesian-3.10(internal, cpu)", + "test_cartesian-3.10(dace, cpu)", + "test_cartesian-3.11(internal, cpu)", + "test_cartesian-3.11(dace, cpu)", + "test_eve-3.10", + "test_eve-3.11", + "test_next-3.10(internal, cpu, nomesh)", + "test_next-3.10(dace, cpu, nomesh)", + "test_next-3.11(internal, cpu, nomesh)", + "test_next-3.11(dace, cpu, nomesh)", + "test_storage-3.10(cpu)", + "test_storage-3.11(cpu)", +] + +# -- Parameter sets -- +DeviceOption: TypeAlias = Literal["cpu", "cuda11", "cuda12", "rocm4_3", "rocm5_0"] +DeviceNoxParam: Final = types.SimpleNamespace( + **{device: nox.param(device, id=device, tags=[device]) for device in DeviceOption.__args__} +) +DeviceTestSettings: Final[dict[str, dict[str, Sequence]]] = { + "cpu": {"extras": [], "markers": ["not requires_gpu"]}, + **{ + device: {"extras": [device], "markers": ["requires_gpu"]} + for device in ["cuda11", "cuda12", "rocm4_3", "rocm5_0"] + }, +} + +CodeGenOption: TypeAlias = Literal["internal", "dace"] +CodeGenNoxParam: Final = types.SimpleNamespace( + **{ + codegen: nox.param(codegen, id=codegen, tags=[codegen]) + for codegen in CodeGenOption.__args__ + } +) +CodeGenTestSettings: Final[dict[str, dict[str, Sequence]]] = { + "internal": {"extras": [], "markers": ["not requires_dace"]}, + "dace": {"extras": ["dace"], "markers": ["requires_dace"]}, +} + + +# -- nox sessions -- +@nox.session(python=["3.10", "3.11"], tags=["cartesian"]) +@nox.parametrize("device", [DeviceNoxParam.cpu, DeviceNoxParam.cuda12]) +@nox.parametrize("codegen", [CodeGenNoxParam.internal, CodeGenNoxParam.dace]) +def test_cartesian( + session: nox.Session, + codegen: CodeGenOption, + device: DeviceOption, +) -> None: + """Run selected 'gt4py.cartesian' tests.""" + + codegen_settings = CodeGenTestSettings[codegen] + device_settings = DeviceTestSettings[device] + + _install_session_venv( + session, + extras=["performance", "testing", *codegen_settings["extras"], *device_settings["extras"]], + groups=["test"], + ) + + num_processes = session.env.get("NUM_PROCESSES", "auto") + markers = " and ".join(codegen_settings["markers"] + device_settings["markers"]) + + session.run( + *f"pytest --cache-clear -sv -n {num_processes}".split(), + *("-m", f"{markers}"), + str(pathlib.Path("tests") / "cartesian_tests"), + *session.posargs, + ) + session.run( + *"pytest --doctest-modules --doctest-ignore-import-errors -sv".split(), + str(pathlib.Path("src") / "gt4py" / "cartesian"), + ) + + +@nox.session(python=["3.10", "3.11"]) +def test_examples(session: nox.Session) -> None: + """Run and test documentation workflows.""" + + _install_session_venv(session, extras=["testing"], groups=["docs", "test"]) + + session.run(*"jupytext docs/user/next/QuickstartGuide.md --to .ipynb".split()) + session.run(*"jupytext docs/user/next/advanced/*.md --to .ipynb".split()) + + num_processes = session.env.get("NUM_PROCESSES", "auto") + for notebook, extra_args in [ + ("docs/user/next/workshop/slides", None), + ("docs/user/next/workshop/exercises", ["-k", "solutions"]), + ("docs/user/next/QuickstartGuide.ipynb", None), + ("docs/user/next/advanced", None), + ("examples", (None)), + ]: + session.run( + *f"pytest --nbmake {notebook} -sv -n {num_processes}".split(), + *(extra_args or []), + ) + + +@nox.session(python=["3.10", "3.11"], tags=["cartesian", "next", "cpu"]) +def test_eve(session: nox.Session) -> None: + """Run 'gt4py.eve' tests.""" + + _install_session_venv(session, groups=["test"]) + + num_processes = session.env.get("NUM_PROCESSES", "auto") + + session.run( + *f"pytest --cache-clear -sv -n {num_processes}".split(), + str(pathlib.Path("tests") / "eve_tests"), + *session.posargs, + ) + session.run( + *"pytest --doctest-modules -sv".split(), + str(pathlib.Path("src") / "gt4py" / "eve"), + ) + + +@nox.session(python=["3.10", "3.11"], tags=["next"]) +@nox.parametrize( + "meshlib", + [ + nox.param("nomesh", id="nomesh", tags=["nomesh"]), + nox.param("atlas", id="atlas", tags=["atlas"]), + ], +) +@nox.parametrize("device", [DeviceNoxParam.cpu, DeviceNoxParam.cuda12]) +@nox.parametrize("codegen", [CodeGenNoxParam.internal, CodeGenNoxParam.dace]) +def test_next( + session: nox.Session, + codegen: CodeGenOption, + device: DeviceOption, + meshlib: Literal["nomesh", "atlas"], +) -> None: + """Run selected 'gt4py.next' tests.""" + + codegen_settings = CodeGenTestSettings[codegen] + device_settings = DeviceTestSettings[device] + groups: list[str] = ["test"] + mesh_markers: list[str] = [] + + match meshlib: + case "nomesh": + mesh_markers.append("not requires_atlas") + case "atlas": + mesh_markers.append("requires_atlas") + groups.append("frameworks") + + _install_session_venv( + session, + extras=["performance", "testing", *codegen_settings["extras"], *device_settings["extras"]], + groups=groups, + ) + + num_processes = session.env.get("NUM_PROCESSES", "auto") + markers = " and ".join(codegen_settings["markers"] + device_settings["markers"] + mesh_markers) + + session.run( + *f"pytest --cache-clear -sv -n {num_processes}".split(), + *("-m", f"{markers}"), + str(pathlib.Path("tests") / "next_tests"), + *session.posargs, + success_codes=[0, NO_TESTS_COLLECTED_EXIT_CODE], + ) + session.run( + *"pytest --doctest-modules --doctest-ignore-import-errors -sv".split(), + str(pathlib.Path("src") / "gt4py" / "next"), + success_codes=[0, NO_TESTS_COLLECTED_EXIT_CODE], + ) + + +@nox.session(python=["3.10", "3.11"], tags=["cartesian", "next"]) +@nox.parametrize("device", [DeviceNoxParam.cpu, DeviceNoxParam.cuda12]) +def test_storage( + session: nox.Session, + device: DeviceOption, +) -> None: + """Run selected 'gt4py.storage' tests.""" + + device_settings = DeviceTestSettings[device] + + _install_session_venv( + session, extras=["performance", "testing", *device_settings["extras"]], groups=["test"] + ) + + num_processes = session.env.get("NUM_PROCESSES", "auto") + markers = " and ".join(device_settings["markers"]) + + session.run( + *f"pytest --cache-clear -sv -n {num_processes}".split(), + *("-m", f"{markers}"), + str(pathlib.Path("tests") / "storage_tests"), + *session.posargs, + ) + session.run( + *"pytest --doctest-modules -sv".split(), + str(pathlib.Path("src") / "gt4py" / "storage"), + success_codes=[0, NO_TESTS_COLLECTED_EXIT_CODE], + ) + + +# -- utils -- +def _install_session_venv( + session: nox.Session, + *args: str | Sequence[str], + extras: Sequence[str] = (), + groups: Sequence[str] = (), +) -> None: + """Install session packages using uv.""" + session.run_install( + "uv", + "sync", + *("--python", session.python), + "--no-dev", + *(f"--extra={e}" for e in extras), + *(f"--group={g}" for g in groups), + env={key: value for key, value in os.environ.items()} + | {"UV_PROJECT_ENVIRONMENT": session.virtualenv.location}, + ) + for item in args: + session.run_install( + "uv", + "pip", + "install", + *((item,) if isinstance(item, str) else item), + env={"UV_PROJECT_ENVIRONMENT": session.virtualenv.location}, + ) diff --git a/pyproject.toml b/pyproject.toml index 979dfbbd02..91c2ba0323 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,68 @@ +# -- Build system requirements (PEP 518) -- + [build-system] build-backend = 'setuptools.build_meta' -requires = ['setuptools>=65.5.0', 'wheel>=0.33.6', 'cython>=0.29.13'] +requires = ['setuptools>=70.0.0', 'wheel>=0.33.6', 'cython>=3.0.0'] + +# -- Dependency groups -- +[dependency-groups] +build = [ + 'bump-my-version>=0.16.0', + 'cython>=3.0.0', + 'pip>=22.1.1', + 'setuptools>=70.0.0', + 'wheel>=0.33.6' +] +dev = [ + {include-group = 'build'}, + {include-group = 'docs'}, + {include-group = 'frameworks'}, + {include-group = 'lint'}, + {include-group = 'test'}, + {include-group = 'typing'} +] +docs = [ + 'esbonio>=0.16.0', + 'jupytext>=1.14', + 'matplotlib>=3.3', + 'myst-parser>=4.0.0', + 'pygments>=2.7.3', + 'sphinx>=7.3.7', + 'sphinx-rtd-theme>=3.0.1', + 'sphinx-toolbox>=3.8.1' +] +frameworks = [ + # 3rd party frameworks with some interoperability with gt4py + 'atlas4py>=0.35' +] +lint = [ + 'pre-commit>=4.0.1', + 'ruff>=0.8.0', + 'tach>=0.16.0' +] +test = [ + 'coverage[toml]>=7.5.0', + 'hypothesis>=6.0.0', + 'nbmake>=1.4.6', + 'nox>=2024.10.9', + 'pytest>=8.0.1', + 'pytest-benchmark>=5.0.0', + 'pytest-cache>=1.0', + 'pytest-cov>=5.0.0', + 'pytest-factoryboy>=2.6.1', + 'pytest-instafail>=0.5.0', + 'pytest-xdist[psutil]>=3.5.0' +] +typing = [ + 'mypy[faster-cache]>=1.13.0', + 'types-tabulate>=0.8.10', + 'types-PyYAML>=6.0.10', + 'types-decorator>=5.1.8', + 'types-docutils>=0.21.0', + 'types-pytz>=2024.2.0' +] -# ---- Project description ---- -# -- Standard options (PEP 621) -- +# -- Standard project description options (PEP 621) -- [project] authors = [{name = 'ETH Zurich', email = 'gridtools@cscs.ch'}] classifiers = [ @@ -44,9 +103,10 @@ dependencies = [ 'numpy>=1.23.3', 'packaging>=20.0', 'pybind11>=2.10.1', - 'setuptools>=65.5.0', + 'setuptools>=70.0.0', 'tabulate>=0.8.10', - 'typing-extensions>=4.10.0', + 'toolz>=0.12.1', + 'typing-extensions>=4.11.0', 'xxhash>=1.4.4,<3.1.0' ] description = 'Python library for generating high-performance implementations of stencil kernels for weather and climate modeling from a domain-specific language (DSL)' @@ -60,27 +120,25 @@ keywords = [ 'portable', 'hpc' ] -license = {file = 'LICENSE.txt'} +license = {text = 'BSD-3 License'} # TODO: waiting for PEP 639 being implemented by setuptools (https://github.com/codecov/codecov-cli/issues/605) name = 'gt4py' readme = 'README.md' -requires-python = '>=3.10' +requires-python = '>=3.10, <3.12' [project.optional-dependencies] -# Bundles -all-cpu = ['gt4py[dace,formatting,jax-cpu,performance,testing]'] -all-cuda11 = ['gt4py[cuda11,dace,formatting,jax-cuda11,performance,testing]'] -all-cuda12 = ['gt4py[cuda12,dace,formatting,jax-cuda12,performance,testing]'] -# Other extras +# bundles +all = ['gt4py[dace,formatting,jax,performance,testing]'] +# device-specific extras cuda11 = ['cupy-cuda11x>=12.0'] cuda12 = ['cupy-cuda12x>=12.0'] +# features dace = ['dace>=1.0.0,<1.1.0'] # v1.x will contain breaking changes, see https://github.com/spcl/dace/milestone/4 formatting = ['clang-format>=9.0'] -gpu = ['cupy>=12.0'] -jax-cpu = ['jax[cpu]>=0.4.18'] -jax-cuda11 = ['jax[cuda11_pip]>=0.4.18'] -jax-cuda12 = ['jax[cuda12_pip]>=0.4.18'] +jax = ['jax>=0.4.26'] +jax-cuda12 = ['jax[cuda12_local]>=0.4.26', 'gt4py[cuda12]'] performance = ['scipy>=1.9.2'] -rocm-43 = ['cupy-rocm-4-3'] +rocm4_3 = ['cupy-rocm-4-3>=13.3.0'] +rocm5_0 = ['cupy-rocm-5-0>=13.3.0'] testing = ['hypothesis>=6.0.0', 'pytest>=7.0'] [project.scripts] @@ -89,7 +147,7 @@ gtpyc = 'gt4py.cartesian.cli:gtpyc' [project.urls] Documentation = 'https://gridtools.github.io/gt4py' Homepage = 'https://gridtools.github.io/' -Source = 'https://github.com/GridTools/gt4py' +Repository = 'https://github.com/GridTools/gt4py' # ---- Other tools ---- # -- bump-my-version -- @@ -97,7 +155,7 @@ Source = 'https://github.com/GridTools/gt4py' allow_dirty = false commit = false commit_args = '' -current_version = "1.0.4" +current_version = '1.0.4' ignore_missing_version = false message = 'Bump version: {current_version} → {new_version}' parse = '(?P\d+)\.(?P\d+)(\.(?P\d+))?' @@ -111,7 +169,7 @@ tag_message = 'Bump version: {current_version} → {new_version}' tag_name = 'v{new_version}' [[tool.bumpversion.files]] -filename = "src/gt4py/__about__.py" +filename = 'src/gt4py/__about__.py' # -- coverage -- [tool.coverage] @@ -120,7 +178,7 @@ filename = "src/gt4py/__about__.py" directory = 'tests/_reports/coverage_html' [tool.coverage.paths] -source = ['src/', '.tox/py*/lib/python3.*/site-packages/'] +source = ['src/', '.nox/py*/lib/python3.*/site-packages/'] [tool.coverage.report] # Regexes for lines to exclude from consideration @@ -303,6 +361,9 @@ select = ['E', 'F', 'I', 'B', 'A', 'T10', 'ERA', 'NPY', 'RUF'] typing-modules = ['gt4py.eve.extended_typing'] unfixable = [] +[tool.ruff.lint.flake8-builtins] +builtins-allowed-modules = ['builtins'] + [tool.ruff.lint.isort] combine-as-imports = true # force-wrap-aliases = true @@ -368,3 +429,22 @@ version = {attr = 'gt4py.__about__.__version__'} [tool.setuptools.packages] find = {namespaces = false, where = ['src']} + +# -- uv: packages & workspace -- +[tool.uv] +conflicts = [ + [ + {extra = 'cuda11'}, + {extra = 'jax-cuda12'}, + {extra = 'rocm4_3'}, + {extra = 'rocm5_0'} + ] +] + +[[tool.uv.index]] +explicit = true +name = 'test.pypi' +url = 'https://test.pypi.org/simple/' + +[tool.uv.sources] +atlas4py = {index = "test.pypi"} diff --git a/requirements-dev.in b/requirements-dev.in deleted file mode 100644 index 1697051d25..0000000000 --- a/requirements-dev.in +++ /dev/null @@ -1,36 +0,0 @@ -# -# Constraints should specify the minimum required version (>=). -# -# Packages also required in the extra `gt4py['all-cpu']` configuration -# should be added here without constraints, so they will use the -# constraints defined in `pyproject.toml`. -# -bump-my-version>=0.12.0 -clang-format>=9.0 -cogapp>=3.3 -coverage[toml]>=5.0 -darglint>=1.6 -hypothesis # constraints in gt4py['testing'] -jupytext>=1.14 -mypy>=1.0 -matplotlib>=3.3 -nbmake>=1.4.6 -pipdeptree>=2.3 -pip-tools>=6.10 -pre-commit>=2.17 -psutil>=5.0 -pygments>=2.7.3 -pytest # constraints in gt4py['testing'] -pytest-cache>=1.0 -pytest-cov>=2.8 -pytest-custom-exit-code>=0.3.0 -pytest-factoryboy>=2.0.3 -pytest-xdist[psutil]>=2.4 -pytest-instafail>=0.5.0 -ruff>=0.2.0 -sphinx>=4.4 -sphinx_rtd_theme>=1.0 -tach>=0.10.7 -tomli>=2.0.1;python_version<'3.11' -tox>=3.2.0 -types-tabulate>=0.8.10 diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 463b1bc6ac..0000000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,178 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.10 -# by the following command: -# -# "tox run -e requirements-base" -# -aenum==3.1.15 # via -c constraints.txt, dace -alabaster==1.0.0 # via -c constraints.txt, sphinx -annotated-types==0.7.0 # via -c constraints.txt, pydantic -asttokens==2.4.1 # via -c constraints.txt, devtools, stack-data -astunparse==1.6.3 # via -c constraints.txt, dace -attrs==24.3.0 # via -c constraints.txt, gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.16.0 # via -c constraints.txt, sphinx -black==24.10.0 # via -c constraints.txt, gt4py (pyproject.toml) -boltons==24.1.0 # via -c constraints.txt, gt4py (pyproject.toml) -bracex==2.5.post1 # via -c constraints.txt, wcmatch -build==1.2.2.post1 # via -c constraints.txt, pip-tools -bump-my-version==0.29.0 # via -c constraints.txt, -r requirements-dev.in -cached-property==2.0.1 # via -c constraints.txt, gt4py (pyproject.toml) -cachetools==5.5.0 # via -c constraints.txt, tox -certifi==2024.12.14 # via -c constraints.txt, requests -cfgv==3.4.0 # via -c constraints.txt, pre-commit -chardet==5.2.0 # via -c constraints.txt, tox -charset-normalizer==3.4.1 # via -c constraints.txt, requests -clang-format==19.1.6 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) -click==8.1.8 # via -c constraints.txt, black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.31.2 # via -c constraints.txt, gt4py (pyproject.toml) -cogapp==3.4.1 # via -c constraints.txt, -r requirements-dev.in -colorama==0.4.6 # via -c constraints.txt, tox -comm==0.2.2 # via -c constraints.txt, ipykernel -contourpy==1.3.1 # via -c constraints.txt, matplotlib -coverage[toml]==7.6.10 # via -c constraints.txt, -r requirements-dev.in, pytest-cov -cycler==0.12.1 # via -c constraints.txt, matplotlib -cytoolz==1.0.1 # via -c constraints.txt, gt4py (pyproject.toml) -dace==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) -darglint==1.8.1 # via -c constraints.txt, -r requirements-dev.in -debugpy==1.8.11 # via -c constraints.txt, ipykernel -decorator==5.1.1 # via -c constraints.txt, ipython -deepdiff==8.1.1 # via -c constraints.txt, gt4py (pyproject.toml) -devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml) -dill==0.3.9 # via -c constraints.txt, dace -diskcache==5.6.3 # via -c constraints.txt, gt4py (pyproject.toml) -distlib==0.3.9 # via -c constraints.txt, virtualenv -docutils==0.21.2 # via -c constraints.txt, sphinx, sphinx-rtd-theme -exceptiongroup==1.2.2 # via -c constraints.txt, hypothesis, ipython, pytest -execnet==2.1.1 # via -c constraints.txt, pytest-cache, pytest-xdist -executing==2.1.0 # via -c constraints.txt, devtools, stack-data -factory-boy==3.3.1 # via -c constraints.txt, gt4py (pyproject.toml), pytest-factoryboy -faker==33.3.0 # via -c constraints.txt, factory-boy -fastjsonschema==2.21.1 # via -c constraints.txt, nbformat -filelock==3.16.1 # via -c constraints.txt, gt4py (pyproject.toml), tox, virtualenv -fonttools==4.55.3 # via -c constraints.txt, matplotlib -fparser==0.2.0 # via -c constraints.txt, dace -frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) -gitdb==4.0.12 # via -c constraints.txt, gitpython -gitpython==3.1.44 # via -c constraints.txt, tach -gridtools-cpp==2.3.8 # via -c constraints.txt, gt4py (pyproject.toml) -hypothesis==6.123.11 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.6.5 # via -c constraints.txt, pre-commit -idna==3.10 # via -c constraints.txt, requests -imagesize==1.4.1 # via -c constraints.txt, sphinx -inflection==0.5.1 # via -c constraints.txt, pytest-factoryboy -iniconfig==2.0.0 # via -c constraints.txt, pytest -ipykernel==6.29.5 # via -c constraints.txt, nbmake -ipython==8.31.0 # via -c constraints.txt, ipykernel -jax[cpu]==0.4.38 # via -c constraints.txt, gt4py (pyproject.toml) -jaxlib==0.4.38 # via -c constraints.txt, jax -jedi==0.19.2 # via -c constraints.txt, ipython -jinja2==3.1.5 # via -c constraints.txt, gt4py (pyproject.toml), sphinx -jsonschema==4.23.0 # via -c constraints.txt, nbformat -jsonschema-specifications==2024.10.1 # via -c constraints.txt, jsonschema -jupyter-client==8.6.3 # via -c constraints.txt, ipykernel, nbclient -jupyter-core==5.7.2 # via -c constraints.txt, ipykernel, jupyter-client, nbclient, nbformat -jupytext==1.16.6 # via -c constraints.txt, -r requirements-dev.in -kiwisolver==1.4.8 # via -c constraints.txt, matplotlib -lark==1.2.2 # via -c constraints.txt, gt4py (pyproject.toml) -mako==1.3.8 # via -c constraints.txt, gt4py (pyproject.toml) -markdown-it-py==3.0.0 # via -c constraints.txt, jupytext, mdit-py-plugins, rich -markupsafe==3.0.2 # via -c constraints.txt, jinja2, mako -matplotlib==3.10.0 # via -c constraints.txt, -r requirements-dev.in -matplotlib-inline==0.1.7 # via -c constraints.txt, ipykernel, ipython -mdit-py-plugins==0.4.2 # via -c constraints.txt, jupytext -mdurl==0.1.2 # via -c constraints.txt, markdown-it-py -ml-dtypes==0.5.1 # via -c constraints.txt, jax, jaxlib -mpmath==1.3.0 # via -c constraints.txt, sympy -mypy==1.14.1 # via -c constraints.txt, -r requirements-dev.in -mypy-extensions==1.0.0 # via -c constraints.txt, black, mypy -nanobind==2.4.0 # via -c constraints.txt, gt4py (pyproject.toml) -nbclient==0.10.2 # via -c constraints.txt, nbmake -nbformat==5.10.4 # via -c constraints.txt, jupytext, nbclient, nbmake -nbmake==1.5.5 # via -c constraints.txt, -r requirements-dev.in -nest-asyncio==1.6.0 # via -c constraints.txt, ipykernel -networkx==3.4.2 # via -c constraints.txt, dace, tach -ninja==1.11.1.3 # via -c constraints.txt, gt4py (pyproject.toml) -nodeenv==1.9.1 # via -c constraints.txt, pre-commit -numpy==1.26.4 # via -c constraints.txt, contourpy, dace, gt4py (pyproject.toml), jax, jaxlib, matplotlib, ml-dtypes, scipy -opt-einsum==3.4.0 # via -c constraints.txt, jax -orderly-set==5.2.3 # via -c constraints.txt, deepdiff -packaging==24.2 # via -c constraints.txt, black, build, dace, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox -parso==0.8.4 # via -c constraints.txt, jedi -pathspec==0.12.1 # via -c constraints.txt, black -pexpect==4.9.0 # via -c constraints.txt, ipython -pillow==11.1.0 # via -c constraints.txt, matplotlib -pip-tools==7.4.1 # via -c constraints.txt, -r requirements-dev.in -pipdeptree==2.24.0 # via -c constraints.txt, -r requirements-dev.in -platformdirs==4.3.6 # via -c constraints.txt, black, jupyter-core, tox, virtualenv -pluggy==1.5.0 # via -c constraints.txt, pytest, tox -ply==3.11 # via -c constraints.txt, dace -pre-commit==4.0.1 # via -c constraints.txt, -r requirements-dev.in -prompt-toolkit==3.0.48 # via -c constraints.txt, ipython, questionary, tach -psutil==6.1.1 # via -c constraints.txt, -r requirements-dev.in, ipykernel, pytest-xdist -ptyprocess==0.7.0 # via -c constraints.txt, pexpect -pure-eval==0.2.3 # via -c constraints.txt, stack-data -pybind11==2.13.6 # via -c constraints.txt, gt4py (pyproject.toml) -pydantic==2.10.4 # via -c constraints.txt, bump-my-version, pydantic-settings -pydantic-core==2.27.2 # via -c constraints.txt, pydantic -pydantic-settings==2.7.1 # via -c constraints.txt, bump-my-version -pydot==3.0.4 # via -c constraints.txt, tach -pygments==2.19.1 # via -c constraints.txt, -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx -pyparsing==3.2.1 # via -c constraints.txt, matplotlib, pydot -pyproject-api==1.8.0 # via -c constraints.txt, tox -pyproject-hooks==1.2.0 # via -c constraints.txt, build, pip-tools -pytest==8.3.4 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist -pytest-cache==1.0 # via -c constraints.txt, -r requirements-dev.in -pytest-cov==6.0.0 # via -c constraints.txt, -r requirements-dev.in -pytest-custom-exit-code==0.3.0 # via -c constraints.txt, -r requirements-dev.in -pytest-factoryboy==2.7.0 # via -c constraints.txt, -r requirements-dev.in -pytest-instafail==0.5.0 # via -c constraints.txt, -r requirements-dev.in -pytest-xdist[psutil]==3.6.1 # via -c constraints.txt, -r requirements-dev.in -python-dateutil==2.9.0.post0 # via -c constraints.txt, faker, jupyter-client, matplotlib -python-dotenv==1.0.1 # via -c constraints.txt, pydantic-settings -pyyaml==6.0.2 # via -c constraints.txt, dace, jupytext, pre-commit, tach -pyzmq==26.2.0 # via -c constraints.txt, ipykernel, jupyter-client -questionary==2.1.0 # via -c constraints.txt, bump-my-version -referencing==0.35.1 # via -c constraints.txt, jsonschema, jsonschema-specifications -requests==2.32.3 # via -c constraints.txt, sphinx -rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach -rich-click==1.8.5 # via -c constraints.txt, bump-my-version -rpds-py==0.22.3 # via -c constraints.txt, jsonschema, referencing -ruff==0.8.6 # via -c constraints.txt, -r requirements-dev.in -scipy==1.15.0 # via -c constraints.txt, jax, jaxlib -setuptools-scm==8.1.0 # via -c constraints.txt, fparser -six==1.17.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil -smmap==5.0.2 # via -c constraints.txt, gitdb -snowballstemmer==2.2.0 # via -c constraints.txt, sphinx -sortedcontainers==2.4.0 # via -c constraints.txt, hypothesis -sphinx==8.1.3 # via -c constraints.txt, -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==3.0.2 # via -c constraints.txt, -r requirements-dev.in -sphinxcontrib-applehelp==2.0.0 # via -c constraints.txt, sphinx -sphinxcontrib-devhelp==2.0.0 # via -c constraints.txt, sphinx -sphinxcontrib-htmlhelp==2.1.0 # via -c constraints.txt, sphinx -sphinxcontrib-jquery==4.1 # via -c constraints.txt, sphinx-rtd-theme -sphinxcontrib-jsmath==1.0.1 # via -c constraints.txt, sphinx -sphinxcontrib-qthelp==2.0.0 # via -c constraints.txt, sphinx -sphinxcontrib-serializinghtml==2.0.0 # via -c constraints.txt, sphinx -stack-data==0.6.3 # via -c constraints.txt, ipython -sympy==1.13.3 # via -c constraints.txt, dace -tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.19.5 # via -c constraints.txt, -r requirements-dev.in -tomli==2.2.1 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, sphinx, tach, tox -tomli-w==1.1.0 # via -c constraints.txt, tach -tomlkit==0.13.2 # via -c constraints.txt, bump-my-version -toolz==1.0.0 # via -c constraints.txt, cytoolz -tornado==6.4.2 # via -c constraints.txt, ipykernel, jupyter-client -tox==4.23.2 # via -c constraints.txt, -r requirements-dev.in -traitlets==5.14.3 # via -c constraints.txt, comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat -types-tabulate==0.9.0.20241207 # via -c constraints.txt, -r requirements-dev.in -typing-extensions==4.12.2 # via -c constraints.txt, black, faker, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, tox -urllib3==2.3.0 # via -c constraints.txt, requests -virtualenv==20.28.1 # via -c constraints.txt, pre-commit, tox -wcmatch==10.0 # via -c constraints.txt, bump-my-version -wcwidth==0.2.13 # via -c constraints.txt, prompt-toolkit -wheel==0.45.1 # via -c constraints.txt, astunparse, pip-tools -xxhash==3.0.0 # via -c constraints.txt, gt4py (pyproject.toml) - -# The following packages are considered to be unsafe in a requirements file: -pip==24.3.1 # via -c constraints.txt, pip-tools, pipdeptree -setuptools==75.8.0 # via -c constraints.txt, gt4py (pyproject.toml), pip-tools, setuptools-scm diff --git a/tach.toml b/tach.toml index 7861ed1fe6..d23b5fb14d 100644 --- a/tach.toml +++ b/tach.toml @@ -3,7 +3,9 @@ source_roots = [ "src", ] exact = true -forbid_circular_dependencies = true +# forbid_circular_dependencies = true +# TODO(egparedes): try to solve the circular dependencies between +# gt4py.cartesian and gt4py.storage [[modules]] path = "gt4py._core" diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py index 9fafc27c85..16a1860d9d 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py @@ -29,7 +29,7 @@ ) -pytestmark = pytest.mark.usefixtures("dace_env") +pytestmark = [pytest.mark.requires_dace, pytest.mark.usefixtures("dace_env")] @pytest.fixture(scope="module") diff --git a/tests/conftest.py b/tests/conftest.py index 285ccda2b0..1bf73651a2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,5 +8,55 @@ """Global configuration of pytest for collecting and running tests.""" +import collections.abc +import functools +import sys +import types +from typing import Final + +import pytest + + # Ignore hidden folders and disabled tests collect_ignore_glob = [".*", "_disabled*"] + +# Custom module attribute to store package-level marks +_PKG_MARKS_ATTR_NAME: Final = "package_pytestmarks" + + +@functools.cache +def _get_pkg_marks(module_name: str) -> list[pytest.Mark | str]: + """Collect markers in the `package_pytestmarks` module attribute (and recursively from its parents).""" + module = sys.modules[module_name] + pkg_markers = getattr(module, _PKG_MARKS_ATTR_NAME, []) + assert isinstance( + pkg_markers, collections.abc.Sequence + ), f"'{_PKG_MARKS_ATTR_NAME}' content must be a sequence of marks" + + if (parent := module_name.rsplit(".", 1)[0]) != module_name: + pkg_markers += _get_pkg_marks(parent) + + return pkg_markers + + +def pytest_collection_modifyitems( + session: pytest.Session, config: pytest.Config, items: list[pytest.Item] +) -> None: + """Pytest hook to modify the collected test items. + + See: https://docs.pytest.org/en/stable/reference/reference.html#pytest.hookspec.pytest_collection_modifyitems + """ + for item in items: + # Visit the chain of parents of the current test item in reverse order, + # until we get to the module object where the test function (or class) + # has been defined. At that point, process the custom package-level marks + # attribute if present, and move to the next collected item in the list. + for node in item.listchain()[-2::-1]: + if not (obj := getattr(node, "obj", None)): + break + if not isinstance(obj, types.ModuleType): + continue + + module_name = obj.__name__ + for marker in _get_pkg_marks(module_name): + item.add_marker(marker) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 6ffbc667bb..522250cafc 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -11,6 +11,7 @@ import dataclasses import enum import importlib +from typing import Final import pytest @@ -53,11 +54,17 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): @dataclasses.dataclass(frozen=True) class EmbeddedDummyBackend: + name: str allocator: next_allocators.FieldBufferAllocatorProtocol + executor: Final = None -numpy_execution = EmbeddedDummyBackend(next_allocators.StandardCPUFieldBufferAllocator()) -cupy_execution = EmbeddedDummyBackend(next_allocators.StandardGPUFieldBufferAllocator()) +numpy_execution = EmbeddedDummyBackend( + "EmbeddedNumPy", next_allocators.StandardCPUFieldBufferAllocator() +) +cupy_execution = EmbeddedDummyBackend( + "EmbeddedCuPy", next_allocators.StandardGPUFieldBufferAllocator() +) class EmbeddedIds(_PythonObjectIdMixin, str, enum.Enum): diff --git a/tests/next_tests/integration_tests/feature_tests/dace/__init__.py b/tests/next_tests/integration_tests/feature_tests/dace/__init__.py index abf4c3e24c..7a9cb1ece5 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/__init__.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/__init__.py @@ -6,3 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import pytest + +#: Attribute defining package-level marks used by a custom pytest hook. +package_pytestmarks = [pytest.mark.requires_dace] diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index 22af788845..8fe0634302 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -30,12 +30,7 @@ ) -try: - import dace -except ImportError: - dace: Optional[ModuleType] = None # type:ignore[no-redef] - -pytestmark = pytest.mark.requires_dace +dace = pytest.importorskip("dace") def test_sdfgConvertible_laplap(cartesian_case): # noqa: F811 diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py index 2a7a3710a9..4edaf9f85f 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py @@ -23,18 +23,9 @@ ) -try: - import dace +dace = pytest.importorskip("dace") - from gt4py.next.program_processors.runners import dace as dace_backends -except ImportError: - from types import ModuleType - from typing import Optional - - from gt4py.next import backend as next_backend - - dace: Optional[ModuleType] = None - dace_backends: Optional[ModuleType] = None +from gt4py.next.program_processors.runners import dace as dace_backends @pytest.fixture( diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py index 7358ab3d8f..2356e9c781 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py @@ -6,15 +6,15 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import typing import pytest from gt4py import next as gtx -from gt4py.next import common from gt4py.next.iterator.transforms import extractors -from next_tests.integration_tests import cases from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( IDim, JDim, @@ -22,53 +22,14 @@ ) -if typing.TYPE_CHECKING: - from types import ModuleType - from typing import Optional - -try: - import dace - - from gt4py.next.program_processors.runners.dace import run_dace_cpu -except ImportError: - from gt4py.next import backend as next_backend - - dace: Optional[ModuleType] = None - run_dace_cpu: Optional[next_backend.Backend] = None - - -@pytest.fixture(params=[pytest.param(run_dace_cpu, marks=pytest.mark.requires_dace), gtx.gtfn_cpu]) -def gtir_dace_backend(request): - yield request.param - - -@pytest.fixture -def cartesian(request, gtir_dace_backend): - if gtir_dace_backend is None: - yield None - - yield cases.Case( - backend=gtir_dace_backend, - offset_provider={ - "Ioff": IDim, - "Joff": JDim, - "Koff": KDim, - }, - default_sizes={IDim: 10, JDim: 10, KDim: 10}, - grid_type=common.GridType.CARTESIAN, - allocator=gtir_dace_backend.allocator, - ) - - -@pytest.mark.skipif(dace is None, reason="DaCe not found") -def test_input_names_extractor_cartesian(cartesian): - @gtx.field_operator(backend=cartesian.backend) +def test_input_names_extractor_cartesian(): + @gtx.field_operator def testee_op( a: gtx.Field[[IDim, JDim, KDim], gtx.int], ) -> gtx.Field[[IDim, JDim, KDim], gtx.int]: return a - @gtx.program(backend=cartesian.backend) + @gtx.program def testee( a: gtx.Field[[IDim, JDim, KDim], gtx.int], b: gtx.Field[[IDim, JDim, KDim], gtx.int], @@ -81,15 +42,14 @@ def testee( assert input_field_names == {"a", "b"} -@pytest.mark.skipif(dace is None, reason="DaCe not found") -def test_output_names_extractor(cartesian): - @gtx.field_operator(backend=cartesian.backend) +def test_output_names_extractor(): + @gtx.field_operator def testee_op( a: gtx.Field[[IDim, JDim, KDim], gtx.int], ) -> gtx.Field[[IDim, JDim, KDim], gtx.int]: return a - @gtx.program(backend=cartesian.backend) + @gtx.program def testee( a: gtx.Field[[IDim, JDim, KDim], gtx.int], b: gtx.Field[[IDim, JDim, KDim], gtx.int], diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py index 6c6ca7e4bc..da354be7ea 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py @@ -13,6 +13,7 @@ pytest.importorskip("atlas4py") +import gt4py._core.definitions as core_defs from gt4py import next as gtx from gt4py.next import allocators, neighbor_sum from gt4py.next.iterator import atlas_utils @@ -62,12 +63,15 @@ def pnabla( return compute_pnabla(pp, S_M[0], sign, vol), compute_pnabla(pp, S_M[1], sign, vol) +@pytest.mark.requires_atlas def test_ffront_compute_zavgS(exec_alloc_descriptor): - _, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator + # TODO(havogt): fix nabla setup to work with GPU + if exec_alloc_descriptor.allocator.device_type != core_defs.DeviceType.CPU: + pytest.skip("This test is only supported on CPU devices yet") - setup = nabla_setup(allocator=allocator) + setup = nabla_setup(allocator=exec_alloc_descriptor.allocator) - zavgS = gtx.zeros({Edge: setup.edges_size}, allocator=allocator) + zavgS = gtx.zeros({Edge: setup.edges_size}, allocator=exec_alloc_descriptor.allocator) compute_zavgS.with_backend( None if exec_alloc_descriptor.executor is None else exec_alloc_descriptor @@ -82,13 +86,16 @@ def test_ffront_compute_zavgS(exec_alloc_descriptor): assert_close(388241977.58389181, np.max(zavgS.asnumpy())) +@pytest.mark.requires_atlas def test_ffront_nabla(exec_alloc_descriptor): - _, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator + # TODO(havogt): fix nabla setup to work with GPU + if exec_alloc_descriptor.allocator.device_type != core_defs.DeviceType.CPU: + pytest.skip("This test is only supported on CPU devices yet") - setup = nabla_setup(allocator=allocator) + setup = nabla_setup(allocator=exec_alloc_descriptor.allocator) - pnabla_MXX = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator) - pnabla_MYY = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator) + pnabla_MXX = gtx.zeros({Vertex: setup.nodes_size}, allocator=exec_alloc_descriptor.allocator) + pnabla_MYY = gtx.zeros({Vertex: setup.nodes_size}, allocator=exec_alloc_descriptor.allocator) pnabla.with_backend(None if exec_alloc_descriptor.executor is None else exec_alloc_descriptor)( setup.input_field, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/__init__.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/__init__.py index 9fa07e46e9..1cdf0f0591 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/__init__.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/__init__.py @@ -9,4 +9,5 @@ import pytest -pytestmark = pytest.mark.requires_dace +#: Attribute defining package-level marks used by a custom pytest hook. +package_pytestmarks = [pytest.mark.requires_dace] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py index 62d88d9f0a..ca4a1e0f1f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py @@ -27,8 +27,6 @@ mesh_descriptor, ) -from . import pytestmark - dace = pytest.importorskip("dace") diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index bfde179e33..7431ad2b4a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -34,8 +34,6 @@ skip_value_mesh, ) -from . import pytestmark - dace_backend = pytest.importorskip("gt4py.next.program_processors.runners.dace") diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/__init__.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/__init__.py index 6c3b1060b6..a576665ee3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/__init__.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/__init__.py @@ -9,4 +9,5 @@ import pytest -pytestmark = [pytest.mark.requires_dace, pytest.mark.usefixtures("set_dace_settings")] +#: Attribute defining package-level marks used by a custom pytest hook. +package_pytestmarks = [pytest.mark.usefixtures("common_dace_config")] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py index 0eb0bf39c2..c3455c37cc 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py @@ -12,7 +12,7 @@ @pytest.fixture(autouse=True) -def set_dace_settings() -> Generator[None, None, None]: +def common_dace_config() -> Generator[None, None, None]: """Sets the common DaCe settings for the tests. The function will modify the following settings: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py index 9241bae4bf..ae3624ce13 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py @@ -6,9 +6,13 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import pytest import numpy as np +dace = pytest.importorskip("dace") + from gt4py.next.program_processors.runners.dace import ( transformations as gtx_transformations, ) @@ -16,11 +20,6 @@ from . import util -# dace = pytest.importorskip("dace") -from dace.sdfg import nodes as dace_nodes -import dace - - def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: sdfg = dace.SDFG(util.unique_name("distributed_buffer_sdfg")) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index cdc66d4ffd..350fa807a1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -18,7 +18,7 @@ gpu_utils as gtx_dace_fieldview_gpu_utils, ) -from . import pytestmark + from . import util diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py index a08cf12a5a..4d7a8156d7 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py @@ -20,7 +20,7 @@ transformations as gtx_transformations, ) -from . import pytestmark + from . import util diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py index 516a70b579..cd4ad77787 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py @@ -21,7 +21,7 @@ transformations as gtx_transformations, ) -from . import pytestmark + from . import util diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py index 762040e20d..d82127f6f3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py @@ -6,9 +6,13 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import pytest import numpy as np +dace = pytest.importorskip("dace") + from gt4py.next.program_processors.runners.dace import ( transformations as gtx_transformations, ) @@ -16,10 +20,6 @@ from . import util -dace = pytest.importorskip("dace") -from dace.sdfg import nodes as dace_nodes - - def _perform_reorder_test( sdfg: dace.SDFG, leading_dim: list[str], diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py index fa7c7255e3..3bd0ed2dc3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py @@ -16,7 +16,7 @@ transformations as gtx_transformations, ) -from . import pytestmark + from . import util diff --git a/tox.ini b/tox.ini deleted file mode 100644 index e7bfd4a3e4..0000000000 --- a/tox.ini +++ /dev/null @@ -1,190 +0,0 @@ -[tox] -requires = - tox>=4.2 - virtualenv>20.2 -envlist = - cartesian-py{310}-{internal,dace}-{cpu} - eve-py{310} - next-py{310}-{nomesh,atlas}-{cpu} - storage-py{310}-{internal,dace}-{cpu} - # docs -labels = - test-cartesian-cpu = cartesian-internal-py310-cpu, cartesian-py311-internal-cpu, cartesian-py310-dace-cpu, cartesian-py311-dace-cpu - test-eve-cpu = eve-py310, eve-py311 - test-next-cpu = next-py310-nomesh-cpu, next-py311-nomesh-cpu, next-py310-atlas-cpu, next-py311-atlas-cpu - test-storage-cpu = storage-py310-internal-cpu, storage-py311-internal-cpu, storage-py310-dace-cpu, storage-py311-dace-cpu - test-cpu = cartesian-py310-internal-cpu, cartesian-py311-internal-cpu, cartesian-py310-dace-cpu, cartesian-py311-dace-cpu, \ - eve-py310, eve-py311, \ - next-py310-nomesh-cpu, next-py311-nomesh-cpu, next-py310-atlas-cpu, next-py311-atlas-cpu, \ - storage-py310-internal-cpu, storage-py311-internal-cpu, storage-py310-dace-cpu, storage-py311-dace-cpu - -[testenv] -deps = -r {tox_root}{/}{env:ENV_REQUIREMENTS_FILE:requirements-dev.txt} -constrain_package_deps = true -use_frozen_constraints = true -extras = - testing - formatting - dace: dace - cuda: cuda - cuda11x: cuda11x - cuda12x: cuda12x -package = wheel -wheel_build_env = .pkg -pass_env = CUDAARCHS, NUM_PROCESSES, GT4PY_* -set_env = - PYTEST_ADDOPTS = --color=auto --instafail - PYTHONWARNINGS = {env:PYTHONWARNINGS:ignore:Support for `[tool.setuptools]` in `pyproject.toml` is still *beta*:UserWarning,ignore:Field View Program:UserWarning} - -# -- Primary tests -- -[testenv:cartesian-py{310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] -description = Run 'gt4py.cartesian' tests -pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH, CXX, CC, OPENMP_CPPFLAGS, OPENMP_LDFLAGS, PIP_USER, PYTHONUSERBASE -allowlist_externals = - make - gcc - g++ - ldd - rm -commands = - python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "\ - internal: not requires_dace \ - dace: requires_dace \ - cpu: and not requires_gpu \ - {cuda,cuda11x,cuda12x}: and requires_gpu \ - " {posargs} tests{/}cartesian_tests - python -m pytest --doctest-modules --doctest-ignore-import-errors src{/}gt4py{/}cartesian -# commands_pre = -# rm -Rf tests/_reports/coverage* -# commands_post = -# coverage json --rcfile=setup.cfg -# coverage html --rcfile=setup.cfg --show-contexts - -[testenv:eve-py{310,311}] -description = Run 'gt4py.eve' tests -commands = - python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} {posargs} tests{/}eve_tests - python -m pytest --doctest-modules src{/}gt4py{/}eve - -[testenv:next-py{310,311}-{nomesh,atlas}-{cpu,cuda,cuda11x,cuda12x}] -description = Run 'gt4py.next' tests -pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH -deps = - -r {tox_root}{/}requirements-dev.txt - atlas: atlas4py -set_env = - {[testenv]set_env} - PIP_EXTRA_INDEX_URL = {env:PIP_EXTRA_INDEX_URL:https://test.pypi.org/simple/} -commands = - python -m pytest --suppress-no-test-exit-code --cache-clear -v -n {env:NUM_PROCESSES:1} -m "\ - nomesh: not requires_atlas \ - atlas: requires_atlas \ - cpu: and not requires_gpu \ - {cuda,cuda11x,cuda12x}: and requires_gpu \ - " {posargs} tests{/}next_tests - pytest --doctest-modules src{/}gt4py{/}next - -[testenv:storage-py{310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] -description = Run 'gt4py.storage' tests -commands = - python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "\ - cpu: not requires_gpu \ - {cuda,cuda11x,cuda12x}: requires_gpu \ - " {posargs} tests{/}storage_tests - # pytest doctest-modules {posargs} src{/}gt4py{/}storage - -# -- Secondary tests -- -[testenv:notebooks-py{310,311}] -description = Run notebooks -commands_pre = - jupytext docs/user/next/QuickstartGuide.md --to .ipynb - jupytext docs/user/next/advanced/*.md --to .ipynb -commands = - python -m pytest --nbmake docs/user/next/workshop/slides -v -n {env:NUM_PROCESSES:1} - python -m pytest --nbmake docs/user/next/workshop/exercises -k 'solutions' -v -n {env:NUM_PROCESSES:1} - python -m pytest --nbmake docs/user/next/QuickstartGuide.ipynb -v -n {env:NUM_PROCESSES:1} - python -m pytest --nbmake docs/user/next/advanced -v -n {env:NUM_PROCESSES:1} - python -m pytest --nbmake examples -v -n {env:NUM_PROCESSES:1} - -# -- Other artefacts -- -[testenv:dev-py{310,311}{-atlas,}] -description = Initialize development environment for gt4py -deps = - -r {tox_root}{/}requirements-dev.txt - atlas: atlas4py -package = editable-legacy # => use_develop = True -set_env = - {[testenv]set_env} - PIP_EXTRA_INDEX_URL = {env:PIP_EXTRA_INDEX_URL:https://test.pypi.org/simple/} - -# [testenv:diagrams] -# install_command = echo {packages} -# skip_install = true -# allowlist_externals = -# /bin/bash -# make -# gcc -# g++ -# ldd -# rm -# plantuml -# git -# echo -# changedir = docs/development/ADRs -# commands = -# plantuml ./*.md -tsvg -o _static -# git add _static -# commands_post = - -[testenv:requirements-{base,py310,py311}] -description = - base: Update pinned development requirements - py310: Update requirements for testing a specific python version - py311: Update requirements for testing a specific python version -base_python = - base: py310 - py310: py310 - py311: py311 -deps = - cogapp>=3.3 - packaging>=20.0 - pip-tools>=6.10 -package = skip -set_env = - CUSTOM_COMPILE_COMMAND = "tox run -e requirements-base" -allowlist_externals = - mv -commands = - -mv constraints.txt constraints.txt.old - -mv requirements-dev.txt requirements-dev.old - # Run cog to update requirements files from pyproject - cog -r -P min-requirements-test.txt min-extra-requirements-test.txt - # Generate constraints file removing extras - # (extras are not supported by pip in constraints files) - pip-compile -r --resolver=backtracking \ - --annotation-style line \ - --build-isolation \ - --strip-extras \ - --allow-unsafe \ - --extra dace \ - --extra formatting \ - --extra jax-cpu \ - --extra performance \ - --extra testing \ - -o constraints.txt \ - pyproject.toml requirements-dev.in - # Generate actual requirements file - # (compiling from scratch again to print actual package sources) - pip-compile --resolver=backtracking \ - --annotation-style line \ - --build-isolation \ - --allow-unsafe \ - --extra dace \ - --extra formatting \ - --extra jax-cpu \ - --extra testing \ - -c constraints.txt \ - -o requirements-dev.txt \ - pyproject.toml requirements-dev.in - # Run cog to update .pre-commit-config.yaml with new versions - base: cog -r -P .pre-commit-config.yaml diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000000..1d050717af --- /dev/null +++ b/uv.lock @@ -0,0 +1,3345 @@ +version = 1 +requires-python = ">=3.10, <3.12" +resolution-markers = [ + "python_full_version < '3.11'", + "python_full_version >= '3.11'", +] +conflicts = [[ + { package = "gt4py", extra = "cuda11" }, + { package = "gt4py", extra = "jax-cuda12" }, + { package = "gt4py", extra = "rocm4-3" }, + { package = "gt4py", extra = "rocm5-0" }, +]] + +[[package]] +name = "aenum" +version = "3.1.15" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d0/f8/33e75863394f42e429bb553e05fda7c59763f0fd6848de847a25b3fbccf6/aenum-3.1.15.tar.gz", hash = "sha256:8cbd76cd18c4f870ff39b24284d3ea028fbe8731a58df3aa581e434c575b9559", size = 134730 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/fa/ca0c66b388624ba9dbbf35aab3a9f326bfdf5e56a7237fe8f1b600da6864/aenum-3.1.15-py3-none-any.whl", hash = "sha256:e0dfaeea4c2bd362144b87377e2c61d91958c5ed0b4daf89cb6f45ae23af6288", size = 137633 }, +] + +[[package]] +name = "alabaster" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a6/f8/d9c74d0daf3f742840fd818d69cfae176fa332022fd44e3469487d5a9420/alabaster-1.0.0.tar.gz", hash = "sha256:c00dca57bca26fa62a6d7d0a9fcce65f3e026e9bfe33e9c538fd3fbb2144fd9e", size = 24210 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/b3/6b4067be973ae96ba0d615946e314c5ae35f9f993eca561b356540bb0c2b/alabaster-1.0.0-py3-none-any.whl", hash = "sha256:fc6786402dc3fcb2de3cabd5fe455a2db534b371124f1f21de8731783dec828b", size = 13929 }, +] + +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, +] + +[[package]] +name = "apeye" +version = "1.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "apeye-core" }, + { name = "domdf-python-tools" }, + { name = "platformdirs" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4f/6b/cc65e31843d7bfda8313a9dc0c77a21e8580b782adca53c7cb3e511fe023/apeye-1.4.1.tar.gz", hash = "sha256:14ea542fad689e3bfdbda2189a354a4908e90aee4bf84c15ab75d68453d76a36", size = 99219 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/7b/2d63664777b3e831ac1b1d8df5bbf0b7c8bee48e57115896080890527b1b/apeye-1.4.1-py3-none-any.whl", hash = "sha256:44e58a9104ec189bf42e76b3a7fe91e2b2879d96d48e9a77e5e32ff699c9204e", size = 107989 }, +] + +[[package]] +name = "apeye-core" +version = "1.1.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "domdf-python-tools" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e5/4c/4f108cfd06923bd897bf992a6ecb6fb122646ee7af94d7f9a64abd071d4c/apeye_core-1.1.5.tar.gz", hash = "sha256:5de72ed3d00cc9b20fea55e54b7ab8f5ef8500eb33a5368bc162a5585e238a55", size = 96511 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/9f/fa9971d2a0c6fef64c87ba362a493a4f230eff4ea8dfb9f4c7cbdf71892e/apeye_core-1.1.5-py3-none-any.whl", hash = "sha256:dc27a93f8c9e246b3b238c5ea51edf6115ab2618ef029b9f2d9a190ec8228fbf", size = 99286 }, +] + +[[package]] +name = "appnope" +version = "0.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/35/5d/752690df9ef5b76e169e68d6a129fa6d08a7100ca7f754c89495db3c6019/appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee", size = 4170 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/29/5ecc3a15d5a33e31b26c11426c45c501e439cb865d0bff96315d86443b78/appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c", size = 4321 }, +] + +[[package]] +name = "argcomplete" +version = "3.5.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0c/be/6c23d80cb966fb8f83fb1ebfb988351ae6b0554d0c3a613ee4531c026597/argcomplete-3.5.3.tar.gz", hash = "sha256:c12bf50eded8aebb298c7b7da7a5ff3ee24dffd9f5281867dfe1424b58c55392", size = 72999 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c4/08/2a4db06ec3d203124c967fc89295e85a202e5cbbcdc08fd6a64b65217d1e/argcomplete-3.5.3-py3-none-any.whl", hash = "sha256:2ab2c4a215c59fd6caaff41a869480a23e8f6a5f910b266c1808037f4e375b61", size = 43569 }, +] + +[[package]] +name = "asttokens" +version = "2.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/45/1d/f03bcb60c4a3212e15f99a56085d93093a497718adf828d050b9d675da81/asttokens-2.4.1.tar.gz", hash = "sha256:b03869718ba9a6eb027e134bfdf69f38a236d681c83c160d510768af11254ba0", size = 62284 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/86/4736ac618d82a20d87d2f92ae19441ebc7ac9e7a581d7e58bbe79233b24a/asttokens-2.4.1-py2.py3-none-any.whl", hash = "sha256:051ed49c3dcae8913ea7cd08e46a606dba30b79993209636c4875bc1d637bc24", size = 27764 }, +] + +[[package]] +name = "astunparse" +version = "1.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, + { name = "wheel" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f3/af/4182184d3c338792894f34a62672919db7ca008c89abee9b564dd34d8029/astunparse-1.6.3.tar.gz", hash = "sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872", size = 18290 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2b/03/13dde6512ad7b4557eb792fbcf0c653af6076b81e5941d36ec61f7ce6028/astunparse-1.6.3-py2.py3-none-any.whl", hash = "sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8", size = 12732 }, +] + +[[package]] +name = "atlas4py" +version = "0.35.1.dev15" +source = { registry = "https://test.pypi.org/simple/" } +sdist = { url = "https://test-files.pythonhosted.org/packages/59/e4/48ede747be846f80b30d6303d732f96ca44ee9858504140db5222d2345bb/atlas4py-0.35.1.dev15.tar.gz", hash = "sha256:3c4274261d99a03ffd14a23dfb9ee9265ce79d8db7887751f4fbf1a315091664", size = 15079 } +wheels = [ + { url = "https://test-files.pythonhosted.org/packages/7a/47/0d1f8f7ba596a60bef920638724dfcc76f4edbfdb6bb79932b7e12ec45fc/atlas4py-0.35.1.dev15-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:244ae7f016d28ad04f8e9071de34192c1f8a58fd075477e327c4528cad8daacf", size = 6040572 }, + { url = "https://test-files.pythonhosted.org/packages/5d/f5/2b5645ec670b4088816ca7089fae06c6d72f0a4c301ef186ec8ac8e715fd/atlas4py-0.35.1.dev15-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:12b8f63df17e22d0ddc8310ced42e4db903e73b167c4f261f180cc2c011888ca", size = 5752419 }, + { url = "https://test-files.pythonhosted.org/packages/41/c3/03f3f061d28865f307c7916a0b82b8d37efeddb6cd4085aa687718341aee/atlas4py-0.35.1.dev15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cadd6e5de2e0771d129b6242cbe0bd9268bed16d37ee3cc65b97a7de19a67933", size = 5251334 }, + { url = "https://test-files.pythonhosted.org/packages/22/fe/32d912deb54d7e9eaecc652b813b86925616be358222e069ded6e3bea8c6/atlas4py-0.35.1.dev15-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:0cacea024adb384aacb5da5a5a233e23cf8563e4f357e9687eeac0d9c4c9a4d8", size = 6041915 }, + { url = "https://test-files.pythonhosted.org/packages/ef/94/e85cc3588d836e58974f7be1b362ce321f5989ae8c355a75faee5b09f131/atlas4py-0.35.1.dev15-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:d5f3147e8ad52b890ffc4d92b51d6fd2b34bb39b89e09d6d4d5d7fec9f48aa0f", size = 5753565 }, + { url = "https://test-files.pythonhosted.org/packages/36/5e/71c7c054ae756f7cd5a984a44edad85ca20f4a0364ccc10052363314a9f2/atlas4py-0.35.1.dev15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c13f4a4a88dbe0eb056920d57eafa3e0f1e9fc117bd3c8773cfebca945ed8d76", size = 5253094 }, +] + +[[package]] +name = "attrs" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/49/7c/fdf464bcc51d23881d110abd74b512a42b3d5d376a55a831b44c603ae17f/attrs-25.1.0.tar.gz", hash = "sha256:1c97078a80c814273a76b2a298a932eb681c87415c11dee0a6921de7f1b02c3e", size = 810562 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/30/d4986a882011f9df997a55e6becd864812ccfcd821d64aac8570ee39f719/attrs-25.1.0-py3-none-any.whl", hash = "sha256:c75a69e28a550a7e93789579c22aa26b0f5b83b75dc4e08fe092980051e1090a", size = 63152 }, +] + +[[package]] +name = "autodocsumm" +version = "0.2.14" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "sphinx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/03/96/92afe8a7912b327c01f0a8b6408c9556ee13b1aba5b98d587ac7327ff32d/autodocsumm-0.2.14.tar.gz", hash = "sha256:2839a9d4facc3c4eccd306c08695540911042b46eeafcdc3203e6d0bab40bc77", size = 46357 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/bc/3f66af9beb683728e06ca08797e4e9d3e44f432f339718cae3ba856a9cad/autodocsumm-0.2.14-py3-none-any.whl", hash = "sha256:3bad8717fc5190802c60392a7ab04b9f3c97aa9efa8b3780b3d81d615bfe5dc0", size = 14640 }, +] + +[[package]] +name = "babel" +version = "2.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/74/f1bc80f23eeba13393b7222b11d95ca3af2c1e28edca18af487137eefed9/babel-2.16.0.tar.gz", hash = "sha256:d1f3554ca26605fe173f3de0c65f750f5a42f924499bf134de6423582298e316", size = 9348104 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/20/bc79bc575ba2e2a7f70e8a1155618bb1301eaa5132a8271373a6903f73f8/babel-2.16.0-py3-none-any.whl", hash = "sha256:368b5b98b37c06b7daf6696391c3240c938b37767d4584413e8438c5c435fa8b", size = 9587599 }, +] + +[[package]] +name = "beautifulsoup4" +version = "4.12.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "soupsieve" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/ca/824b1195773ce6166d388573fc106ce56d4a805bd7427b624e063596ec58/beautifulsoup4-4.12.3.tar.gz", hash = "sha256:74e3d1928edc070d21748185c46e3fb33490f22f52a3addee9aee0f4f7781051", size = 581181 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/fe/e8c672695b37eecc5cbf43e1d0638d88d66ba3a44c4d321c796f4e59167f/beautifulsoup4-4.12.3-py3-none-any.whl", hash = "sha256:b80878c9f40111313e55da8ba20bdba06d8fa3969fc68304167741bbf9e082ed", size = 147925 }, +] + +[[package]] +name = "black" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "mypy-extensions" }, + { name = "packaging" }, + { name = "pathspec" }, + { name = "platformdirs" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/49/26a7b0f3f35da4b5a65f081943b7bcd22d7002f5f0fb8098ec1ff21cb6ef/black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666", size = 649449 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/3b/4ba3f93ac8d90410423fdd31d7541ada9bcee1df32fb90d26de41ed40e1d/black-25.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759e7ec1e050a15f89b770cefbf91ebee8917aac5c20483bc2d80a6c3a04df32", size = 1629419 }, + { url = "https://files.pythonhosted.org/packages/b4/02/0bde0485146a8a5e694daed47561785e8b77a0466ccc1f3e485d5ef2925e/black-25.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e519ecf93120f34243e6b0054db49c00a35f84f195d5bce7e9f5cfc578fc2da", size = 1461080 }, + { url = "https://files.pythonhosted.org/packages/52/0e/abdf75183c830eaca7589144ff96d49bce73d7ec6ad12ef62185cc0f79a2/black-25.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:055e59b198df7ac0b7efca5ad7ff2516bca343276c466be72eb04a3bcc1f82d7", size = 1766886 }, + { url = "https://files.pythonhosted.org/packages/dc/a6/97d8bb65b1d8a41f8a6736222ba0a334db7b7b77b8023ab4568288f23973/black-25.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:db8ea9917d6f8fc62abd90d944920d95e73c83a5ee3383493e35d271aca872e9", size = 1419404 }, + { url = "https://files.pythonhosted.org/packages/7e/4f/87f596aca05c3ce5b94b8663dbfe242a12843caaa82dd3f85f1ffdc3f177/black-25.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a39337598244de4bae26475f77dda852ea00a93bd4c728e09eacd827ec929df0", size = 1614372 }, + { url = "https://files.pythonhosted.org/packages/e7/d0/2c34c36190b741c59c901e56ab7f6e54dad8df05a6272a9747ecef7c6036/black-25.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96c1c7cd856bba8e20094e36e0f948718dc688dba4a9d78c3adde52b9e6c2299", size = 1442865 }, + { url = "https://files.pythonhosted.org/packages/21/d4/7518c72262468430ead45cf22bd86c883a6448b9eb43672765d69a8f1248/black-25.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bce2e264d59c91e52d8000d507eb20a9aca4a778731a08cfff7e5ac4a4bb7096", size = 1749699 }, + { url = "https://files.pythonhosted.org/packages/58/db/4f5beb989b547f79096e035c4981ceb36ac2b552d0ac5f2620e941501c99/black-25.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:172b1dbff09f86ce6f4eb8edf9dede08b1fce58ba194c87d7a4f1a5aa2f5b3c2", size = 1428028 }, + { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646 }, +] + +[[package]] +name = "boltons" +version = "24.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/84/76/dfc34232b3e88634025563f52a430be0838182647c063f99569086922554/boltons-24.1.0.tar.gz", hash = "sha256:4a49b7d57ee055b83a458c8682a2a6f199d263a8aa517098bda9bab813554b87", size = 240916 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/96/e44606e60a0c005ac5f2a641960a93ca8f449ebdce7479f9bc4f10bead6d/boltons-24.1.0-py3-none-any.whl", hash = "sha256:a1776d47fdc387fb730fba1fe245f405ee184ee0be2fb447dd289773a84aed3b", size = 192196 }, +] + +[[package]] +name = "bracex" +version = "2.5.post1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/6c/57418c4404cd22fe6275b8301ca2b46a8cdaa8157938017a9ae0b3edf363/bracex-2.5.post1.tar.gz", hash = "sha256:12c50952415bfa773d2d9ccb8e79651b8cdb1f31a42f6091b804f6ba2b4a66b6", size = 26641 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/02/8db98cdc1a58e0abd6716d5e63244658e6e63513c65f469f34b6f1053fd0/bracex-2.5.post1-py3-none-any.whl", hash = "sha256:13e5732fec27828d6af308628285ad358047cec36801598368cb28bc631dbaf6", size = 11558 }, +] + +[[package]] +name = "bump-my-version" +version = "0.30.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "questionary" }, + { name = "rich" }, + { name = "rich-click" }, + { name = "tomlkit" }, + { name = "wcmatch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e7/4f/57eda33958c5820b462c4c262bc18dc374dca6312bbb63f95606172200cb/bump_my_version-0.30.0.tar.gz", hash = "sha256:d53e784c73abc4bb5759e296f510bc71878e1df078eb525542ec9291b5ceb195", size = 1062228 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/9b/965ad61f85cbde14694516b02dcd38ec0c5cf7132fe33a30fddb4d8b0803/bump_my_version-0.30.0-py3-none-any.whl", hash = "sha256:b0d683a1cb97fbc2f46adf8eb39ff1f0bdd72866c3583fe01f9837d6f031e5e3", size = 55257 }, +] + +[[package]] +name = "cachecontrol" +version = "0.14.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "msgpack" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b7/a4/3390ac4dfa1773f661c8780368018230e8207ec4fd3800d2c0c3adee4456/cachecontrol-0.14.2.tar.gz", hash = "sha256:7d47d19f866409b98ff6025b6a0fca8e4c791fb31abbd95f622093894ce903a2", size = 28832 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/63/baffb44ca6876e7b5fc8fe17b24a7c07bf479d604a592182db9af26ea366/cachecontrol-0.14.2-py3-none-any.whl", hash = "sha256:ebad2091bf12d0d200dfc2464330db638c5deb41d546f6d7aca079e87290f3b0", size = 21780 }, +] + +[package.optional-dependencies] +filecache = [ + { name = "filelock" }, +] + +[[package]] +name = "cached-property" +version = "2.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/76/4b/3d870836119dbe9a5e3c9a61af8cc1a8b69d75aea564572e385882d5aefb/cached_property-2.0.1.tar.gz", hash = "sha256:484d617105e3ee0e4f1f58725e72a8ef9e93deee462222dbd51cd91230897641", size = 10574 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/0e/7d8225aab3bc1a0f5811f8e1b557aa034ac04bdf641925b30d3caf586b28/cached_property-2.0.1-py3-none-any.whl", hash = "sha256:f617d70ab1100b7bcf6e42228f9ddcb78c676ffa167278d9f730d1c2fba69ccb", size = 7428 }, +] + +[[package]] +name = "cattrs" +version = "24.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/64/65/af6d57da2cb32c076319b7489ae0958f746949d407109e3ccf4d115f147c/cattrs-24.1.2.tar.gz", hash = "sha256:8028cfe1ff5382df59dd36474a86e02d817b06eaf8af84555441bac915d2ef85", size = 426462 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/d5/867e75361fc45f6de75fe277dd085627a9db5ebb511a87f27dc1396b5351/cattrs-24.1.2-py3-none-any.whl", hash = "sha256:67c7495b760168d931a10233f979b28dc04daf853b30752246f4f8471c6d68d0", size = 66446 }, +] + +[[package]] +name = "certifi" +version = "2024.12.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/bd/1d41ee578ce09523c81a15426705dd20969f5abf006d1afe8aeff0dd776a/certifi-2024.12.14.tar.gz", hash = "sha256:b650d30f370c2b724812bee08008be0c4163b163ddaec3f2546c1caf65f191db", size = 166010 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a5/32/8f6669fc4798494966bf446c8c4a162e0b5d893dff088afddf76414f70e1/certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56", size = 164927 }, +] + +[[package]] +name = "cffi" +version = "1.17.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pycparser" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fc/97/c783634659c2920c3fc70419e3af40972dbaf758daa229a7d6ea6135c90d/cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824", size = 516621 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/90/07/f44ca684db4e4f08a3fdc6eeb9a0d15dc6883efc7b8c90357fdbf74e186c/cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14", size = 182191 }, + { url = "https://files.pythonhosted.org/packages/08/fd/cc2fedbd887223f9f5d170c96e57cbf655df9831a6546c1727ae13fa977a/cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67", size = 178592 }, + { url = "https://files.pythonhosted.org/packages/de/cc/4635c320081c78d6ffc2cab0a76025b691a91204f4aa317d568ff9280a2d/cffi-1.17.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382", size = 426024 }, + { url = "https://files.pythonhosted.org/packages/b6/7b/3b2b250f3aab91abe5f8a51ada1b717935fdaec53f790ad4100fe2ec64d1/cffi-1.17.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702", size = 448188 }, + { url = "https://files.pythonhosted.org/packages/d3/48/1b9283ebbf0ec065148d8de05d647a986c5f22586b18120020452fff8f5d/cffi-1.17.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3", size = 455571 }, + { url = "https://files.pythonhosted.org/packages/40/87/3b8452525437b40f39ca7ff70276679772ee7e8b394934ff60e63b7b090c/cffi-1.17.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6", size = 436687 }, + { url = "https://files.pythonhosted.org/packages/8d/fb/4da72871d177d63649ac449aec2e8a29efe0274035880c7af59101ca2232/cffi-1.17.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17", size = 446211 }, + { url = "https://files.pythonhosted.org/packages/ab/a0/62f00bcb411332106c02b663b26f3545a9ef136f80d5df746c05878f8c4b/cffi-1.17.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8", size = 461325 }, + { url = "https://files.pythonhosted.org/packages/36/83/76127035ed2e7e27b0787604d99da630ac3123bfb02d8e80c633f218a11d/cffi-1.17.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e", size = 438784 }, + { url = "https://files.pythonhosted.org/packages/21/81/a6cd025db2f08ac88b901b745c163d884641909641f9b826e8cb87645942/cffi-1.17.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be", size = 461564 }, + { url = "https://files.pythonhosted.org/packages/f8/fe/4d41c2f200c4a457933dbd98d3cf4e911870877bd94d9656cc0fcb390681/cffi-1.17.1-cp310-cp310-win32.whl", hash = "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c", size = 171804 }, + { url = "https://files.pythonhosted.org/packages/d1/b6/0b0f5ab93b0df4acc49cae758c81fe4e5ef26c3ae2e10cc69249dfd8b3ab/cffi-1.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15", size = 181299 }, + { url = "https://files.pythonhosted.org/packages/6b/f4/927e3a8899e52a27fa57a48607ff7dc91a9ebe97399b357b85a0c7892e00/cffi-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401", size = 182264 }, + { url = "https://files.pythonhosted.org/packages/6c/f5/6c3a8efe5f503175aaddcbea6ad0d2c96dad6f5abb205750d1b3df44ef29/cffi-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf", size = 178651 }, + { url = "https://files.pythonhosted.org/packages/94/dd/a3f0118e688d1b1a57553da23b16bdade96d2f9bcda4d32e7d2838047ff7/cffi-1.17.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4", size = 445259 }, + { url = "https://files.pythonhosted.org/packages/2e/ea/70ce63780f096e16ce8588efe039d3c4f91deb1dc01e9c73a287939c79a6/cffi-1.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41", size = 469200 }, + { url = "https://files.pythonhosted.org/packages/1c/a0/a4fa9f4f781bda074c3ddd57a572b060fa0df7655d2a4247bbe277200146/cffi-1.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1", size = 477235 }, + { url = "https://files.pythonhosted.org/packages/62/12/ce8710b5b8affbcdd5c6e367217c242524ad17a02fe5beec3ee339f69f85/cffi-1.17.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6", size = 459721 }, + { url = "https://files.pythonhosted.org/packages/ff/6b/d45873c5e0242196f042d555526f92aa9e0c32355a1be1ff8c27f077fd37/cffi-1.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d", size = 467242 }, + { url = "https://files.pythonhosted.org/packages/1a/52/d9a0e523a572fbccf2955f5abe883cfa8bcc570d7faeee06336fbd50c9fc/cffi-1.17.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6", size = 477999 }, + { url = "https://files.pythonhosted.org/packages/44/74/f2a2460684a1a2d00ca799ad880d54652841a780c4c97b87754f660c7603/cffi-1.17.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f", size = 454242 }, + { url = "https://files.pythonhosted.org/packages/f8/4a/34599cac7dfcd888ff54e801afe06a19c17787dfd94495ab0c8d35fe99fb/cffi-1.17.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b", size = 478604 }, + { url = "https://files.pythonhosted.org/packages/34/33/e1b8a1ba29025adbdcda5fb3a36f94c03d771c1b7b12f726ff7fef2ebe36/cffi-1.17.1-cp311-cp311-win32.whl", hash = "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655", size = 171727 }, + { url = "https://files.pythonhosted.org/packages/3d/97/50228be003bb2802627d28ec0627837ac0bf35c90cf769812056f235b2d1/cffi-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0", size = 181400 }, +] + +[[package]] +name = "cfgv" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/11/74/539e56497d9bd1d484fd863dd69cbbfa653cd2aa27abfe35653494d85e94/cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560", size = 7114 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249 }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/16/b0/572805e227f01586461c80e0fd25d65a2115599cc9dad142fee4b747c357/charset_normalizer-3.4.1.tar.gz", hash = "sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3", size = 123188 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/58/5580c1716040bc89206c77d8f74418caf82ce519aae06450393ca73475d1/charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de", size = 198013 }, + { url = "https://files.pythonhosted.org/packages/d0/11/00341177ae71c6f5159a08168bcb98c6e6d196d372c94511f9f6c9afe0c6/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176", size = 141285 }, + { url = "https://files.pythonhosted.org/packages/01/09/11d684ea5819e5a8f5100fb0b38cf8d02b514746607934134d31233e02c8/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e218488cd232553829be0664c2292d3af2eeeb94b32bea483cf79ac6a694e037", size = 151449 }, + { url = "https://files.pythonhosted.org/packages/08/06/9f5a12939db324d905dc1f70591ae7d7898d030d7662f0d426e2286f68c9/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80ed5e856eb7f30115aaf94e4a08114ccc8813e6ed1b5efa74f9f82e8509858f", size = 143892 }, + { url = "https://files.pythonhosted.org/packages/93/62/5e89cdfe04584cb7f4d36003ffa2936681b03ecc0754f8e969c2becb7e24/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b010a7a4fd316c3c484d482922d13044979e78d1861f0e0650423144c616a46a", size = 146123 }, + { url = "https://files.pythonhosted.org/packages/a9/ac/ab729a15c516da2ab70a05f8722ecfccc3f04ed7a18e45c75bbbaa347d61/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4532bff1b8421fd0a320463030c7520f56a79c9024a4e88f01c537316019005a", size = 147943 }, + { url = "https://files.pythonhosted.org/packages/03/d2/3f392f23f042615689456e9a274640c1d2e5dd1d52de36ab8f7955f8f050/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d973f03c0cb71c5ed99037b870f2be986c3c05e63622c017ea9816881d2dd247", size = 142063 }, + { url = "https://files.pythonhosted.org/packages/f2/e3/e20aae5e1039a2cd9b08d9205f52142329f887f8cf70da3650326670bddf/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3a3bd0dcd373514dcec91c411ddb9632c0d7d92aed7093b8c3bbb6d69ca74408", size = 150578 }, + { url = "https://files.pythonhosted.org/packages/8d/af/779ad72a4da0aed925e1139d458adc486e61076d7ecdcc09e610ea8678db/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:d9c3cdf5390dcd29aa8056d13e8e99526cda0305acc038b96b30352aff5ff2bb", size = 153629 }, + { url = "https://files.pythonhosted.org/packages/c2/b6/7aa450b278e7aa92cf7732140bfd8be21f5f29d5bf334ae987c945276639/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:2bdfe3ac2e1bbe5b59a1a63721eb3b95fc9b6817ae4a46debbb4e11f6232428d", size = 150778 }, + { url = "https://files.pythonhosted.org/packages/39/f4/d9f4f712d0951dcbfd42920d3db81b00dd23b6ab520419626f4023334056/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:eab677309cdb30d047996b36d34caeda1dc91149e4fdca0b1a039b3f79d9a807", size = 146453 }, + { url = "https://files.pythonhosted.org/packages/49/2b/999d0314e4ee0cff3cb83e6bc9aeddd397eeed693edb4facb901eb8fbb69/charset_normalizer-3.4.1-cp310-cp310-win32.whl", hash = "sha256:c0429126cf75e16c4f0ad00ee0eae4242dc652290f940152ca8c75c3a4b6ee8f", size = 95479 }, + { url = "https://files.pythonhosted.org/packages/2d/ce/3cbed41cff67e455a386fb5e5dd8906cdda2ed92fbc6297921f2e4419309/charset_normalizer-3.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:9f0b8b1c6d84c8034a44893aba5e767bf9c7a211e313a9605d9c617d7083829f", size = 102790 }, + { url = "https://files.pythonhosted.org/packages/72/80/41ef5d5a7935d2d3a773e3eaebf0a9350542f2cab4eac59a7a4741fbbbbe/charset_normalizer-3.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8bfa33f4f2672964266e940dd22a195989ba31669bd84629f05fab3ef4e2d125", size = 194995 }, + { url = "https://files.pythonhosted.org/packages/7a/28/0b9fefa7b8b080ec492110af6d88aa3dea91c464b17d53474b6e9ba5d2c5/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28bf57629c75e810b6ae989f03c0828d64d6b26a5e205535585f96093e405ed1", size = 139471 }, + { url = "https://files.pythonhosted.org/packages/71/64/d24ab1a997efb06402e3fc07317e94da358e2585165930d9d59ad45fcae2/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f08ff5e948271dc7e18a35641d2f11a4cd8dfd5634f55228b691e62b37125eb3", size = 149831 }, + { url = "https://files.pythonhosted.org/packages/37/ed/be39e5258e198655240db5e19e0b11379163ad7070962d6b0c87ed2c4d39/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:234ac59ea147c59ee4da87a0c0f098e9c8d169f4dc2a159ef720f1a61bbe27cd", size = 142335 }, + { url = "https://files.pythonhosted.org/packages/88/83/489e9504711fa05d8dde1574996408026bdbdbd938f23be67deebb5eca92/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd4ec41f914fa74ad1b8304bbc634b3de73d2a0889bd32076342a573e0779e00", size = 143862 }, + { url = "https://files.pythonhosted.org/packages/c6/c7/32da20821cf387b759ad24627a9aca289d2822de929b8a41b6241767b461/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eea6ee1db730b3483adf394ea72f808b6e18cf3cb6454b4d86e04fa8c4327a12", size = 145673 }, + { url = "https://files.pythonhosted.org/packages/68/85/f4288e96039abdd5aeb5c546fa20a37b50da71b5cf01e75e87f16cd43304/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c96836c97b1238e9c9e3fe90844c947d5afbf4f4c92762679acfe19927d81d77", size = 140211 }, + { url = "https://files.pythonhosted.org/packages/28/a3/a42e70d03cbdabc18997baf4f0227c73591a08041c149e710045c281f97b/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4d86f7aff21ee58f26dcf5ae81a9addbd914115cdebcbb2217e4f0ed8982e146", size = 148039 }, + { url = "https://files.pythonhosted.org/packages/85/e4/65699e8ab3014ecbe6f5c71d1a55d810fb716bbfd74f6283d5c2aa87febf/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:09b5e6733cbd160dcc09589227187e242a30a49ca5cefa5a7edd3f9d19ed53fd", size = 151939 }, + { url = "https://files.pythonhosted.org/packages/b1/82/8e9fe624cc5374193de6860aba3ea8070f584c8565ee77c168ec13274bd2/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:5777ee0881f9499ed0f71cc82cf873d9a0ca8af166dfa0af8ec4e675b7df48e6", size = 149075 }, + { url = "https://files.pythonhosted.org/packages/3d/7b/82865ba54c765560c8433f65e8acb9217cb839a9e32b42af4aa8e945870f/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:237bdbe6159cff53b4f24f397d43c6336c6b0b42affbe857970cefbb620911c8", size = 144340 }, + { url = "https://files.pythonhosted.org/packages/b5/b6/9674a4b7d4d99a0d2df9b215da766ee682718f88055751e1e5e753c82db0/charset_normalizer-3.4.1-cp311-cp311-win32.whl", hash = "sha256:8417cb1f36cc0bc7eaba8ccb0e04d55f0ee52df06df3ad55259b9a323555fc8b", size = 95205 }, + { url = "https://files.pythonhosted.org/packages/1e/ab/45b180e175de4402dcf7547e4fb617283bae54ce35c27930a6f35b6bef15/charset_normalizer-3.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:d7f50a1f8c450f3925cb367d011448c39239bb3eb4117c36a6d354794de4ce76", size = 102441 }, + { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 }, +] + +[[package]] +name = "clang-format" +version = "19.1.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/ee/71d017fe603c06b83d6720df6b3f6f07f03abf330f39beee3fee2a067c56/clang_format-19.1.7.tar.gz", hash = "sha256:bd6fc5272a41034a7844149203461d1f311bece9ed100d22eb3eebd952a25f49", size = 11122 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/c3/2f1c53bc298c1740d0c9f8dc2d9b7030be4826b6f2aa8a04f07ef25a3d9b/clang_format-19.1.7-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:a09f34d2c89d176581858ff718c327eebc14eb6415c176dab4af5bfd8582a999", size = 1428184 }, + { url = "https://files.pythonhosted.org/packages/8e/9d/7c246a3d08105de305553d14971ed6c16cde06d20ab12d6ce7f243cf66f0/clang_format-19.1.7-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:776f89c7b056c498c0e256485bc031cbf514aaebe71e929ed54e50c478524b65", size = 1398224 }, + { url = "https://files.pythonhosted.org/packages/b1/7d/002aa5571351ee7f00f87aae5104cdd30cad1a46f25936226f7d2aed06bf/clang_format-19.1.7-py2.py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dac394c83a9233ab6707f66e1cdbd950f8b014b58604142a5b6f7998bf0bcc8c", size = 1730962 }, + { url = "https://files.pythonhosted.org/packages/1c/fe/24b7c13af432e609d65dc32c47c61f0a6c3b80d78eb7b3df37daf0395c56/clang_format-19.1.7-py2.py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bbd4f94d929edf6d8d81e990dfaafc22bb10deaefcb2762150a136f281b01c00", size = 1908820 }, + { url = "https://files.pythonhosted.org/packages/7d/a8/86595ffd6ea0bf3a3013aad94e3d55be32ef987567781eddf4621e316d09/clang_format-19.1.7-py2.py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bdcda63fffdbe2aac23b54d46408a6283ad16676a5230a95b3ed49eacd99129b", size = 2622838 }, + { url = "https://files.pythonhosted.org/packages/48/d1/731ebf78c5d5cc043c20b0755c89239350b8e75ac5d667b99689e8110bc7/clang_format-19.1.7-py2.py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c13a5802da986b1400afbee97162c29f841890ab9e20a0be7ede18189219f5f1", size = 1723352 }, + { url = "https://files.pythonhosted.org/packages/3c/e7/0e526915a3a4a23100cc721c24226a192fa0385d394019d06920dc83fe6c/clang_format-19.1.7-py2.py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f4906fb463dd2033032978f56962caab268c9428a384126b9400543eb667f11c", size = 1740347 }, + { url = "https://files.pythonhosted.org/packages/52/04/ed8e2af6b3e29655a858b3aad145f3f0539df0dd1c77815b95f578260bd3/clang_format-19.1.7-py2.py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ffca915c09aed9137f8c649ad7521bd5ce690c939121db1ba54af2ba63ac8374", size = 2675802 }, + { url = "https://files.pythonhosted.org/packages/9a/ab/7874a6f45c167f4cc4d02f517b85d14b6b5fa8412f6e9c7482588d00fccb/clang_format-19.1.7-py2.py3-none-musllinux_1_2_i686.whl", hash = "sha256:fc011dc7bbe3ac8a32e0caa37ab8ba6c1639ceef6ecd04feea8d37360fc175e4", size = 2977872 }, + { url = "https://files.pythonhosted.org/packages/46/b5/c87b6c46eb7e9d0f07e2bd56cd0a62bf7e679f146b4e1447110cfae4bd01/clang_format-19.1.7-py2.py3-none-musllinux_1_2_ppc64le.whl", hash = "sha256:afdfb11584f5a6f15127a7061673a7ea12a0393fe9ee8d2ed84e74bb191ffc3b", size = 3125795 }, + { url = "https://files.pythonhosted.org/packages/22/3e/7ea08aba446c1e838367d3c0e13eb3d2e482b23e099a25149d4f7f6b8c75/clang_format-19.1.7-py2.py3-none-musllinux_1_2_s390x.whl", hash = "sha256:6ce81d5b08e0169dc52037d3ff1802eafcaf86c281ceb8b38b8359ba7b6b7bdc", size = 3069663 }, + { url = "https://files.pythonhosted.org/packages/f5/f9/6ce7fe8ff52ded01d02a568358f2ddf993347e44202b6506b039a583b7ed/clang_format-19.1.7-py2.py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d27ac1a5a8783c9271d41cd5851766ca547ea003efa4e3764f880f319b2d3ed3", size = 2763172 }, + { url = "https://files.pythonhosted.org/packages/82/fa/77fe5636bb6b6252918bf129226a248506af218a2256deece3a9d95af850/clang_format-19.1.7-py2.py3-none-win32.whl", hash = "sha256:5dfde0be33f038114af89efb917144c2f766f8b7f3a3d3e4cb9c25f76d71ef81", size = 1243262 }, + { url = "https://files.pythonhosted.org/packages/e4/32/0b44f3582b9df0b8f90266ef43975e37ec8ad52bae4f85b71552f264d5a2/clang_format-19.1.7-py2.py3-none-win_amd64.whl", hash = "sha256:3e3c75fbdf8827bbb7277226b3057fc3785dabe7284d3a9d15fceb250f68f529", size = 1441132 }, +] + +[[package]] +name = "click" +version = "8.1.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "platform_system == 'Windows' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 }, +] + +[[package]] +name = "cmake" +version = "3.31.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/50/cb/3a327fa784a5dbaf838b135cb1729f43535c52d83bbf02191fb8a0cb118e/cmake-3.31.4.tar.gz", hash = "sha256:a6ac2242e0b16ad7d94c9f8572d6f232e6169747be50e5cdf497f206c4819ce1", size = 34278 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/db/50efa1d3e29cb2a6e8e143e522e52698b3fc08f4b56100fb35f97a70af79/cmake-3.31.4-py3-none-macosx_10_10_universal2.whl", hash = "sha256:fc048b4b70facd16699a43c737f6782b4eff56e8e6093090db5979532d9db0f6", size = 47198138 }, + { url = "https://files.pythonhosted.org/packages/c7/76/ccb8764761c739ef16bd8957a16ecbda01b03c2d7d241c376bfca6bf2822/cmake-3.31.4-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:2a37be93534df04513f0845492d71bc80899c3f87b77e3b01c95aff1a7fc9bde", size = 27556485 }, + { url = "https://files.pythonhosted.org/packages/ad/8e/888e2944655d7fa1ea5af46b60883a0e7847bbf9fb7ecc321c8e5f0a1394/cmake-3.31.4-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c9f5f8289c5e7bd2ed654cbac164021fa7723064fee0443a2f0068bc08413d81", size = 26808834 }, + { url = "https://files.pythonhosted.org/packages/59/f4/0b2b1430a441c3c09ee102bf8c5d9ec1dc11d002ff4affef15c656f37ce9/cmake-3.31.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:926d91cae2ba7d2f3df857d0fc066bdac4f3904bf5c95e99b60435e85aabedb4", size = 27140820 }, + { url = "https://files.pythonhosted.org/packages/d1/f9/a274b4e36e457d8e99db1038cc31a6c391bf3bc26230c2dc9caf37499753/cmake-3.31.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:929a8d8d289d69e43784661748ddd08933ce1ec5db8f9bcfce6ee817a48f8787", size = 28868269 }, + { url = "https://files.pythonhosted.org/packages/9b/35/8da1ffa00a3f3853881aa5025cdf11c744303013df70c8716155b83825d3/cmake-3.31.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b463efdf5b92f3b290235aa9f8da092b3dac19b7636c563fd156022dab580649", size = 30732267 }, + { url = "https://files.pythonhosted.org/packages/79/48/bb8485687f5a64d52ac68cfcb02e9b8e46a9e107f380c54d484b6632c87e/cmake-3.31.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:225d9a643b0b60ffce0399ff0cabd7a4820e0dbcb794e97d3aacfcf7c0589ae6", size = 26908885 }, + { url = "https://files.pythonhosted.org/packages/e5/9e/2594d7fa8b263296497bf044469b4ab4797c51675ea629f9672011cdfe09/cmake-3.31.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89143a5e2a5916061f2cfc5012e9fe6281aaf7c0dae7930bdc68d105d22ddc39", size = 27784555 }, + { url = "https://files.pythonhosted.org/packages/95/16/5b1989f1d2287b05cd68792c0a48b721c060f728506d719fcf0e3b80ceb2/cmake-3.31.4-py3-none-manylinux_2_31_armv7l.whl", hash = "sha256:f96127bf663168accd29d5a50ee68ea80f26bcd37f96c7a14ef2378781f19936", size = 24965366 }, + { url = "https://files.pythonhosted.org/packages/5a/4c/289fb0986c6ff63583383eca0c9479147f362330938856a9b5201c84cee8/cmake-3.31.4-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:25c5094394f0cee21130b5678e5b4552f72470e266df6d6fb1d5c505100f0eaa", size = 27824887 }, + { url = "https://files.pythonhosted.org/packages/3c/f3/d45ba2b5bb54f4ef615a6a24cf6258600eec790a9d5017c9584107b445b9/cmake-3.31.4-py3-none-musllinux_1_1_i686.whl", hash = "sha256:466c9295af440bb4a47cc5e1af10576cf2227620528afd0fd0b3effa1d513b49", size = 31368421 }, + { url = "https://files.pythonhosted.org/packages/34/3d/f6b712241ede5fb8e32c13e119c06e142f3f12ead1656721b1f67756106b/cmake-3.31.4-py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:f6af3b83a1b1fc1d990d18b6a566ee9c95c0393f986c6df15f2505dda8ad1bcc", size = 32074545 }, + { url = "https://files.pythonhosted.org/packages/f0/23/48cd0404d7238d703a4cd4d7434eeaf12e8fbe68160d52f1489f55f582df/cmake-3.31.4-py3-none-musllinux_1_1_s390x.whl", hash = "sha256:23781e17563693a68b0cef85749746894b8a61488e56e96fc6649b73652e8236", size = 27946950 }, + { url = "https://files.pythonhosted.org/packages/21/03/014d9710bccf5a7e04c6f6ee27bfaba1220e79ee145d7b95f84e7843729b/cmake-3.31.4-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:838a388b559137f3654d8cf30f62bbdec10f8d1c3624f0d289614d33cdf4fba1", size = 29473412 }, + { url = "https://files.pythonhosted.org/packages/23/de/5a8142732f0a52dedac2887e0c105c9bbb449e517ade500e56bf2af520d1/cmake-3.31.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a6a3b0b9557f41c955a6b25c94205f2ca9c3a46edca809ad87507c5ef6bc4274", size = 32971081 }, + { url = "https://files.pythonhosted.org/packages/a5/a1/50c11f0b110986c753592f025970094030b25748df126abe8e38265be722/cmake-3.31.4-py3-none-win32.whl", hash = "sha256:d378c9e58eac906bddafd673c7571262dcd5a9946bb1e8f9e3902572a8fa95ca", size = 33351393 }, + { url = "https://files.pythonhosted.org/packages/0c/7f/331d181b6b1b8942ec5fad23e98fff85218485f29f62f6bc60663d424df8/cmake-3.31.4-py3-none-win_amd64.whl", hash = "sha256:20be7cdb41903edf85e8a498c4beff8d6854acbb087abfb07c362c738bdf0018", size = 36496715 }, + { url = "https://files.pythonhosted.org/packages/65/26/11a78723364716004928b7bea7d96cf2c72dc3abfaa7c163159110fcb649/cmake-3.31.4-py3-none-win_arm64.whl", hash = "sha256:9479a9255197c49e135df039d8484c69aa63158a06ae9c2d0eb939da2f0f7dff", size = 35559239 }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, +] + +[[package]] +name = "colorlog" +version = "6.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d3/7a/359f4d5df2353f26172b3cc39ea32daa39af8de522205f512f458923e677/colorlog-6.9.0.tar.gz", hash = "sha256:bfba54a1b93b94f54e1f4fe48395725a3d92fd2a4af702f6bd70946bdc0c6ac2", size = 16624 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e3/51/9b208e85196941db2f0654ad0357ca6388ab3ed67efdbfc799f35d1f83aa/colorlog-6.9.0-py3-none-any.whl", hash = "sha256:5906e71acd67cb07a71e779c47c4bcb45fb8c2993eebe9e5adcd6a6f1b283eff", size = 11424 }, +] + +[[package]] +name = "comm" +version = "0.2.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/a8/fb783cb0abe2b5fded9f55e5703015cdf1c9c85b3669087c538dd15a6a86/comm-0.2.2.tar.gz", hash = "sha256:3fd7a84065306e07bea1773df6eb8282de51ba82f77c72f9c85716ab11fe980e", size = 6210 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e6/75/49e5bfe642f71f272236b5b2d2691cf915a7283cc0ceda56357b61daa538/comm-0.2.2-py3-none-any.whl", hash = "sha256:e6fb86cb70ff661ee8c9c14e7d36d6de3b4066f1441be4063df9c5009f0a64d3", size = 7180 }, +] + +[[package]] +name = "contourpy" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/25/c2/fc7193cc5383637ff390a712e88e4ded0452c9fbcf84abe3de5ea3df1866/contourpy-1.3.1.tar.gz", hash = "sha256:dfd97abd83335045a913e3bcc4a09c0ceadbe66580cf573fe961f4a825efa699", size = 13465753 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b2/a3/80937fe3efe0edacf67c9a20b955139a1a622730042c1ea991956f2704ad/contourpy-1.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a045f341a77b77e1c5de31e74e966537bba9f3c4099b35bf4c2e3939dd54cdab", size = 268466 }, + { url = "https://files.pythonhosted.org/packages/82/1d/e3eaebb4aa2d7311528c048350ca8e99cdacfafd99da87bc0a5f8d81f2c2/contourpy-1.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:500360b77259914f7805af7462e41f9cb7ca92ad38e9f94d6c8641b089338124", size = 253314 }, + { url = "https://files.pythonhosted.org/packages/de/f3/d796b22d1a2b587acc8100ba8c07fb7b5e17fde265a7bb05ab967f4c935a/contourpy-1.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2f926efda994cdf3c8d3fdb40b9962f86edbc4457e739277b961eced3d0b4c1", size = 312003 }, + { url = "https://files.pythonhosted.org/packages/bf/f5/0e67902bc4394daee8daa39c81d4f00b50e063ee1a46cb3938cc65585d36/contourpy-1.3.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:adce39d67c0edf383647a3a007de0a45fd1b08dedaa5318404f1a73059c2512b", size = 351896 }, + { url = "https://files.pythonhosted.org/packages/1f/d6/e766395723f6256d45d6e67c13bb638dd1fa9dc10ef912dc7dd3dcfc19de/contourpy-1.3.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:abbb49fb7dac584e5abc6636b7b2a7227111c4f771005853e7d25176daaf8453", size = 320814 }, + { url = "https://files.pythonhosted.org/packages/a9/57/86c500d63b3e26e5b73a28b8291a67c5608d4aa87ebd17bd15bb33c178bc/contourpy-1.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0cffcbede75c059f535725c1680dfb17b6ba8753f0c74b14e6a9c68c29d7ea3", size = 324969 }, + { url = "https://files.pythonhosted.org/packages/b8/62/bb146d1289d6b3450bccc4642e7f4413b92ebffd9bf2e91b0404323704a7/contourpy-1.3.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ab29962927945d89d9b293eabd0d59aea28d887d4f3be6c22deaefbb938a7277", size = 1265162 }, + { url = "https://files.pythonhosted.org/packages/18/04/9f7d132ce49a212c8e767042cc80ae390f728060d2eea47058f55b9eff1c/contourpy-1.3.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:974d8145f8ca354498005b5b981165b74a195abfae9a8129df3e56771961d595", size = 1324328 }, + { url = "https://files.pythonhosted.org/packages/46/23/196813901be3f97c83ababdab1382e13e0edc0bb4e7b49a7bff15fcf754e/contourpy-1.3.1-cp310-cp310-win32.whl", hash = "sha256:ac4578ac281983f63b400f7fe6c101bedc10651650eef012be1ccffcbacf3697", size = 173861 }, + { url = "https://files.pythonhosted.org/packages/e0/82/c372be3fc000a3b2005061ca623a0d1ecd2eaafb10d9e883a2fc8566e951/contourpy-1.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:174e758c66bbc1c8576992cec9599ce8b6672b741b5d336b5c74e35ac382b18e", size = 218566 }, + { url = "https://files.pythonhosted.org/packages/12/bb/11250d2906ee2e8b466b5f93e6b19d525f3e0254ac8b445b56e618527718/contourpy-1.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3e8b974d8db2c5610fb4e76307e265de0edb655ae8169e8b21f41807ccbeec4b", size = 269555 }, + { url = "https://files.pythonhosted.org/packages/67/71/1e6e95aee21a500415f5d2dbf037bf4567529b6a4e986594d7026ec5ae90/contourpy-1.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:20914c8c973f41456337652a6eeca26d2148aa96dd7ac323b74516988bea89fc", size = 254549 }, + { url = "https://files.pythonhosted.org/packages/31/2c/b88986e8d79ac45efe9d8801ae341525f38e087449b6c2f2e6050468a42c/contourpy-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19d40d37c1c3a4961b4619dd9d77b12124a453cc3d02bb31a07d58ef684d3d86", size = 313000 }, + { url = "https://files.pythonhosted.org/packages/c4/18/65280989b151fcf33a8352f992eff71e61b968bef7432fbfde3a364f0730/contourpy-1.3.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:113231fe3825ebf6f15eaa8bc1f5b0ddc19d42b733345eae0934cb291beb88b6", size = 352925 }, + { url = "https://files.pythonhosted.org/packages/f5/c7/5fd0146c93220dbfe1a2e0f98969293b86ca9bc041d6c90c0e065f4619ad/contourpy-1.3.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4dbbc03a40f916a8420e420d63e96a1258d3d1b58cbdfd8d1f07b49fcbd38e85", size = 323693 }, + { url = "https://files.pythonhosted.org/packages/85/fc/7fa5d17daf77306840a4e84668a48ddff09e6bc09ba4e37e85ffc8e4faa3/contourpy-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a04ecd68acbd77fa2d39723ceca4c3197cb2969633836ced1bea14e219d077c", size = 326184 }, + { url = "https://files.pythonhosted.org/packages/ef/e7/104065c8270c7397c9571620d3ab880558957216f2b5ebb7e040f85eeb22/contourpy-1.3.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c414fc1ed8ee1dbd5da626cf3710c6013d3d27456651d156711fa24f24bd1291", size = 1268031 }, + { url = "https://files.pythonhosted.org/packages/e2/4a/c788d0bdbf32c8113c2354493ed291f924d4793c4a2e85b69e737a21a658/contourpy-1.3.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:31c1b55c1f34f80557d3830d3dd93ba722ce7e33a0b472cba0ec3b6535684d8f", size = 1325995 }, + { url = "https://files.pythonhosted.org/packages/a6/e6/a2f351a90d955f8b0564caf1ebe4b1451a3f01f83e5e3a414055a5b8bccb/contourpy-1.3.1-cp311-cp311-win32.whl", hash = "sha256:f611e628ef06670df83fce17805c344710ca5cde01edfdc72751311da8585375", size = 174396 }, + { url = "https://files.pythonhosted.org/packages/a8/7e/cd93cab453720a5d6cb75588cc17dcdc08fc3484b9de98b885924ff61900/contourpy-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:b2bdca22a27e35f16794cf585832e542123296b4687f9fd96822db6bae17bfc9", size = 219787 }, + { url = "https://files.pythonhosted.org/packages/3e/4f/e56862e64b52b55b5ddcff4090085521fc228ceb09a88390a2b103dccd1b/contourpy-1.3.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:b457d6430833cee8e4b8e9b6f07aa1c161e5e0d52e118dc102c8f9bd7dd060d6", size = 265605 }, + { url = "https://files.pythonhosted.org/packages/b0/2e/52bfeeaa4541889f23d8eadc6386b442ee2470bd3cff9baa67deb2dd5c57/contourpy-1.3.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb76c1a154b83991a3cbbf0dfeb26ec2833ad56f95540b442c73950af2013750", size = 315040 }, + { url = "https://files.pythonhosted.org/packages/52/94/86bfae441707205634d80392e873295652fc313dfd93c233c52c4dc07874/contourpy-1.3.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:44a29502ca9c7b5ba389e620d44f2fbe792b1fb5734e8b931ad307071ec58c53", size = 218221 }, +] + +[[package]] +name = "coverage" +version = "7.6.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/84/ba/ac14d281f80aab516275012e8875991bb06203957aa1e19950139238d658/coverage-7.6.10.tar.gz", hash = "sha256:7fb105327c8f8f0682e29843e2ff96af9dcbe5bab8eeb4b398c6a33a16d80a23", size = 803868 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/12/2a2a923edf4ddabdffed7ad6da50d96a5c126dae7b80a33df7310e329a1e/coverage-7.6.10-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5c912978f7fbf47ef99cec50c4401340436d200d41d714c7a4766f377c5b7b78", size = 207982 }, + { url = "https://files.pythonhosted.org/packages/ca/49/6985dbca9c7be3f3cb62a2e6e492a0c88b65bf40579e16c71ae9c33c6b23/coverage-7.6.10-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a01ec4af7dfeb96ff0078ad9a48810bb0cc8abcb0115180c6013a6b26237626c", size = 208414 }, + { url = "https://files.pythonhosted.org/packages/35/93/287e8f1d1ed2646f4e0b2605d14616c9a8a2697d0d1b453815eb5c6cebdb/coverage-7.6.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3b204c11e2b2d883946fe1d97f89403aa1811df28ce0447439178cc7463448a", size = 236860 }, + { url = "https://files.pythonhosted.org/packages/de/e1/cfdb5627a03567a10031acc629b75d45a4ca1616e54f7133ca1fa366050a/coverage-7.6.10-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32ee6d8491fcfc82652a37109f69dee9a830e9379166cb73c16d8dc5c2915165", size = 234758 }, + { url = "https://files.pythonhosted.org/packages/6d/85/fc0de2bcda3f97c2ee9fe8568f7d48f7279e91068958e5b2cc19e0e5f600/coverage-7.6.10-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675cefc4c06e3b4c876b85bfb7c59c5e2218167bbd4da5075cbe3b5790a28988", size = 235920 }, + { url = "https://files.pythonhosted.org/packages/79/73/ef4ea0105531506a6f4cf4ba571a214b14a884630b567ed65b3d9c1975e1/coverage-7.6.10-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f4f620668dbc6f5e909a0946a877310fb3d57aea8198bde792aae369ee1c23b5", size = 234986 }, + { url = "https://files.pythonhosted.org/packages/c6/4d/75afcfe4432e2ad0405c6f27adeb109ff8976c5e636af8604f94f29fa3fc/coverage-7.6.10-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:4eea95ef275de7abaef630c9b2c002ffbc01918b726a39f5a4353916ec72d2f3", size = 233446 }, + { url = "https://files.pythonhosted.org/packages/86/5b/efee56a89c16171288cafff022e8af44f8f94075c2d8da563c3935212871/coverage-7.6.10-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e2f0280519e42b0a17550072861e0bc8a80a0870de260f9796157d3fca2733c5", size = 234566 }, + { url = "https://files.pythonhosted.org/packages/f2/db/67770cceb4a64d3198bf2aa49946f411b85ec6b0a9b489e61c8467a4253b/coverage-7.6.10-cp310-cp310-win32.whl", hash = "sha256:bc67deb76bc3717f22e765ab3e07ee9c7a5e26b9019ca19a3b063d9f4b874244", size = 210675 }, + { url = "https://files.pythonhosted.org/packages/8d/27/e8bfc43f5345ec2c27bc8a1fa77cdc5ce9dcf954445e11f14bb70b889d14/coverage-7.6.10-cp310-cp310-win_amd64.whl", hash = "sha256:0f460286cb94036455e703c66988851d970fdfd8acc2a1122ab7f4f904e4029e", size = 211518 }, + { url = "https://files.pythonhosted.org/packages/85/d2/5e175fcf6766cf7501a8541d81778fd2f52f4870100e791f5327fd23270b/coverage-7.6.10-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ea3c8f04b3e4af80e17bab607c386a830ffc2fb88a5484e1df756478cf70d1d3", size = 208088 }, + { url = "https://files.pythonhosted.org/packages/4b/6f/06db4dc8fca33c13b673986e20e466fd936235a6ec1f0045c3853ac1b593/coverage-7.6.10-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:507a20fc863cae1d5720797761b42d2d87a04b3e5aeb682ef3b7332e90598f43", size = 208536 }, + { url = "https://files.pythonhosted.org/packages/0d/62/c6a0cf80318c1c1af376d52df444da3608eafc913b82c84a4600d8349472/coverage-7.6.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d37a84878285b903c0fe21ac8794c6dab58150e9359f1aaebbeddd6412d53132", size = 240474 }, + { url = "https://files.pythonhosted.org/packages/a3/59/750adafc2e57786d2e8739a46b680d4fb0fbc2d57fbcb161290a9f1ecf23/coverage-7.6.10-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a534738b47b0de1995f85f582d983d94031dffb48ab86c95bdf88dc62212142f", size = 237880 }, + { url = "https://files.pythonhosted.org/packages/2c/f8/ef009b3b98e9f7033c19deb40d629354aab1d8b2d7f9cfec284dbedf5096/coverage-7.6.10-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d7a2bf79378d8fb8afaa994f91bfd8215134f8631d27eba3e0e2c13546ce994", size = 239750 }, + { url = "https://files.pythonhosted.org/packages/a6/e2/6622f3b70f5f5b59f705e680dae6db64421af05a5d1e389afd24dae62e5b/coverage-7.6.10-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6713ba4b4ebc330f3def51df1d5d38fad60b66720948112f114968feb52d3f99", size = 238642 }, + { url = "https://files.pythonhosted.org/packages/2d/10/57ac3f191a3c95c67844099514ff44e6e19b2915cd1c22269fb27f9b17b6/coverage-7.6.10-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ab32947f481f7e8c763fa2c92fd9f44eeb143e7610c4ca9ecd6a36adab4081bd", size = 237266 }, + { url = "https://files.pythonhosted.org/packages/ee/2d/7016f4ad9d553cabcb7333ed78ff9d27248ec4eba8dd21fa488254dff894/coverage-7.6.10-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:7bbd8c8f1b115b892e34ba66a097b915d3871db7ce0e6b9901f462ff3a975377", size = 238045 }, + { url = "https://files.pythonhosted.org/packages/a7/fe/45af5c82389a71e0cae4546413266d2195c3744849669b0bab4b5f2c75da/coverage-7.6.10-cp311-cp311-win32.whl", hash = "sha256:299e91b274c5c9cdb64cbdf1b3e4a8fe538a7a86acdd08fae52301b28ba297f8", size = 210647 }, + { url = "https://files.pythonhosted.org/packages/db/11/3f8e803a43b79bc534c6a506674da9d614e990e37118b4506faf70d46ed6/coverage-7.6.10-cp311-cp311-win_amd64.whl", hash = "sha256:489a01f94aa581dbd961f306e37d75d4ba16104bbfa2b0edb21d29b73be83609", size = 211508 }, + { url = "https://files.pythonhosted.org/packages/a1/70/de81bfec9ed38a64fc44a77c7665e20ca507fc3265597c28b0d989e4082e/coverage-7.6.10-pp39.pp310-none-any.whl", hash = "sha256:fd34e7b3405f0cc7ab03d54a334c17a9e802897580d964bd8c2001f4b9fd488f", size = 200223 }, +] + +[package.optional-dependencies] +toml = [ + { name = "tomli", marker = "python_full_version <= '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] + +[[package]] +name = "cssutils" +version = "2.11.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "more-itertools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/9f/329d26121fe165be44b1dfff21aa0dc348f04633931f1d20ed6cf448a236/cssutils-2.11.1.tar.gz", hash = "sha256:0563a76513b6af6eebbe788c3bf3d01c920e46b3f90c8416738c5cfc773ff8e2", size = 711657 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/ec/bb273b7208c606890dc36540fe667d06ce840a6f62f9fae7e658fcdc90fb/cssutils-2.11.1-py3-none-any.whl", hash = "sha256:a67bfdfdff4f3867fab43698ec4897c1a828eca5973f4073321b3bccaf1199b1", size = 385747 }, +] + +[[package]] +name = "cupy-cuda11x" +version = "13.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastrlock" }, + { name = "numpy" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/1b/3afbaea2b78114c82b33ecc9affc79b7d9f4899945940b9b50790c93fd33/cupy_cuda11x-13.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ef854f0c63525d8163ab7af19f503d964de9dde0dd1cf9ea806a6ecb302cdce3", size = 109578634 }, + { url = "https://files.pythonhosted.org/packages/82/94/1da4205249baa861ac848dcbc36208a0b08f2ba2c414634525e53dabf818/cupy_cuda11x-13.3.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:54bf12a6663d0471e3e37e62972add348c5263ce803688f48bbfab1b20ebdb02", size = 96619611 }, + { url = "https://files.pythonhosted.org/packages/3f/ef/6924de40b67d4a0176e9c27f1ea9b0c8700935424473afd104cf72b36eb0/cupy_cuda11x-13.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:972d133efa2af80bb8ef321858ffe7cabc3abf8f58bcc4f13541dd497c05077d", size = 76006133 }, + { url = "https://files.pythonhosted.org/packages/4d/2d/9f01f25a81535572050f77ca618a54d8ad08afc13963c9fc57c162931e42/cupy_cuda11x-13.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:766ef1558a3ed967d5f092829bfb99edbcfaf75224925e1fb1a9f531e1e79f36", size = 110899612 }, + { url = "https://files.pythonhosted.org/packages/96/8f/b92bbf066ed86ec9dbeb969a5d6e6b6597bf0bab730f9e8b4c589f7cf198/cupy_cuda11x-13.3.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:77a81fa48d1a392b731885555a53cf2febde39cc33db55f2d78ba64b5ef4689b", size = 97172154 }, + { url = "https://files.pythonhosted.org/packages/08/94/113cc947b06b45b950979441a4f12f257b203d9a33796b1dbe6b82a2c36c/cupy_cuda11x-13.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:a8e8b7f7f73677afe2f70c38562f01f82688e43147550b3e192a5a2206e17fe1", size = 75976673 }, +] + +[[package]] +name = "cupy-cuda12x" +version = "13.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastrlock" }, + { name = "numpy" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/34/60/dc268d1d9c5fdde4673a463feff5e9c70c59f477e647b54b501f65deef60/cupy_cuda12x-13.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:674488e990998042cc54d2486d3c37cae80a12ba3787636be5a10b9446dd6914", size = 103601326 }, + { url = "https://files.pythonhosted.org/packages/7a/a9/1e19ecf008011df2935d038f26f721f22f2804c00077fc024f088e0996e6/cupy_cuda12x-13.3.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:cf4a2a0864364715881b50012927e88bd7ec1e6f1de3987970870861ae5ed25e", size = 90619949 }, + { url = "https://files.pythonhosted.org/packages/ce/6b/e77e3fc20648d323021f55d4e0fafc5572eff50c37750d6aeae868e110d8/cupy_cuda12x-13.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:7c0dc8c49d271d1c03e49a5d6c8e42e8fee3114b10f269a5ecc387731d693eaa", size = 69594183 }, + { url = "https://files.pythonhosted.org/packages/95/c9/0b88c015e98aad808c18f938267585d79e6211fe08650e0de7132e235e40/cupy_cuda12x-13.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:c0cc095b9a3835fd5db66c45ed3c58ecdc5a3bb14e53e1defbfd4a0ce5c8ecdb", size = 104925909 }, + { url = "https://files.pythonhosted.org/packages/8c/1f/596803c35833c01a41da21c6a7bb552f1ed56d807090ddc6727c8f396d7d/cupy_cuda12x-13.3.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:a0e3bead04e502ebde515f0343444ca3f4f7aed09cbc3a316a946cba97f2ea66", size = 91172049 }, + { url = "https://files.pythonhosted.org/packages/d0/a8/5b5929830d2da94608d8126bafe2c52d69929a197fd8698ac09142c068ba/cupy_cuda12x-13.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:5f11df1149c7219858b27e4c8be92cb4eaf7364c94af6b78c40dffb98050a61f", size = 69564719 }, +] + +[[package]] +name = "cupy-rocm-4-3" +version = "13.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastrlock" }, + { name = "numpy" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/16/7fd4bc8a8f1a4697f76e52c13f348f284fcc5c37195efd7e4c5d0eb2b15c/cupy_rocm_4_3-13.3.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:fc6b93be093bcea8b820baed856b61efc5c8cb09b02ebdc890431655714366ad", size = 41259087 }, + { url = "https://files.pythonhosted.org/packages/2e/ee/e893b0fdc6b347d8d65024442e5baf5ae13ee92c1364152e8f343906793d/cupy_rocm_4_3-13.3.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:f5e6886f1750810ddc3d261adf84d98b4d42f1d3cb2be5b7f5da181c8bf1593d", size = 41775360 }, +] + +[[package]] +name = "cupy-rocm-5-0" +version = "13.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastrlock" }, + { name = "numpy" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/2e/6e4ecd65f5158808a54ef75d90fc7a884afb55bd405c4a7dbc34bb4a8f96/cupy_rocm_5_0-13.3.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:d4c370441f7778b00f3ab80d6f0d669ea0215b6e96bbed9663ecce7ffce83fa9", size = 60056031 }, + { url = "https://files.pythonhosted.org/packages/08/52/8b5b6b32c84616989a2a84f02d9f4ca39d812de9f630276a664f321840bf/cupy_rocm_5_0-13.3.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:00907762735d182737bee317f532dc381337fb8e978bd846acb268df463b2d7b", size = 60576552 }, +] + +[[package]] +name = "cycler" +version = "0.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a9/95/a3dbbb5028f35eafb79008e7522a75244477d2838f38cbb722248dabc2a8/cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c", size = 7615 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30", size = 8321 }, +] + +[[package]] +name = "cython" +version = "3.0.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/84/4d/b720d6000f4ca77f030bd70f12550820f0766b568e43f11af7f7ad9061aa/cython-3.0.11.tar.gz", hash = "sha256:7146dd2af8682b4ca61331851e6aebce9fe5158e75300343f80c07ca80b1faff", size = 2755544 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/13/7f/ab5796a0951328d7818b771c36fe7e1a2077cffa28c917d9fa4a642728c3/Cython-3.0.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:44292aae17524abb4b70a25111fe7dec1a0ad718711d47e3786a211d5408fdaa", size = 3100879 }, + { url = "https://files.pythonhosted.org/packages/d8/3b/67480e609537e9fc899864847910ded481b82d033fea1b7fcf85893a2fc4/Cython-3.0.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a75d45fbc20651c1b72e4111149fed3b33d270b0a4fb78328c54d965f28d55e1", size = 3461957 }, + { url = "https://files.pythonhosted.org/packages/f0/89/b1ae45689abecca777f95462781a76e67ff46b55495a481ec5a73a739994/Cython-3.0.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d89a82937ce4037f092e9848a7bbcc65bc8e9fc9aef2bb74f5c15e7d21a73080", size = 3627062 }, + { url = "https://files.pythonhosted.org/packages/44/77/a651da74d5d41c6045bbe0b6990b1515bf4850cd7a8d8580333c90dfce2e/Cython-3.0.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a8ea2e7e2d3bc0d8630dafe6c4a5a89485598ff8a61885b74f8ed882597efd5", size = 3680431 }, + { url = "https://files.pythonhosted.org/packages/59/45/60e7e8db93c3eb8b2af8c64020c1fa502e355f4b762886a24d46e433f395/Cython-3.0.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cee29846471ce60226b18e931d8c1c66a158db94853e3e79bc2da9bd22345008", size = 3497314 }, + { url = "https://files.pythonhosted.org/packages/f8/0b/6919025958926625319f83523ee7f45e7e7ae516b8054dcff6eb710daf32/Cython-3.0.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:eeb6860b0f4bfa402de8929833fe5370fa34069c7ebacb2d543cb017f21fb891", size = 3709091 }, + { url = "https://files.pythonhosted.org/packages/52/3c/c21b9b9271dfaa46fa2938de730f62fc94b9c2ec25ec400585e372f35dcd/Cython-3.0.11-cp310-cp310-win32.whl", hash = "sha256:3699391125ab344d8d25438074d1097d9ba0fb674d0320599316cfe7cf5f002a", size = 2576110 }, + { url = "https://files.pythonhosted.org/packages/f9/de/19fdd1c7a52e0534bf5f544e0346c15d71d20338dbd013117f763b94613f/Cython-3.0.11-cp310-cp310-win_amd64.whl", hash = "sha256:d02f4ebe15aac7cdacce1a628e556c1983f26d140fd2e0ac5e0a090e605a2d38", size = 2776386 }, + { url = "https://files.pythonhosted.org/packages/f8/73/e55be864199cd674cb3426a052726c205589b1ac66fb0090e7fe793b60b3/Cython-3.0.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:75ba1c70b6deeaffbac123856b8d35f253da13552207aa969078611c197377e4", size = 3113599 }, + { url = "https://files.pythonhosted.org/packages/09/c9/537108d0980beffff55336baaf8b34162ad0f3f33ededcb5db07069bc8ef/Cython-3.0.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af91497dc098718e634d6ec8f91b182aea6bb3690f333fc9a7777bc70abe8810", size = 3441131 }, + { url = "https://files.pythonhosted.org/packages/93/03/e330b241ad8aa12bb9d98b58fb76d4eb7dcbe747479aab5c29fce937b9e7/Cython-3.0.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3999fb52d3328a6a5e8c63122b0a8bd110dfcdb98dda585a3def1426b991cba7", size = 3595065 }, + { url = "https://files.pythonhosted.org/packages/4a/84/a3c40f2c0439d425daa5aa4e3a6fdbbb41341a14a6fd97f94906f528d9a4/Cython-3.0.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d566a4e09b8979be8ab9f843bac0dd216c81f5e5f45661a9b25cd162ed80508c", size = 3641667 }, + { url = "https://files.pythonhosted.org/packages/6d/93/bdb61e0254ed8f1d21a14088a473584ecb1963d68dba5682158aa45c70ef/Cython-3.0.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:46aec30f217bdf096175a1a639203d44ac73a36fe7fa3dd06bd012e8f39eca0f", size = 3503650 }, + { url = "https://files.pythonhosted.org/packages/f8/62/0da548144c71176155ff5355c4cc40fb28b9effe22e830b55cec8072bdf2/Cython-3.0.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ddd1fe25af330f4e003421636746a546474e4ccd8f239f55d2898d80983d20ed", size = 3709662 }, + { url = "https://files.pythonhosted.org/packages/56/d3/d9c9eaf3611a9fe5256266d07b6a5f9069aa84d20d9f6aa5824289513315/Cython-3.0.11-cp311-cp311-win32.whl", hash = "sha256:221de0b48bf387f209003508e602ce839a80463522fc6f583ad3c8d5c890d2c1", size = 2577870 }, + { url = "https://files.pythonhosted.org/packages/fd/10/236fcc0306f85a2db1b8bc147aea714b66a2f27bac4d9e09e5b2c5d5dcca/Cython-3.0.11-cp311-cp311-win_amd64.whl", hash = "sha256:3ff8ac1f0ecd4f505db4ab051e58e4531f5d098b6ac03b91c3b902e8d10c67b3", size = 2785053 }, + { url = "https://files.pythonhosted.org/packages/43/39/bdbec9142bc46605b54d674bf158a78b191c2b75be527c6dcf3e6dfe90b8/Cython-3.0.11-py2.py3-none-any.whl", hash = "sha256:0e25f6425ad4a700d7f77cd468da9161e63658837d1bc34861a9861a4ef6346d", size = 1171267 }, +] + +[[package]] +name = "cytoolz" +version = "1.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "toolz" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a7/f9/3243eed3a6545c2a33a21f74f655e3fcb5d2192613cd3db81a93369eb339/cytoolz-1.0.1.tar.gz", hash = "sha256:89cc3161b89e1bb3ed7636f74ed2e55984fd35516904fc878cae216e42b2c7d6", size = 626652 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/d9/f13d66c16cff1fa1cb6c234698029877c456f35f577ef274aba3b86e7c51/cytoolz-1.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cec9af61f71fc3853eb5dca3d42eb07d1f48a4599fa502cbe92adde85f74b042", size = 403515 }, + { url = "https://files.pythonhosted.org/packages/4b/2d/4cdf848a69300c7d44984f2ebbebb3b8576e5449c8dea157298f3bdc4da3/cytoolz-1.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:140bbd649dbda01e91add7642149a5987a7c3ccc251f2263de894b89f50b6608", size = 383936 }, + { url = "https://files.pythonhosted.org/packages/72/a4/ccfdd3f0ed9cc818f734b424261f6018fc61e3ec833bf85225a9aca0d994/cytoolz-1.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e90124bdc42ff58b88cdea1d24a6bc5f776414a314cc4d94f25c88badb3a16d1", size = 1934569 }, + { url = "https://files.pythonhosted.org/packages/50/fc/38d5344fa595683ad10dc819cfc1d8b9d2b3391ccf3e8cb7bab4899a01f5/cytoolz-1.0.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e74801b751e28f7c5cc3ad264c123954a051f546f2fdfe089f5aa7a12ccfa6da", size = 2015129 }, + { url = "https://files.pythonhosted.org/packages/28/29/75261748dc54a20a927f33641f4e9aac674cfc6d3fbd4f332e10d0b37639/cytoolz-1.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:582dad4545ddfb5127494ef23f3fa4855f1673a35d50c66f7638e9fb49805089", size = 2000506 }, + { url = "https://files.pythonhosted.org/packages/00/ae/e4ead004cc2698281d153c4a5388638d67cdb5544d6d6cc1e5b3db2bd2a3/cytoolz-1.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd7bd0618e16efe03bd12f19c2a26a27e6e6b75d7105adb7be1cd2a53fa755d8", size = 1957537 }, + { url = "https://files.pythonhosted.org/packages/4a/ff/4f3aa07f4f47701f7f63df60ce0a5669fa09c256c3d4a33503a9414ea5cc/cytoolz-1.0.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d74cca6acf1c4af58b2e4a89cc565ed61c5e201de2e434748c93e5a0f5c541a5", size = 1863331 }, + { url = "https://files.pythonhosted.org/packages/a2/29/654f57f2a9b8e9765a4ab876765f64f94530b61fc6471a07feea42ece6d4/cytoolz-1.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:823a3763828d8d457f542b2a45d75d6b4ced5e470b5c7cf2ed66a02f508ed442", size = 1849938 }, + { url = "https://files.pythonhosted.org/packages/bc/7b/11f457db6b291060a98315ab2c7198077d8bddeeebe5f7126d9dad98cc54/cytoolz-1.0.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:51633a14e6844c61db1d68c1ffd077cf949f5c99c60ed5f1e265b9e2966f1b52", size = 1852345 }, + { url = "https://files.pythonhosted.org/packages/6b/92/0dccc96ce0323be236d404f5084479b79b747fa0e74e43a270e95868b5f9/cytoolz-1.0.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:f3ec9b01c45348f1d0d712507d54c2bfd69c62fbd7c9ef555c9d8298693c2432", size = 1989877 }, + { url = "https://files.pythonhosted.org/packages/a3/c8/1c5203a81200bae51aa8f7b5fad613f695bf1afa03f16251ca23ecb2ef9f/cytoolz-1.0.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:1855022b712a9c7a5bce354517ab4727a38095f81e2d23d3eabaf1daeb6a3b3c", size = 1994492 }, + { url = "https://files.pythonhosted.org/packages/e2/8a/04bc193c4d7ced8ef6bb62cdcd0bf40b5e5eb26586ed2cfb4433ec7dfd0a/cytoolz-1.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9930f7288c4866a1dc1cc87174f0c6ff4cad1671eb1f6306808aa6c445857d78", size = 1896077 }, + { url = "https://files.pythonhosted.org/packages/21/a5/bee63a58f51d2c74856db66e6119a014464ff8cb1c9387fa4bd2d94e49b0/cytoolz-1.0.1-cp310-cp310-win32.whl", hash = "sha256:a9baad795d72fadc3445ccd0f122abfdbdf94269157e6d6d4835636dad318804", size = 322135 }, + { url = "https://files.pythonhosted.org/packages/e8/16/7abfb1685e8b7f2838264551ee33651748994813f566ac4c3d737dfe90e5/cytoolz-1.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:ad95b386a84e18e1f6136f6d343d2509d4c3aae9f5a536f3dc96808fcc56a8cf", size = 363599 }, + { url = "https://files.pythonhosted.org/packages/dc/ea/8131ae39119820b8867cddc23716fa9f681f2b3bbce6f693e68dfb36b55b/cytoolz-1.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2d958d4f04d9d7018e5c1850790d9d8e68b31c9a2deebca74b903706fdddd2b6", size = 406162 }, + { url = "https://files.pythonhosted.org/packages/26/18/3d9bd4c146f6ea6e51300c242b20cb416966b21d481dac230e1304f1e54b/cytoolz-1.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0f445b8b731fc0ecb1865b8e68a070084eb95d735d04f5b6c851db2daf3048ab", size = 384961 }, + { url = "https://files.pythonhosted.org/packages/e4/73/9034827907c7f85c7c484c9494e905d022fb8174526004e9ef332570349e/cytoolz-1.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f546a96460a7e28eb2ec439f4664fa646c9b3e51c6ebad9a59d3922bbe65e30", size = 2091698 }, + { url = "https://files.pythonhosted.org/packages/74/af/d5c2733b0fde1a08254ff1a8a8d567874040c9eb1606363cfebc0713c73f/cytoolz-1.0.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0317681dd065532d21836f860b0563b199ee716f55d0c1f10de3ce7100c78a3b", size = 2188452 }, + { url = "https://files.pythonhosted.org/packages/6a/bb/77c71fa9c217260b4056a732d754748903423c2cdd82a673d6064741e375/cytoolz-1.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0c0ef52febd5a7821a3fd8d10f21d460d1a3d2992f724ba9c91fbd7a96745d41", size = 2174203 }, + { url = "https://files.pythonhosted.org/packages/fc/a9/a5b4a3ff5d22faa1b60293bfe97362e2caf4a830c26d37ab5557f60d04b2/cytoolz-1.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5ebaf419acf2de73b643cf96108702b8aef8e825cf4f63209ceb078d5fbbbfd", size = 2099831 }, + { url = "https://files.pythonhosted.org/packages/35/08/7f6869ea1ff31ce5289a7d58d0e7090acfe7058baa2764473048ff61ea3c/cytoolz-1.0.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5f7f04eeb4088947585c92d6185a618b25ad4a0f8f66ea30c8db83cf94a425e3", size = 1996744 }, + { url = "https://files.pythonhosted.org/packages/46/b4/9ac424c994b51763fd1bbed62d95f8fba8fa0e45c8c3c583904fdaf8f51d/cytoolz-1.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:f61928803bb501c17914b82d457c6f50fe838b173fb40d39c38d5961185bd6c7", size = 2013733 }, + { url = "https://files.pythonhosted.org/packages/3e/99/03009765c4b87d742d5b5a8670abb56a8c7ede033c2cdaa4be8662d3b001/cytoolz-1.0.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:d2960cb4fa01ccb985ad1280db41f90dc97a80b397af970a15d5a5de403c8c61", size = 1994850 }, + { url = "https://files.pythonhosted.org/packages/40/9a/8458af9a5557e177ea42f8cf7e477bede518b0bbef564e28c4151feaa52c/cytoolz-1.0.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:b2b407cc3e9defa8df5eb46644f6f136586f70ba49eba96f43de67b9a0984fd3", size = 2155352 }, + { url = "https://files.pythonhosted.org/packages/5e/5c/2a701423e001fcbec288b4f3fc2bf67557d114c2388237fc1ae67e1e2686/cytoolz-1.0.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:8245f929144d4d3bd7b972c9593300195c6cea246b81b4c46053c48b3f044580", size = 2163515 }, + { url = "https://files.pythonhosted.org/packages/36/16/ee2e06e65d9d533bc05cd52a0b355ba9072fc8f60d77289e529c6d2e3750/cytoolz-1.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e37385db03af65763933befe89fa70faf25301effc3b0485fec1c15d4ce4f052", size = 2054431 }, + { url = "https://files.pythonhosted.org/packages/d8/d5/2fac8315f210fa1bc7106e27c19e1211580aa25bb7fa17dfd79505e5baf2/cytoolz-1.0.1-cp311-cp311-win32.whl", hash = "sha256:50f9c530f83e3e574fc95c264c3350adde8145f4f8fc8099f65f00cc595e5ead", size = 322004 }, + { url = "https://files.pythonhosted.org/packages/a9/9e/0b70b641850a95f9ff90adde9d094a4b1d81ec54dadfd97fec0a2aaf440e/cytoolz-1.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:b7f6b617454b4326af7bd3c7c49b0fc80767f134eb9fd6449917a058d17a0e3c", size = 365358 }, + { url = "https://files.pythonhosted.org/packages/d9/f7/ef2a10daaec5c0f7d781d50758c6187eee484256e356ae8ef178d6c48497/cytoolz-1.0.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:83d19d55738ad9c60763b94f3f6d3c6e4de979aeb8d76841c1401081e0e58d96", size = 345702 }, + { url = "https://files.pythonhosted.org/packages/c8/14/53c84adddedb67ff1546abb86fea04d26e24298c3ceab8436d20122ed0b9/cytoolz-1.0.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f112a71fad6ea824578e6393765ce5c054603afe1471a5c753ff6c67fd872d10", size = 385695 }, + { url = "https://files.pythonhosted.org/packages/bd/80/3ae356c5e7b8d7dc7d1adb52f6932fee85cd748ed4e1217c269d2dfd610f/cytoolz-1.0.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a515df8f8aa6e1eaaf397761a6e4aff2eef73b5f920aedf271416d5471ae5ee", size = 406261 }, + { url = "https://files.pythonhosted.org/packages/0c/31/8e43761ffc82d90bf9cab7e0959712eedcd1e33c211397e143dd42d7af57/cytoolz-1.0.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92c398e7b7023460bea2edffe5fcd0a76029580f06c3f6938ac3d198b47156f3", size = 397207 }, + { url = "https://files.pythonhosted.org/packages/d1/b9/fe9da37090b6444c65f848a83e390f87d8cb43d6a4df46de1556ad7e5ceb/cytoolz-1.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:3237e56211e03b13df47435b2369f5df281e02b04ad80a948ebd199b7bc10a47", size = 343358 }, +] + +[[package]] +name = "dace" +version = "1.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aenum" }, + { name = "astunparse" }, + { name = "dill" }, + { name = "fparser" }, + { name = "networkx" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "ply" }, + { name = "pyreadline", marker = "platform_system == 'Windows' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "pyyaml" }, + { name = "sympy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b4/67/fb1be2673868ee1f08e9c7bacc0b9b77d2bd5ff17ab47896f20006a2a1a5/dace-1.0.1.tar.gz", hash = "sha256:6f7a5defb082ed4f1a81f857d4268ed2bb606f6d9ea9c28d2831d1151e3a80f7", size = 5801727 } + +[[package]] +name = "debugpy" +version = "1.8.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/68/25/c74e337134edf55c4dfc9af579eccb45af2393c40960e2795a94351e8140/debugpy-1.8.12.tar.gz", hash = "sha256:646530b04f45c830ceae8e491ca1c9320a2d2f0efea3141487c82130aba70dce", size = 1641122 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/19/dd58334c0a1ec07babf80bf29fb8daf1a7ca4c1a3bbe61548e40616ac087/debugpy-1.8.12-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:a2ba7ffe58efeae5b8fad1165357edfe01464f9aef25e814e891ec690e7dd82a", size = 2076091 }, + { url = "https://files.pythonhosted.org/packages/4c/37/bde1737da15f9617d11ab7b8d5267165f1b7dae116b2585a6643e89e1fa2/debugpy-1.8.12-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbbd4149c4fc5e7d508ece083e78c17442ee13b0e69bfa6bd63003e486770f45", size = 3560717 }, + { url = "https://files.pythonhosted.org/packages/d9/ca/bc67f5a36a7de072908bc9e1156c0f0b272a9a2224cf21540ab1ffd71a1f/debugpy-1.8.12-cp310-cp310-win32.whl", hash = "sha256:b202f591204023b3ce62ff9a47baa555dc00bb092219abf5caf0e3718ac20e7c", size = 5180672 }, + { url = "https://files.pythonhosted.org/packages/c1/b9/e899c0a80dfa674dbc992f36f2b1453cd1ee879143cdb455bc04fce999da/debugpy-1.8.12-cp310-cp310-win_amd64.whl", hash = "sha256:9649eced17a98ce816756ce50433b2dd85dfa7bc92ceb60579d68c053f98dff9", size = 5212702 }, + { url = "https://files.pythonhosted.org/packages/af/9f/5b8af282253615296264d4ef62d14a8686f0dcdebb31a669374e22fff0a4/debugpy-1.8.12-cp311-cp311-macosx_14_0_universal2.whl", hash = "sha256:36f4829839ef0afdfdd208bb54f4c3d0eea86106d719811681a8627ae2e53dd5", size = 2174643 }, + { url = "https://files.pythonhosted.org/packages/ef/31/f9274dcd3b0f9f7d1e60373c3fa4696a585c55acb30729d313bb9d3bcbd1/debugpy-1.8.12-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a28ed481d530e3138553be60991d2d61103ce6da254e51547b79549675f539b7", size = 3133457 }, + { url = "https://files.pythonhosted.org/packages/ab/ca/6ee59e9892e424477e0c76e3798046f1fd1288040b927319c7a7b0baa484/debugpy-1.8.12-cp311-cp311-win32.whl", hash = "sha256:4ad9a94d8f5c9b954e0e3b137cc64ef3f579d0df3c3698fe9c3734ee397e4abb", size = 5106220 }, + { url = "https://files.pythonhosted.org/packages/d5/1a/8ab508ab05ede8a4eae3b139bbc06ea3ca6234f9e8c02713a044f253be5e/debugpy-1.8.12-cp311-cp311-win_amd64.whl", hash = "sha256:4703575b78dd697b294f8c65588dc86874ed787b7348c65da70cfc885efdf1e1", size = 5130481 }, + { url = "https://files.pythonhosted.org/packages/38/c4/5120ad36405c3008f451f94b8f92ef1805b1e516f6ff870f331ccb3c4cc0/debugpy-1.8.12-py2.py3-none-any.whl", hash = "sha256:274b6a2040349b5c9864e475284bce5bb062e63dce368a394b8cc865ae3b00c6", size = 5229490 }, +] + +[[package]] +name = "decorator" +version = "5.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/66/0c/8d907af351aa16b42caae42f9d6aa37b900c67308052d10fdce809f8d952/decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330", size = 35016 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/50/83c593b07763e1161326b3b8c6686f0f4b0f24d5526546bee538c89837d6/decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186", size = 9073 }, +] + +[[package]] +name = "deepdiff" +version = "8.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "orderly-set" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/50/4b/ce2d3a36f77186d7dbca0f10b33e6a1c0eee390d9434960d2a14e2736b52/deepdiff-8.1.1.tar.gz", hash = "sha256:dd7bc7d5c8b51b5b90f01b0e2fe23c801fd8b4c6a7ee7e31c5a3c3663fcc7ceb", size = 433560 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/66/f7/2df72b55635926872b947203aacbe7e1109a51929aec8ebfef8c4a348eb5/deepdiff-8.1.1-py3-none-any.whl", hash = "sha256:b0231fa3afb0f7184e82535f2b4a36636442ed21e94a0cf3aaa7982157e7ebca", size = 84655 }, +] + +[[package]] +name = "devtools" +version = "0.12.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asttokens" }, + { name = "executing" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/84/75/b78198620640d394bc435c17bb49db18419afdd6cfa3ed8bcfe14034ec80/devtools-0.12.2.tar.gz", hash = "sha256:efceab184cb35e3a11fa8e602cc4fadacaa2e859e920fc6f87bf130b69885507", size = 75005 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/ae/afb1487556e2dc827a17097aac8158a25b433a345386f0e249f6d2694ccb/devtools-0.12.2-py3-none-any.whl", hash = "sha256:c366e3de1df4cdd635f1ad8cbcd3af01a384d7abda71900e68d43b04eb6aaca7", size = 19411 }, +] + +[[package]] +name = "dict2css" +version = "0.3.0.post1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cssutils" }, + { name = "domdf-python-tools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/24/eb/776eef1f1aa0188c0fc165c3a60b71027539f71f2eedc43ad21b060e9c39/dict2css-0.3.0.post1.tar.gz", hash = "sha256:89c544c21c4ca7472c3fffb9d37d3d926f606329afdb751dc1de67a411b70719", size = 7845 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/47/290daabcf91628f4fc0e17c75a1690b354ba067066cd14407712600e609f/dict2css-0.3.0.post1-py3-none-any.whl", hash = "sha256:f006a6b774c3e31869015122ae82c491fd25e7de4a75607a62aa3e798f837e0d", size = 25647 }, +] + +[[package]] +name = "dill" +version = "0.3.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/70/43/86fe3f9e130c4137b0f1b50784dd70a5087b911fe07fa81e53e0c4c47fea/dill-0.3.9.tar.gz", hash = "sha256:81aa267dddf68cbfe8029c42ca9ec6a4ab3b22371d1c450abc54422577b4512c", size = 187000 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/d1/e73b6ad76f0b1fb7f23c35c6d95dbc506a9c8804f43dda8cb5b0fa6331fd/dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a", size = 119418 }, +] + +[[package]] +name = "diskcache" +version = "5.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3f/21/1c1ffc1a039ddcc459db43cc108658f32c57d271d7289a2794e401d0fdb6/diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc", size = 67916 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/27/4570e78fc0bf5ea0ca45eb1de3818a23787af9b390c0b0a0033a1b8236f9/diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19", size = 45550 }, +] + +[[package]] +name = "distlib" +version = "0.3.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/dd/1bec4c5ddb504ca60fc29472f3d27e8d4da1257a854e1d96742f15c1d02d/distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403", size = 613923 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973 }, +] + +[[package]] +name = "docutils" +version = "0.21.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ae/ed/aefcc8cd0ba62a0560c3c18c33925362d46c6075480bfa4df87b28e169a9/docutils-0.21.2.tar.gz", hash = "sha256:3a6b18732edf182daa3cd12775bbb338cf5691468f91eeeb109deff6ebfa986f", size = 2204444 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/d7/9322c609343d929e75e7e5e6255e614fcc67572cfd083959cdef3b7aad79/docutils-0.21.2-py3-none-any.whl", hash = "sha256:dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2", size = 587408 }, +] + +[[package]] +name = "domdf-python-tools" +version = "3.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "natsort" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6b/78/974e10c583ba9d2302e748c9585313a7f2c7ba00e4f600324f432e38fe68/domdf_python_tools-3.9.0.tar.gz", hash = "sha256:1f8a96971178333a55e083e35610d7688cd7620ad2b99790164e1fc1a3614c18", size = 103792 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/e9/7447a88b217650a74927d3444a89507986479a69b83741900eddd34167fe/domdf_python_tools-3.9.0-py3-none-any.whl", hash = "sha256:4e1ef365cbc24627d6d1e90cf7d46d8ab8df967e1237f4a26885f6986c78872e", size = 127106 }, +] + +[[package]] +name = "esbonio" +version = "0.16.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "platformdirs" }, + { name = "pygls" }, + { name = "pyspellchecker" }, + { name = "sphinx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/67/c5/0c89af3da1f3133b53f3ba8ae677ed4d4ddff33eec50dbf32c95e01ed2d2/esbonio-0.16.5.tar.gz", hash = "sha256:acab2e16c6cf8f7232fb04e0d48514ce50566516b1f6fcf669ccf2f247e8b10f", size = 145347 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d8/ca/a0296fca375d4324f471bb34d2ce8a585b48fb9eae21cf9abe00913eb899/esbonio-0.16.5-py3-none-any.whl", hash = "sha256:04ba926e3603f7b1fde1abc690b47afd60749b64b1029b6bce8e1de0bb284921", size = 170830 }, +] + +[[package]] +name = "exceptiongroup" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/09/35/2495c4ac46b980e4ca1f6ad6db102322ef3ad2410b79fdde159a4b0f3b92/exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc", size = 28883 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453 }, +] + +[[package]] +name = "execnet" +version = "2.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bb/ff/b4c0dc78fbe20c3e59c0c7334de0c27eb4001a2b2017999af398bf730817/execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3", size = 166524 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/09/2aea36ff60d16dd8879bdb2f5b3ee0ba8d08cbbdcdfe870e695ce3784385/execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc", size = 40612 }, +] + +[[package]] +name = "executing" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/91/50/a9d80c47ff289c611ff12e63f7c5d13942c65d68125160cefd768c73e6e4/executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755", size = 978693 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 }, +] + +[[package]] +name = "factory-boy" +version = "3.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "faker" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/99/3d/8070dde623341401b1c80156583d4c793058fe250450178218bb6e45526c/factory_boy-3.3.1.tar.gz", hash = "sha256:8317aa5289cdfc45f9cae570feb07a6177316c82e34d14df3c2e1f22f26abef0", size = 163924 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/cf/44ec67152f3129d0114c1499dd34f0a0a0faf43d9c2af05bc535746ca482/factory_boy-3.3.1-py2.py3-none-any.whl", hash = "sha256:7b1113c49736e1e9995bc2a18f4dbf2c52cf0f841103517010b1d825712ce3ca", size = 36878 }, +] + +[[package]] +name = "faker" +version = "35.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d5/18/86fe668976308d09e0178041c3756e646a1f5ddc676aa7fb0cf3cd52f5b9/faker-35.0.0.tar.gz", hash = "sha256:42f2da8cf561e38c72b25e9891168b1e25fec42b6b0b5b0b6cd6041da54af885", size = 1855098 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/fe/40452fb1730b10afa34dfe016097b28baa070ad74a1c1a3512ebed438c08/Faker-35.0.0-py3-none-any.whl", hash = "sha256:926d2301787220e0554c2e39afc4dc535ce4b0a8d0a089657137999f66334ef4", size = 1894841 }, +] + +[[package]] +name = "fastjsonschema" +version = "2.21.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/50/4b769ce1ac4071a1ef6d86b1a3fb56cdc3a37615e8c5519e1af96cdac366/fastjsonschema-2.21.1.tar.gz", hash = "sha256:794d4f0a58f848961ba16af7b9c85a3e88cd360df008c59aac6fc5ae9323b5d4", size = 373939 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/90/2b/0817a2b257fe88725c25589d89aec060581aabf668707a8d03b2e9e0cb2a/fastjsonschema-2.21.1-py3-none-any.whl", hash = "sha256:c9e5b7e908310918cf494a434eeb31384dd84a98b57a30bcb1f535015b554667", size = 23924 }, +] + +[[package]] +name = "fastrlock" +version = "0.8.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/73/b1/1c3d635d955f2b4bf34d45abf8f35492e04dbd7804e94ce65d9f928ef3ec/fastrlock-0.8.3.tar.gz", hash = "sha256:4af6734d92eaa3ab4373e6c9a1dd0d5ad1304e172b1521733c6c3b3d73c8fa5d", size = 79327 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/02/3f771177380d8690812d5b2b7736dc6b6c8cd1c317e4572e65f823eede08/fastrlock-0.8.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:cc5fa9166e05409f64a804d5b6d01af670979cdb12cd2594f555cb33cdc155bd", size = 55094 }, + { url = "https://files.pythonhosted.org/packages/be/b4/aae7ed94b8122c325d89eb91336084596cebc505dc629b795fcc9629606d/fastrlock-0.8.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:7a77ebb0a24535ef4f167da2c5ee35d9be1e96ae192137e9dc3ff75b8dfc08a5", size = 48220 }, + { url = "https://files.pythonhosted.org/packages/96/87/9807af47617fdd65c68b0fcd1e714542c1d4d3a1f1381f591f1aa7383a53/fastrlock-0.8.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:d51f7fb0db8dab341b7f03a39a3031678cf4a98b18533b176c533c122bfce47d", size = 49551 }, + { url = "https://files.pythonhosted.org/packages/9d/12/e201634810ac9aee59f93e3953cb39f98157d17c3fc9d44900f1209054e9/fastrlock-0.8.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:767ec79b7f6ed9b9a00eb9ff62f2a51f56fdb221c5092ab2dadec34a9ccbfc6e", size = 49398 }, + { url = "https://files.pythonhosted.org/packages/15/a1/439962ed439ff6f00b7dce14927e7830e02618f26f4653424220a646cd1c/fastrlock-0.8.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0d6a77b3f396f7d41094ef09606f65ae57feeb713f4285e8e417f4021617ca62", size = 53334 }, + { url = "https://files.pythonhosted.org/packages/b5/9e/1ae90829dd40559ab104e97ebe74217d9da794c4bb43016da8367ca7a596/fastrlock-0.8.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:92577ff82ef4a94c5667d6d2841f017820932bc59f31ffd83e4a2c56c1738f90", size = 52495 }, + { url = "https://files.pythonhosted.org/packages/e5/8c/5e746ee6f3d7afbfbb0d794c16c71bfd5259a4e3fb1dda48baf31e46956c/fastrlock-0.8.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3df8514086e16bb7c66169156a8066dc152f3be892c7817e85bf09a27fa2ada2", size = 51972 }, + { url = "https://files.pythonhosted.org/packages/76/a7/8b91068f00400931da950f143fa0f9018bd447f8ed4e34bed3fe65ed55d2/fastrlock-0.8.3-cp310-cp310-win_amd64.whl", hash = "sha256:001fd86bcac78c79658bac496e8a17472d64d558cd2227fdc768aa77f877fe40", size = 30946 }, + { url = "https://files.pythonhosted.org/packages/90/9e/647951c579ef74b6541493d5ca786d21a0b2d330c9514ba2c39f0b0b0046/fastrlock-0.8.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:f68c551cf8a34b6460a3a0eba44bd7897ebfc820854e19970c52a76bf064a59f", size = 55233 }, + { url = "https://files.pythonhosted.org/packages/be/91/5f3afba7d14b8b7d60ac651375f50fff9220d6ccc3bef233d2bd74b73ec7/fastrlock-0.8.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:55d42f6286b9d867370af4c27bc70d04ce2d342fe450c4a4fcce14440514e695", size = 48911 }, + { url = "https://files.pythonhosted.org/packages/d5/7a/e37bd72d7d70a8a551b3b4610d028bd73ff5d6253201d5d3cf6296468bee/fastrlock-0.8.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:bbc3bf96dcbd68392366c477f78c9d5c47e5d9290cb115feea19f20a43ef6d05", size = 50357 }, + { url = "https://files.pythonhosted.org/packages/0d/ef/a13b8bab8266840bf38831d7bf5970518c02603d00a548a678763322d5bf/fastrlock-0.8.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:77ab8a98417a1f467dafcd2226718f7ca0cf18d4b64732f838b8c2b3e4b55cb5", size = 50222 }, + { url = "https://files.pythonhosted.org/packages/01/e2/5e5515562b2e9a56d84659377176aef7345da2c3c22909a1897fe27e14dd/fastrlock-0.8.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:04bb5eef8f460d13b8c0084ea5a9d3aab2c0573991c880c0a34a56bb14951d30", size = 54553 }, + { url = "https://files.pythonhosted.org/packages/c0/8f/65907405a8cdb2fc8beaf7d09a9a07bb58deff478ff391ca95be4f130b70/fastrlock-0.8.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8c9d459ce344c21ff03268212a1845aa37feab634d242131bc16c2a2355d5f65", size = 53362 }, + { url = "https://files.pythonhosted.org/packages/ec/b9/ae6511e52738ba4e3a6adb7c6a20158573fbc98aab448992ece25abb0b07/fastrlock-0.8.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:33e6fa4af4f3af3e9c747ec72d1eadc0b7ba2035456c2afb51c24d9e8a56f8fd", size = 52836 }, + { url = "https://files.pythonhosted.org/packages/88/3e/c26f8192c93e8e43b426787cec04bb46ac36e72b1033b7fe5a9267155fdf/fastrlock-0.8.3-cp311-cp311-win_amd64.whl", hash = "sha256:5e5f1665d8e70f4c5b4a67f2db202f354abc80a321ce5a26ac1493f055e3ae2c", size = 31046 }, +] + +[[package]] +name = "filelock" +version = "3.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/dc/9c/0b15fb47b464e1b663b1acd1253a062aa5feecb07d4e597daea542ebd2b5/filelock-3.17.0.tar.gz", hash = "sha256:ee4e77401ef576ebb38cd7f13b9b28893194acc20a8e68e18730ba9c0e54660e", size = 18027 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/ec/00d68c4ddfedfe64159999e5f8a98fb8442729a63e2077eb9dcd89623d27/filelock-3.17.0-py3-none-any.whl", hash = "sha256:533dc2f7ba78dc2f0f531fc6c4940addf7b70a481e269a5a3b93be94ffbe8338", size = 16164 }, +] + +[[package]] +name = "fonttools" +version = "4.55.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/55/55/3b1566c6186a5e58a17a19ad63195f87c6ca4039ef10ff5318a1b9fc5639/fonttools-4.55.7.tar.gz", hash = "sha256:6899e3d97225a8218f525e9754da0376e1c62953a0d57a76c5abaada51e0d140", size = 3458372 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/5c/ce2fce845af9696d043ac912f15b9fac4b9002fcd9ff66b80aa513a6c43f/fonttools-4.55.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:c2680a3e6e2e2d104a7ea81fb89323e1a9122c23b03d6569d0768887d0d76e69", size = 2752048 }, + { url = "https://files.pythonhosted.org/packages/07/9b/f7f9409adcf22763263c6327d2d31d538babd9ad2d63d1732c9e85d60a78/fonttools-4.55.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a7831d16c95b60866772a15fdcc03772625c4bb6d858e0ad8ef3d6e48709b2ef", size = 2280495 }, + { url = "https://files.pythonhosted.org/packages/91/df/348cf4ff1becd63ed952e35e436de3f9fd3245edb74c070457b465c40a58/fonttools-4.55.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:833927d089e6585019f2c85e3f8f7d87733e3fe81cd704ebaca7afa27e2e7113", size = 4561947 }, + { url = "https://files.pythonhosted.org/packages/14/fe/48b808bdf14bb9467e4a5aaa8aa89f8aba9979d52be3f7f1962f065e933e/fonttools-4.55.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7858dc6823296a053d85b831fa8428781c6c6f06fca44582bf7b6b2ff32a9089", size = 4604618 }, + { url = "https://files.pythonhosted.org/packages/52/25/305d88761aa15a8b2761869a15db34c070e72756d166a163756c53d07b35/fonttools-4.55.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:05568a66b090ed9d79aefdce2ceb180bb64fc856961deaedc29f5ad51355ce2c", size = 4558896 }, + { url = "https://files.pythonhosted.org/packages/0c/0b/c6f7877611940ab75dbe50f035d16ca5ce6d9ff2e5e65b9c76da830286ff/fonttools-4.55.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2dbc08e227fbeb716776905a7bd3c4fc62c8e37c8ef7d481acd10cb5fde12222", size = 4728347 }, + { url = "https://files.pythonhosted.org/packages/43/2c/490223b8cfaeccdef3d8819945a455aa8cc57f12f49233a3d40556b739cc/fonttools-4.55.7-cp310-cp310-win32.whl", hash = "sha256:6eb93cbba484a463b5ee83f7dd3211905f27a3871d20d90fb72de84c6c5056e3", size = 2155437 }, + { url = "https://files.pythonhosted.org/packages/37/f8/ee47526b3f03596cbed9dc7f38519cb650e7769bf9365e04bd81ff4a5302/fonttools-4.55.7-cp310-cp310-win_amd64.whl", hash = "sha256:7ff8e606f905048dc91a55a06d994b68065bf35752ae199df54a9bf30013dcaa", size = 2199898 }, + { url = "https://files.pythonhosted.org/packages/07/cb/f1dd2e31553bd03dcb4eb3af1ac6acc7fe41f26067d1bba104005ec1bb04/fonttools-4.55.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:916e1d926823b4b3b3815c59fc79f4ed670696fdd5fd9a5e690a0503eef38f79", size = 2753201 }, + { url = "https://files.pythonhosted.org/packages/21/84/f9f82093789947547b4bc86242669cde816ef4d949b23f472e47e85f125d/fonttools-4.55.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b89da448e0073408d7b2c44935f9fdae4fdc93644899f99f6102ef883ecf083c", size = 2281418 }, + { url = "https://files.pythonhosted.org/packages/46/e1/e0398d2aa7bf5400c84650fc7d85708502289bb92a40f8090e6e71cfe315/fonttools-4.55.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:087ace2d06894ccdb03e6975d05da6bb9cec0c689b2a9983c059880e33a1464a", size = 4869132 }, + { url = "https://files.pythonhosted.org/packages/d4/2d/9d86cd653c758334285a5c95d1bc0a7f13b6a72fc674c6b33fef3b8e3f77/fonttools-4.55.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:775ed0700ee6f781436641f18a0c61b1846a8c1aecae6da6b395c4417e2cb567", size = 4898375 }, + { url = "https://files.pythonhosted.org/packages/48/ce/f49fccb7d9f7c9c6d239434fc48546a0b37a91ba8310c7bcd5127cfeb5f6/fonttools-4.55.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9ec71d0cc0242899f87e4c230ed0b22c7b8681f288fb80e3d81c2c54c5bd2c79", size = 4877574 }, + { url = "https://files.pythonhosted.org/packages/cc/85/afe73e96a1572ba0acc86e82d52554bf69f384b431acd7a15b8c3890833b/fonttools-4.55.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d4b1c5939c0521525f45522823508e6fad21175bca978583688ea3b3736e6625", size = 5045681 }, + { url = "https://files.pythonhosted.org/packages/b8/37/dc59bc5a2f049d39b62996c806c147ae2eee5316f047a37bcf4cb9dbc4ef/fonttools-4.55.7-cp311-cp311-win32.whl", hash = "sha256:23df0f1003abaf8a435543f59583fc247e7ae1b047ee2263510e0654a5f207e0", size = 2154302 }, + { url = "https://files.pythonhosted.org/packages/86/33/281989403a57945c7871df144af3512ad3d1cd223e025b08b7f377847e6d/fonttools-4.55.7-cp311-cp311-win_amd64.whl", hash = "sha256:82163d58b43eff6e2025a25c32905fdb9042a163cc1ff82dab393e7ffc77a7d5", size = 2200818 }, + { url = "https://files.pythonhosted.org/packages/7b/6d/304a16caf63a8c193ec387b1fae1cb10072a59d34549f2eefe7e3fa9f364/fonttools-4.55.7-py3-none-any.whl", hash = "sha256:3304dfcf9ca204dd0ef691a287bd851ddd8e8250108658c0677c3fdfec853a20", size = 1089677 }, +] + +[[package]] +name = "fparser" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools-scm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f6/af/570c36d7bc374646ab82f579e2bf9d24a619cc53d83f95b38b0992de3492/fparser-0.2.0.tar.gz", hash = "sha256:3901d31c104062c4e532248286929e7405e43b79a6a85815146a176673e69c82", size = 433559 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/91/03999b30650f5621dd5ec9e8245024dea1b71c4e28e52e0c7300aa0c769d/fparser-0.2.0-py3-none-any.whl", hash = "sha256:49fab105e3a977b9b9d5d4489649287c5060e94c688f9936f3d5af3a45d6f4eb", size = 639408 }, +] + +[[package]] +name = "frozendict" +version = "2.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bb/59/19eb300ba28e7547538bdf603f1c6c34793240a90e1a7b61b65d8517e35e/frozendict-2.4.6.tar.gz", hash = "sha256:df7cd16470fbd26fc4969a208efadc46319334eb97def1ddf48919b351192b8e", size = 316416 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/7f/e80cdbe0db930b2ba9d46ca35a41b0150156da16dfb79edcc05642690c3b/frozendict-2.4.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c3a05c0a50cab96b4bb0ea25aa752efbfceed5ccb24c007612bc63e51299336f", size = 37927 }, + { url = "https://files.pythonhosted.org/packages/29/98/27e145ff7e8e63caa95fb8ee4fc56c68acb208bef01a89c3678a66f9a34d/frozendict-2.4.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f5b94d5b07c00986f9e37a38dd83c13f5fe3bf3f1ccc8e88edea8fe15d6cd88c", size = 37945 }, + { url = "https://files.pythonhosted.org/packages/ac/f1/a10be024a9d53441c997b3661ea80ecba6e3130adc53812a4b95b607cdd1/frozendict-2.4.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4c789fd70879ccb6289a603cdebdc4953e7e5dea047d30c1b180529b28257b5", size = 117656 }, + { url = "https://files.pythonhosted.org/packages/46/a6/34c760975e6f1cb4db59a990d58dcf22287e10241c851804670c74c6a27a/frozendict-2.4.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da6a10164c8a50b34b9ab508a9420df38f4edf286b9ca7b7df8a91767baecb34", size = 117444 }, + { url = "https://files.pythonhosted.org/packages/62/dd/64bddd1ffa9617f50e7e63656b2a7ad7f0a46c86b5f4a3d2c714d0006277/frozendict-2.4.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:9a8a43036754a941601635ea9c788ebd7a7efbed2becba01b54a887b41b175b9", size = 116801 }, + { url = "https://files.pythonhosted.org/packages/45/ae/af06a8bde1947277aad895c2f26c3b8b8b6ee9c0c2ad988fb58a9d1dde3f/frozendict-2.4.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c9905dcf7aa659e6a11b8051114c9fa76dfde3a6e50e6dc129d5aece75b449a2", size = 117329 }, + { url = "https://files.pythonhosted.org/packages/d2/df/be3fa0457ff661301228f4c59c630699568c8ed9b5480f113b3eea7d0cb3/frozendict-2.4.6-cp310-cp310-win_amd64.whl", hash = "sha256:323f1b674a2cc18f86ab81698e22aba8145d7a755e0ac2cccf142ee2db58620d", size = 37522 }, + { url = "https://files.pythonhosted.org/packages/4a/6f/c22e0266b4c85f58b4613fec024e040e93753880527bf92b0c1bc228c27c/frozendict-2.4.6-cp310-cp310-win_arm64.whl", hash = "sha256:eabd21d8e5db0c58b60d26b4bb9839cac13132e88277e1376970172a85ee04b3", size = 34056 }, + { url = "https://files.pythonhosted.org/packages/04/13/d9839089b900fa7b479cce495d62110cddc4bd5630a04d8469916c0e79c5/frozendict-2.4.6-py311-none-any.whl", hash = "sha256:d065db6a44db2e2375c23eac816f1a022feb2fa98cbb50df44a9e83700accbea", size = 16148 }, +] + +[[package]] +name = "gitdb" +version = "4.0.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "smmap" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794 }, +] + +[[package]] +name = "gitpython" +version = "3.1.44" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gitdb" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/89/37df0b71473153574a5cdef8f242de422a0f5d26d7a9e231e6f169b4ad14/gitpython-3.1.44.tar.gz", hash = "sha256:c87e30b26253bf5418b01b0660f818967f3c503193838337fe5e573331249269", size = 214196 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/9a/4114a9057db2f1462d5c8f8390ab7383925fe1ac012eaa42402ad65c2963/GitPython-3.1.44-py3-none-any.whl", hash = "sha256:9e0e10cda9bed1ee64bc9a6de50e7e38a9c9943241cd7f585f6df3ed28011110", size = 207599 }, +] + +[[package]] +name = "gridtools-cpp" +version = "2.3.8" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/b8/352120417da7a3e16cc822e95668e1843d0cd9ee7f0269b9a098893471cc/gridtools_cpp-2.3.8-py3-none-any.whl", hash = "sha256:d9cb8aadc5dca7e864677072de15596feb883844eee2158ab108d04f2f17f355", size = 420716 }, +] + +[[package]] +name = "gt4py" +version = "1.0.4" +source = { editable = "." } +dependencies = [ + { name = "attrs" }, + { name = "black" }, + { name = "boltons" }, + { name = "cached-property" }, + { name = "click" }, + { name = "cmake" }, + { name = "cytoolz" }, + { name = "deepdiff" }, + { name = "devtools" }, + { name = "diskcache" }, + { name = "factory-boy" }, + { name = "filelock" }, + { name = "frozendict" }, + { name = "gridtools-cpp" }, + { name = "jinja2" }, + { name = "lark" }, + { name = "mako" }, + { name = "nanobind" }, + { name = "ninja" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pybind11" }, + { name = "setuptools" }, + { name = "tabulate" }, + { name = "toolz" }, + { name = "typing-extensions" }, + { name = "xxhash" }, +] + +[package.optional-dependencies] +all = [ + { name = "clang-format" }, + { name = "dace" }, + { name = "hypothesis" }, + { name = "jax" }, + { name = "pytest" }, + { name = "scipy" }, +] +cuda11 = [ + { name = "cupy-cuda11x" }, +] +cuda12 = [ + { name = "cupy-cuda12x" }, +] +dace = [ + { name = "dace" }, +] +formatting = [ + { name = "clang-format" }, +] +jax = [ + { name = "jax" }, +] +jax-cuda12 = [ + { name = "cupy-cuda12x" }, + { name = "jax", extra = ["cuda12-local"] }, +] +performance = [ + { name = "scipy" }, +] +rocm4-3 = [ + { name = "cupy-rocm-4-3" }, +] +rocm5-0 = [ + { name = "cupy-rocm-5-0" }, +] +testing = [ + { name = "hypothesis" }, + { name = "pytest" }, +] + +[package.dev-dependencies] +build = [ + { name = "bump-my-version" }, + { name = "cython" }, + { name = "pip" }, + { name = "setuptools" }, + { name = "wheel" }, +] +dev = [ + { name = "atlas4py" }, + { name = "bump-my-version" }, + { name = "coverage", extra = ["toml"] }, + { name = "cython" }, + { name = "esbonio" }, + { name = "hypothesis" }, + { name = "jupytext" }, + { name = "matplotlib" }, + { name = "mypy", extra = ["faster-cache"] }, + { name = "myst-parser" }, + { name = "nbmake" }, + { name = "nox" }, + { name = "pip" }, + { name = "pre-commit" }, + { name = "pygments" }, + { name = "pytest" }, + { name = "pytest-benchmark" }, + { name = "pytest-cache" }, + { name = "pytest-cov" }, + { name = "pytest-factoryboy" }, + { name = "pytest-instafail" }, + { name = "pytest-xdist", extra = ["psutil"] }, + { name = "ruff" }, + { name = "setuptools" }, + { name = "sphinx" }, + { name = "sphinx-rtd-theme" }, + { name = "sphinx-toolbox" }, + { name = "tach" }, + { name = "types-decorator" }, + { name = "types-docutils" }, + { name = "types-pytz" }, + { name = "types-pyyaml" }, + { name = "types-tabulate" }, + { name = "wheel" }, +] +docs = [ + { name = "esbonio" }, + { name = "jupytext" }, + { name = "matplotlib" }, + { name = "myst-parser" }, + { name = "pygments" }, + { name = "sphinx" }, + { name = "sphinx-rtd-theme" }, + { name = "sphinx-toolbox" }, +] +frameworks = [ + { name = "atlas4py" }, +] +lint = [ + { name = "pre-commit" }, + { name = "ruff" }, + { name = "tach" }, +] +test = [ + { name = "coverage", extra = ["toml"] }, + { name = "hypothesis" }, + { name = "nbmake" }, + { name = "nox" }, + { name = "pytest" }, + { name = "pytest-benchmark" }, + { name = "pytest-cache" }, + { name = "pytest-cov" }, + { name = "pytest-factoryboy" }, + { name = "pytest-instafail" }, + { name = "pytest-xdist", extra = ["psutil"] }, +] +typing = [ + { name = "mypy", extra = ["faster-cache"] }, + { name = "types-decorator" }, + { name = "types-docutils" }, + { name = "types-pytz" }, + { name = "types-pyyaml" }, + { name = "types-tabulate" }, +] + +[package.metadata] +requires-dist = [ + { name = "attrs", specifier = ">=21.3" }, + { name = "black", specifier = ">=22.3" }, + { name = "boltons", specifier = ">=20.1" }, + { name = "cached-property", specifier = ">=1.5.1" }, + { name = "clang-format", marker = "extra == 'formatting'", specifier = ">=9.0" }, + { name = "click", specifier = ">=8.0.0" }, + { name = "cmake", specifier = ">=3.22" }, + { name = "cupy-cuda11x", marker = "extra == 'cuda11'", specifier = ">=12.0" }, + { name = "cupy-cuda12x", marker = "extra == 'cuda12'", specifier = ">=12.0" }, + { name = "cupy-rocm-4-3", marker = "extra == 'rocm4-3'", specifier = ">=13.3.0" }, + { name = "cupy-rocm-5-0", marker = "extra == 'rocm5-0'", specifier = ">=13.3.0" }, + { name = "cytoolz", specifier = ">=0.12.1" }, + { name = "dace", marker = "extra == 'dace'", specifier = ">=1.0.0,<1.1.0" }, + { name = "deepdiff", specifier = ">=5.6.0" }, + { name = "devtools", specifier = ">=0.6" }, + { name = "diskcache", specifier = ">=5.6.3" }, + { name = "factory-boy", specifier = ">=3.3.0" }, + { name = "filelock", specifier = ">=3.16.1" }, + { name = "frozendict", specifier = ">=2.3" }, + { name = "gridtools-cpp", specifier = "==2.*,>=2.3.8" }, + { name = "gt4py", extras = ["cuda12"], marker = "extra == 'jax-cuda12'" }, + { name = "gt4py", extras = ["dace", "formatting", "jax", "performance", "testing"], marker = "extra == 'all'" }, + { name = "hypothesis", marker = "extra == 'testing'", specifier = ">=6.0.0" }, + { name = "jax", marker = "extra == 'jax'", specifier = ">=0.4.26" }, + { name = "jax", extras = ["cuda12-local"], marker = "extra == 'jax-cuda12'", specifier = ">=0.4.26" }, + { name = "jinja2", specifier = ">=3.0.0" }, + { name = "lark", specifier = ">=1.1.2" }, + { name = "mako", specifier = ">=1.1" }, + { name = "nanobind", specifier = ">=1.4.0" }, + { name = "ninja", specifier = ">=1.10" }, + { name = "numpy", specifier = ">=1.23.3" }, + { name = "packaging", specifier = ">=20.0" }, + { name = "pybind11", specifier = ">=2.10.1" }, + { name = "pytest", marker = "extra == 'testing'", specifier = ">=7.0" }, + { name = "scipy", marker = "extra == 'performance'", specifier = ">=1.9.2" }, + { name = "setuptools", specifier = ">=70.0.0" }, + { name = "tabulate", specifier = ">=0.8.10" }, + { name = "toolz", specifier = ">=0.12.1" }, + { name = "typing-extensions", specifier = ">=4.11.0" }, + { name = "xxhash", specifier = ">=1.4.4,<3.1.0" }, +] + +[package.metadata.requires-dev] +build = [ + { name = "bump-my-version", specifier = ">=0.16.0" }, + { name = "cython", specifier = ">=3.0.0" }, + { name = "pip", specifier = ">=22.1.1" }, + { name = "setuptools", specifier = ">=70.0.0" }, + { name = "wheel", specifier = ">=0.33.6" }, +] +dev = [ + { name = "atlas4py", specifier = ">=0.35", index = "https://test.pypi.org/simple/" }, + { name = "bump-my-version", specifier = ">=0.16.0" }, + { name = "coverage", extras = ["toml"], specifier = ">=7.5.0" }, + { name = "cython", specifier = ">=3.0.0" }, + { name = "esbonio", specifier = ">=0.16.0" }, + { name = "hypothesis", specifier = ">=6.0.0" }, + { name = "jupytext", specifier = ">=1.14" }, + { name = "matplotlib", specifier = ">=3.3" }, + { name = "mypy", extras = ["faster-cache"], specifier = ">=1.13.0" }, + { name = "myst-parser", specifier = ">=4.0.0" }, + { name = "nbmake", specifier = ">=1.4.6" }, + { name = "nox", specifier = ">=2024.10.9" }, + { name = "pip", specifier = ">=22.1.1" }, + { name = "pre-commit", specifier = ">=4.0.1" }, + { name = "pygments", specifier = ">=2.7.3" }, + { name = "pytest", specifier = ">=8.0.1" }, + { name = "pytest-benchmark", specifier = ">=5.0.0" }, + { name = "pytest-cache", specifier = ">=1.0" }, + { name = "pytest-cov", specifier = ">=5.0.0" }, + { name = "pytest-factoryboy", specifier = ">=2.6.1" }, + { name = "pytest-instafail", specifier = ">=0.5.0" }, + { name = "pytest-xdist", extras = ["psutil"], specifier = ">=3.5.0" }, + { name = "ruff", specifier = ">=0.8.0" }, + { name = "setuptools", specifier = ">=70.0.0" }, + { name = "sphinx", specifier = ">=7.3.7" }, + { name = "sphinx-rtd-theme", specifier = ">=3.0.1" }, + { name = "sphinx-toolbox", specifier = ">=3.8.1" }, + { name = "tach", specifier = ">=0.16.0" }, + { name = "types-decorator", specifier = ">=5.1.8" }, + { name = "types-docutils", specifier = ">=0.21.0" }, + { name = "types-pytz", specifier = ">=2024.2.0" }, + { name = "types-pyyaml", specifier = ">=6.0.10" }, + { name = "types-tabulate", specifier = ">=0.8.10" }, + { name = "wheel", specifier = ">=0.33.6" }, +] +docs = [ + { name = "esbonio", specifier = ">=0.16.0" }, + { name = "jupytext", specifier = ">=1.14" }, + { name = "matplotlib", specifier = ">=3.3" }, + { name = "myst-parser", specifier = ">=4.0.0" }, + { name = "pygments", specifier = ">=2.7.3" }, + { name = "sphinx", specifier = ">=7.3.7" }, + { name = "sphinx-rtd-theme", specifier = ">=3.0.1" }, + { name = "sphinx-toolbox", specifier = ">=3.8.1" }, +] +frameworks = [{ name = "atlas4py", specifier = ">=0.35", index = "https://test.pypi.org/simple/" }] +lint = [ + { name = "pre-commit", specifier = ">=4.0.1" }, + { name = "ruff", specifier = ">=0.8.0" }, + { name = "tach", specifier = ">=0.16.0" }, +] +test = [ + { name = "coverage", extras = ["toml"], specifier = ">=7.5.0" }, + { name = "hypothesis", specifier = ">=6.0.0" }, + { name = "nbmake", specifier = ">=1.4.6" }, + { name = "nox", specifier = ">=2024.10.9" }, + { name = "pytest", specifier = ">=8.0.1" }, + { name = "pytest-benchmark", specifier = ">=5.0.0" }, + { name = "pytest-cache", specifier = ">=1.0" }, + { name = "pytest-cov", specifier = ">=5.0.0" }, + { name = "pytest-factoryboy", specifier = ">=2.6.1" }, + { name = "pytest-instafail", specifier = ">=0.5.0" }, + { name = "pytest-xdist", extras = ["psutil"], specifier = ">=3.5.0" }, +] +typing = [ + { name = "mypy", extras = ["faster-cache"], specifier = ">=1.13.0" }, + { name = "types-decorator", specifier = ">=5.1.8" }, + { name = "types-docutils", specifier = ">=0.21.0" }, + { name = "types-pytz", specifier = ">=2024.2.0" }, + { name = "types-pyyaml", specifier = ">=6.0.10" }, + { name = "types-tabulate", specifier = ">=0.8.10" }, +] + +[[package]] +name = "html5lib" +version = "1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, + { name = "webencodings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ac/b6/b55c3f49042f1df3dcd422b7f224f939892ee94f22abcf503a9b7339eaf2/html5lib-1.1.tar.gz", hash = "sha256:b2e5b40261e20f354d198eae92afc10d750afb487ed5e50f9c4eaf07c184146f", size = 272215 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6c/dd/a834df6482147d48e225a49515aabc28974ad5a4ca3215c18a882565b028/html5lib-1.1-py2.py3-none-any.whl", hash = "sha256:0d78f8fde1c230e99fe37986a60526d7049ed4bf8a9fadbad5f00e22e58e041d", size = 112173 }, +] + +[[package]] +name = "hypothesis" +version = "6.124.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "sortedcontainers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6a/ef/6e3736663ee67369f7f5b697674bfbd3efc91e7096ddd4452bbbc80065ff/hypothesis-6.124.7.tar.gz", hash = "sha256:8ed6c6ae47e7d26d869c1dc3dee04e8fc50c95240715bb9915ded88d6d920f0e", size = 416938 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/03/48/2412d4aacf1c50882126910ce036c92a838784915e3de66fb603a75c05ec/hypothesis-6.124.7-py3-none-any.whl", hash = "sha256:a6e1f66de84de3152d57f595a187a123ce3ecdea9dc8ef51ff8dcaa069137085", size = 479518 }, +] + +[[package]] +name = "identify" +version = "2.6.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/82/bf/c68c46601bacd4c6fb4dd751a42b6e7087240eaabc6487f2ef7a48e0e8fc/identify-2.6.6.tar.gz", hash = "sha256:7bec12768ed44ea4761efb47806f0a41f86e7c0a5fdf5950d4648c90eca7e251", size = 99217 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/74/a1/68a395c17eeefb04917034bd0a1bfa765e7654fa150cca473d669aa3afb5/identify-2.6.6-py2.py3-none-any.whl", hash = "sha256:cbd1810bce79f8b671ecb20f53ee0ae8e86ae84b557de31d89709dc2a48ba881", size = 99083 }, +] + +[[package]] +name = "idna" +version = "3.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 }, +] + +[[package]] +name = "imagesize" +version = "1.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a7/84/62473fb57d61e31fef6e36d64a179c8781605429fd927b5dd608c997be31/imagesize-1.4.1.tar.gz", hash = "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a", size = 1280026 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ff/62/85c4c919272577931d407be5ba5d71c20f0b616d31a0befe0ae45bb79abd/imagesize-1.4.1-py2.py3-none-any.whl", hash = "sha256:0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b", size = 8769 }, +] + +[[package]] +name = "inflection" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e1/7e/691d061b7329bc8d54edbf0ec22fbfb2afe61facb681f9aaa9bff7a27d04/inflection-0.5.1.tar.gz", hash = "sha256:1a29730d366e996aaacffb2f1f1cb9593dc38e2ddd30c91250c6dde09ea9b417", size = 15091 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/59/91/aa6bde563e0085a02a435aa99b49ef75b0a4b062635e606dab23ce18d720/inflection-0.5.1-py2.py3-none-any.whl", hash = "sha256:f38b2b640938a4f35ade69ac3d053042959b62a0f1076a5bbaa1b9526605a8a2", size = 9454 }, +] + +[[package]] +name = "iniconfig" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, +] + +[[package]] +name = "ipykernel" +version = "6.29.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "appnope", marker = "platform_system == 'Darwin' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "comm" }, + { name = "debugpy" }, + { name = "ipython" }, + { name = "jupyter-client" }, + { name = "jupyter-core" }, + { name = "matplotlib-inline" }, + { name = "nest-asyncio" }, + { name = "packaging" }, + { name = "psutil" }, + { name = "pyzmq" }, + { name = "tornado" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/5c/67594cb0c7055dc50814b21731c22a601101ea3b1b50a9a1b090e11f5d0f/ipykernel-6.29.5.tar.gz", hash = "sha256:f093a22c4a40f8828f8e330a9c297cb93dcab13bd9678ded6de8e5cf81c56215", size = 163367 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/5c/368ae6c01c7628438358e6d337c19b05425727fbb221d2a3c4303c372f42/ipykernel-6.29.5-py3-none-any.whl", hash = "sha256:afdb66ba5aa354b09b91379bac28ae4afebbb30e8b39510c9690afb7a10421b5", size = 117173 }, +] + +[[package]] +name = "ipython" +version = "8.31.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "decorator" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "jedi" }, + { name = "matplotlib-inline" }, + { name = "pexpect", marker = "(sys_platform != 'emscripten' and sys_platform != 'win32') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "prompt-toolkit" }, + { name = "pygments" }, + { name = "stack-data" }, + { name = "traitlets" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/01/35/6f90fdddff7a08b7b715fccbd2427b5212c9525cd043d26fdc45bee0708d/ipython-8.31.0.tar.gz", hash = "sha256:b6a2274606bec6166405ff05e54932ed6e5cfecaca1fc05f2cacde7bb074d70b", size = 5501011 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/60/d0feb6b6d9fe4ab89fe8fe5b47cbf6cd936bfd9f1e7ffa9d0015425aeed6/ipython-8.31.0-py3-none-any.whl", hash = "sha256:46ec58f8d3d076a61d128fe517a51eb730e3aaf0c184ea8c17d16e366660c6a6", size = 821583 }, +] + +[[package]] +name = "jax" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jaxlib" }, + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "opt-einsum" }, + { name = "scipy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4a/cb/22d62b26284f08e62d6eb64603d3b010004cfdb7a97ce6cca5c6cf86edab/jax-0.5.0.tar.gz", hash = "sha256:49df70bf293a345a7fb519f71193506d37a024c4f850b358042eb32d502c81c8", size = 1959707 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/58/cc0721a1030fcbab0984beea0bf3c4610ec103f738423cdfa9c4ceb40598/jax-0.5.0-py3-none-any.whl", hash = "sha256:b3907aa87ae2c340b39cdbf80c07a74550369cafcaf7398fb60ba58d167345ab", size = 2270365 }, +] + +[package.optional-dependencies] +cuda12-local = [ + { name = "jax-cuda12-plugin" }, + { name = "jaxlib" }, +] + +[[package]] +name = "jax-cuda12-pjrt" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/a6/4b161016aaafe04d92e8d9a50b47e6767ea5cf874a8a9d2d1bcd049409d3/jax_cuda12_pjrt-0.5.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:6025cd4b32d8ec04a11705a749764cd96a6cbc8b6273beac947cc481f2584b8c", size = 89441461 }, + { url = "https://files.pythonhosted.org/packages/8e/ac/824ff70eb5b5dd2a4b597a2017ae62f24b9aaa5fd846f04c94dc447aa1ec/jax_cuda12_pjrt-0.5.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d23833c1b885d96c2764000e95052f2b5827c77d492ea68f67e903a132656dbb", size = 103122594 }, +] + +[[package]] +name = "jax-cuda12-plugin" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jax-cuda12-pjrt" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/57/58/3dab6bb4cdbc43663093c2af4671e87312236a23c84a3fc152d3c3979019/jax_cuda12_plugin-0.5.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d497dcc9205a11d283c308d8f400fb71507cf808753168d47effd1d4c47f9c3d", size = 16777702 }, + { url = "https://files.pythonhosted.org/packages/c2/46/a54402df9e2d057bb16d7e2ab045bd536fc8b83662cfc8d503fc56f5fc41/jax_cuda12_plugin-0.5.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:0f443a6b37298edfb0796fcdbd1f86ce85a4b084b6bd3f1f50a4fbfd67ded86b", size = 16733143 }, + { url = "https://files.pythonhosted.org/packages/d9/d5/64ad0b832122d938cbad07652625679a35c03e16e2ce4b8eda4ead8feed5/jax_cuda12_plugin-0.5.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:25407ccb030e4eed7d7e2ccccac8ab65f932aa05936ca5cf0e8ded4adfdcad1a", size = 16777553 }, + { url = "https://files.pythonhosted.org/packages/a2/7b/cc9fa545db9397de9054357de8440c8b10d28a6ab5d1cef1eba184c3d426/jax_cuda12_plugin-0.5.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:a98135a0223064b8f5c6853e22ddc1a4e3862152d37fb685f0dbdeffe0c80122", size = 16734352 }, +] + +[[package]] +name = "jaxlib" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "scipy" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/41/3e4ac64df72c4da126df3fd66a2214025a46b6263f7be266728e7b8e473e/jaxlib-0.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1b8a6c4345f137f387650de2dbc488c20251b7412b55dd648e1a4f13bcf507fb", size = 79248968 }, + { url = "https://files.pythonhosted.org/packages/1e/5f/2a16e61f1d54ae5f55fbf3cb3e22ef5bb01bf9d7d6474e0d34fedba19c4d/jaxlib-0.5.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5b2efe3dfebf18a84c451d3803ac884ee242021c1113b279c13f4bbc378c3dc0", size = 93181077 }, + { url = "https://files.pythonhosted.org/packages/08/c3/573e2f01b99f1247e8fbe1aa46b95a0faa68ef208f9a8e8ef775d607b3e6/jaxlib-0.5.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:74440b632107336400d4f97a16481d767f13ea914c53ba14e544c6fda54819b3", size = 101969119 }, + { url = "https://files.pythonhosted.org/packages/6e/38/512f61ea13da41ca47f2411d7c05af0cf74a37f225e16725ed0e6fb58893/jaxlib-0.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:53478a28eee6c2ef01759b05a9491702daef9268c3ed013d6f8e2e5f5cae0887", size = 63883394 }, + { url = "https://files.pythonhosted.org/packages/92/4b/8875870ff52ad3fbea876c905228f691f05c8dc8556b226cbfaf0fba7f62/jaxlib-0.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6cd762ed1623132499fa701c4203446102e0a9c82ca23194b87288f746d12a29", size = 79242870 }, + { url = "https://files.pythonhosted.org/packages/a0/0f/00cdfa411d7218e4696c10c5867f7d3c396219adbcaeb02e95108ca802de/jaxlib-0.5.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:63088dbfaa85bb56cd521a925a3472fd7328b18ec93c2d8ffa85af331095c995", size = 93181807 }, + { url = "https://files.pythonhosted.org/packages/58/8e/a5c29db03d5a93b0326e297b556d0e0a9805e9c9c1ae5f82f69557273faa/jaxlib-0.5.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:09113ef1582ba34d7cbc440fedb318f4855b59b776711a8aba2473c9727d3025", size = 101969212 }, + { url = "https://files.pythonhosted.org/packages/70/86/ceae20e4f37fa07f1cc95551cc0f49170d0db46d2e82fdf511d26bffd801/jaxlib-0.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:78289fc3ddc1e4e9510de2536a6375df9fe1c50de0ac60826c286b7a5c5090fe", size = 63881994 }, +] + +[[package]] +name = "jedi" +version = "0.19.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "parso" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278 }, +] + +[[package]] +name = "jinja2" +version = "3.1.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/af/92/b3130cbbf5591acf9ade8708c365f3238046ac7cb8ccba6e81abccb0ccff/jinja2-3.1.5.tar.gz", hash = "sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb", size = 244674 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/0f/2ba5fbcd631e3e88689309dbe978c5769e883e4b84ebfe7da30b43275c5a/jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb", size = 134596 }, +] + +[[package]] +name = "jsonschema" +version = "4.23.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "jsonschema-specifications" }, + { name = "referencing" }, + { name = "rpds-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/38/2e/03362ee4034a4c917f697890ccd4aec0800ccf9ded7f511971c75451deec/jsonschema-4.23.0.tar.gz", hash = "sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4", size = 325778 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/4a/4f9dbeb84e8850557c02365a0eee0649abe5eb1d84af92a25731c6c0f922/jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566", size = 88462 }, +] + +[[package]] +name = "jsonschema-specifications" +version = "2024.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "referencing" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/10/db/58f950c996c793472e336ff3655b13fbcf1e3b359dcf52dcf3ed3b52c352/jsonschema_specifications-2024.10.1.tar.gz", hash = "sha256:0f38b83639958ce1152d02a7f062902c41c8fd20d558b0c34344292d417ae272", size = 15561 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/0f/8910b19ac0670a0f80ce1008e5e751c4a57e14d2c4c13a482aa6079fa9d6/jsonschema_specifications-2024.10.1-py3-none-any.whl", hash = "sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf", size = 18459 }, +] + +[[package]] +name = "jupyter-client" +version = "8.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyter-core" }, + { name = "python-dateutil" }, + { name = "pyzmq" }, + { name = "tornado" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/22/bf9f12fdaeae18019a468b68952a60fe6dbab5d67cd2a103cac7659b41ca/jupyter_client-8.6.3.tar.gz", hash = "sha256:35b3a0947c4a6e9d589eb97d7d4cd5e90f910ee73101611f01283732bd6d9419", size = 342019 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/85/b0394e0b6fcccd2c1eeefc230978a6f8cb0c5df1e4cd3e7625735a0d7d1e/jupyter_client-8.6.3-py3-none-any.whl", hash = "sha256:e8a19cc986cc45905ac3362915f410f3af85424b4c0905e94fa5f2cb08e8f23f", size = 106105 }, +] + +[[package]] +name = "jupyter-core" +version = "5.7.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "platformdirs" }, + { name = "pywin32", marker = "(platform_python_implementation != 'PyPy' and sys_platform == 'win32') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/11/b56381fa6c3f4cc5d2cf54a7dbf98ad9aa0b339ef7a601d6053538b079a7/jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9", size = 87629 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/fb/108ecd1fe961941959ad0ee4e12ee7b8b1477247f30b1fdfd83ceaf017f0/jupyter_core-5.7.2-py3-none-any.whl", hash = "sha256:4f7315d2f6b4bcf2e3e7cb6e46772eba760ae459cd1f59d29eb57b0a01bd7409", size = 28965 }, +] + +[[package]] +name = "jupytext" +version = "1.16.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "mdit-py-plugins" }, + { name = "nbformat" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/10/e7/58d6fd374e1065d2bccefd07953d2f1f911d8de03fd7dc33dd5a25ac659c/jupytext-1.16.6.tar.gz", hash = "sha256:dbd03f9263c34b737003f388fc069e9030834fb7136879c4c32c32473557baa0", size = 3726029 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/02/27191f18564d4f2c0e543643aa94b54567de58f359cd6a3bed33adb723ac/jupytext-1.16.6-py3-none-any.whl", hash = "sha256:900132031f73fee15a1c9ebd862e05eb5f51e1ad6ab3a2c6fdd97ce2f9c913b4", size = 154200 }, +] + +[[package]] +name = "kiwisolver" +version = "1.4.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/82/59/7c91426a8ac292e1cdd53a63b6d9439abd573c875c3f92c146767dd33faf/kiwisolver-1.4.8.tar.gz", hash = "sha256:23d5f023bdc8c7e54eb65f03ca5d5bb25b601eac4d7f1a042888a1f45237987e", size = 97538 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/5f/4d8e9e852d98ecd26cdf8eaf7ed8bc33174033bba5e07001b289f07308fd/kiwisolver-1.4.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:88c6f252f6816a73b1f8c904f7bbe02fd67c09a69f7cb8a0eecdbf5ce78e63db", size = 124623 }, + { url = "https://files.pythonhosted.org/packages/1d/70/7f5af2a18a76fe92ea14675f8bd88ce53ee79e37900fa5f1a1d8e0b42998/kiwisolver-1.4.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c72941acb7b67138f35b879bbe85be0f6c6a70cab78fe3ef6db9c024d9223e5b", size = 66720 }, + { url = "https://files.pythonhosted.org/packages/c6/13/e15f804a142353aefd089fadc8f1d985561a15358c97aca27b0979cb0785/kiwisolver-1.4.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ce2cf1e5688edcb727fdf7cd1bbd0b6416758996826a8be1d958f91880d0809d", size = 65413 }, + { url = "https://files.pythonhosted.org/packages/ce/6d/67d36c4d2054e83fb875c6b59d0809d5c530de8148846b1370475eeeece9/kiwisolver-1.4.8-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:c8bf637892dc6e6aad2bc6d4d69d08764166e5e3f69d469e55427b6ac001b19d", size = 1650826 }, + { url = "https://files.pythonhosted.org/packages/de/c6/7b9bb8044e150d4d1558423a1568e4f227193662a02231064e3824f37e0a/kiwisolver-1.4.8-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:034d2c891f76bd3edbdb3ea11140d8510dca675443da7304205a2eaa45d8334c", size = 1628231 }, + { url = "https://files.pythonhosted.org/packages/b6/38/ad10d437563063eaaedbe2c3540a71101fc7fb07a7e71f855e93ea4de605/kiwisolver-1.4.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d47b28d1dfe0793d5e96bce90835e17edf9a499b53969b03c6c47ea5985844c3", size = 1408938 }, + { url = "https://files.pythonhosted.org/packages/52/ce/c0106b3bd7f9e665c5f5bc1e07cc95b5dabd4e08e3dad42dbe2faad467e7/kiwisolver-1.4.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb158fe28ca0c29f2260cca8c43005329ad58452c36f0edf298204de32a9a3ed", size = 1422799 }, + { url = "https://files.pythonhosted.org/packages/d0/87/efb704b1d75dc9758087ba374c0f23d3254505edaedd09cf9d247f7878b9/kiwisolver-1.4.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5536185fce131780ebd809f8e623bf4030ce1b161353166c49a3c74c287897f", size = 1354362 }, + { url = "https://files.pythonhosted.org/packages/eb/b3/fd760dc214ec9a8f208b99e42e8f0130ff4b384eca8b29dd0efc62052176/kiwisolver-1.4.8-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:369b75d40abedc1da2c1f4de13f3482cb99e3237b38726710f4a793432b1c5ff", size = 2222695 }, + { url = "https://files.pythonhosted.org/packages/a2/09/a27fb36cca3fc01700687cc45dae7a6a5f8eeb5f657b9f710f788748e10d/kiwisolver-1.4.8-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:641f2ddf9358c80faa22e22eb4c9f54bd3f0e442e038728f500e3b978d00aa7d", size = 2370802 }, + { url = "https://files.pythonhosted.org/packages/3d/c3/ba0a0346db35fe4dc1f2f2cf8b99362fbb922d7562e5f911f7ce7a7b60fa/kiwisolver-1.4.8-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:d561d2d8883e0819445cfe58d7ddd673e4015c3c57261d7bdcd3710d0d14005c", size = 2334646 }, + { url = "https://files.pythonhosted.org/packages/41/52/942cf69e562f5ed253ac67d5c92a693745f0bed3c81f49fc0cbebe4d6b00/kiwisolver-1.4.8-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:1732e065704b47c9afca7ffa272f845300a4eb959276bf6970dc07265e73b605", size = 2467260 }, + { url = "https://files.pythonhosted.org/packages/32/26/2d9668f30d8a494b0411d4d7d4ea1345ba12deb6a75274d58dd6ea01e951/kiwisolver-1.4.8-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bcb1ebc3547619c3b58a39e2448af089ea2ef44b37988caf432447374941574e", size = 2288633 }, + { url = "https://files.pythonhosted.org/packages/98/99/0dd05071654aa44fe5d5e350729961e7bb535372935a45ac89a8924316e6/kiwisolver-1.4.8-cp310-cp310-win_amd64.whl", hash = "sha256:89c107041f7b27844179ea9c85d6da275aa55ecf28413e87624d033cf1f6b751", size = 71885 }, + { url = "https://files.pythonhosted.org/packages/6c/fc/822e532262a97442989335394d441cd1d0448c2e46d26d3e04efca84df22/kiwisolver-1.4.8-cp310-cp310-win_arm64.whl", hash = "sha256:b5773efa2be9eb9fcf5415ea3ab70fc785d598729fd6057bea38d539ead28271", size = 65175 }, + { url = "https://files.pythonhosted.org/packages/da/ed/c913ee28936c371418cb167b128066ffb20bbf37771eecc2c97edf8a6e4c/kiwisolver-1.4.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a4d3601908c560bdf880f07d94f31d734afd1bb71e96585cace0e38ef44c6d84", size = 124635 }, + { url = "https://files.pythonhosted.org/packages/4c/45/4a7f896f7467aaf5f56ef093d1f329346f3b594e77c6a3c327b2d415f521/kiwisolver-1.4.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:856b269c4d28a5c0d5e6c1955ec36ebfd1651ac00e1ce0afa3e28da95293b561", size = 66717 }, + { url = "https://files.pythonhosted.org/packages/5f/b4/c12b3ac0852a3a68f94598d4c8d569f55361beef6159dce4e7b624160da2/kiwisolver-1.4.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c2b9a96e0f326205af81a15718a9073328df1173a2619a68553decb7097fd5d7", size = 65413 }, + { url = "https://files.pythonhosted.org/packages/a9/98/1df4089b1ed23d83d410adfdc5947245c753bddfbe06541c4aae330e9e70/kiwisolver-1.4.8-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c5020c83e8553f770cb3b5fc13faac40f17e0b205bd237aebd21d53d733adb03", size = 1343994 }, + { url = "https://files.pythonhosted.org/packages/8d/bf/b4b169b050c8421a7c53ea1ea74e4ef9c335ee9013216c558a047f162d20/kiwisolver-1.4.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dace81d28c787956bfbfbbfd72fdcef014f37d9b48830829e488fdb32b49d954", size = 1434804 }, + { url = "https://files.pythonhosted.org/packages/66/5a/e13bd341fbcf73325ea60fdc8af752addf75c5079867af2e04cc41f34434/kiwisolver-1.4.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11e1022b524bd48ae56c9b4f9296bce77e15a2e42a502cceba602f804b32bb79", size = 1450690 }, + { url = "https://files.pythonhosted.org/packages/9b/4f/5955dcb376ba4a830384cc6fab7d7547bd6759fe75a09564910e9e3bb8ea/kiwisolver-1.4.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b9b4d2892fefc886f30301cdd80debd8bb01ecdf165a449eb6e78f79f0fabd6", size = 1376839 }, + { url = "https://files.pythonhosted.org/packages/3a/97/5edbed69a9d0caa2e4aa616ae7df8127e10f6586940aa683a496c2c280b9/kiwisolver-1.4.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a96c0e790ee875d65e340ab383700e2b4891677b7fcd30a699146f9384a2bb0", size = 1435109 }, + { url = "https://files.pythonhosted.org/packages/13/fc/e756382cb64e556af6c1809a1bbb22c141bbc2445049f2da06b420fe52bf/kiwisolver-1.4.8-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:23454ff084b07ac54ca8be535f4174170c1094a4cff78fbae4f73a4bcc0d4dab", size = 2245269 }, + { url = "https://files.pythonhosted.org/packages/76/15/e59e45829d7f41c776d138245cabae6515cb4eb44b418f6d4109c478b481/kiwisolver-1.4.8-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:87b287251ad6488e95b4f0b4a79a6d04d3ea35fde6340eb38fbd1ca9cd35bbbc", size = 2393468 }, + { url = "https://files.pythonhosted.org/packages/e9/39/483558c2a913ab8384d6e4b66a932406f87c95a6080112433da5ed668559/kiwisolver-1.4.8-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:b21dbe165081142b1232a240fc6383fd32cdd877ca6cc89eab93e5f5883e1c25", size = 2355394 }, + { url = "https://files.pythonhosted.org/packages/01/aa/efad1fbca6570a161d29224f14b082960c7e08268a133fe5dc0f6906820e/kiwisolver-1.4.8-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:768cade2c2df13db52475bd28d3a3fac8c9eff04b0e9e2fda0f3760f20b3f7fc", size = 2490901 }, + { url = "https://files.pythonhosted.org/packages/c9/4f/15988966ba46bcd5ab9d0c8296914436720dd67fca689ae1a75b4ec1c72f/kiwisolver-1.4.8-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d47cfb2650f0e103d4bf68b0b5804c68da97272c84bb12850d877a95c056bd67", size = 2312306 }, + { url = "https://files.pythonhosted.org/packages/2d/27/bdf1c769c83f74d98cbc34483a972f221440703054894a37d174fba8aa68/kiwisolver-1.4.8-cp311-cp311-win_amd64.whl", hash = "sha256:ed33ca2002a779a2e20eeb06aea7721b6e47f2d4b8a8ece979d8ba9e2a167e34", size = 71966 }, + { url = "https://files.pythonhosted.org/packages/4a/c9/9642ea855604aeb2968a8e145fc662edf61db7632ad2e4fb92424be6b6c0/kiwisolver-1.4.8-cp311-cp311-win_arm64.whl", hash = "sha256:16523b40aab60426ffdebe33ac374457cf62863e330a90a0383639ce14bf44b2", size = 65311 }, + { url = "https://files.pythonhosted.org/packages/1f/f9/ae81c47a43e33b93b0a9819cac6723257f5da2a5a60daf46aa5c7226ea85/kiwisolver-1.4.8-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:e7a019419b7b510f0f7c9dceff8c5eae2392037eae483a7f9162625233802b0a", size = 60403 }, + { url = "https://files.pythonhosted.org/packages/58/ca/f92b5cb6f4ce0c1ebfcfe3e2e42b96917e16f7090e45b21102941924f18f/kiwisolver-1.4.8-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:286b18e86682fd2217a48fc6be6b0f20c1d0ed10958d8dc53453ad58d7be0bf8", size = 58657 }, + { url = "https://files.pythonhosted.org/packages/80/28/ae0240f732f0484d3a4dc885d055653c47144bdf59b670aae0ec3c65a7c8/kiwisolver-1.4.8-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4191ee8dfd0be1c3666ccbac178c5a05d5f8d689bbe3fc92f3c4abec817f8fe0", size = 84948 }, + { url = "https://files.pythonhosted.org/packages/5d/eb/78d50346c51db22c7203c1611f9b513075f35c4e0e4877c5dde378d66043/kiwisolver-1.4.8-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cd2785b9391f2873ad46088ed7599a6a71e762e1ea33e87514b1a441ed1da1c", size = 81186 }, + { url = "https://files.pythonhosted.org/packages/43/f8/7259f18c77adca88d5f64f9a522792e178b2691f3748817a8750c2d216ef/kiwisolver-1.4.8-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c07b29089b7ba090b6f1a669f1411f27221c3662b3a1b7010e67b59bb5a6f10b", size = 80279 }, + { url = "https://files.pythonhosted.org/packages/3a/1d/50ad811d1c5dae091e4cf046beba925bcae0a610e79ae4c538f996f63ed5/kiwisolver-1.4.8-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:65ea09a5a3faadd59c2ce96dc7bf0f364986a315949dc6374f04396b0d60e09b", size = 71762 }, +] + +[[package]] +name = "lark" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/af/60/bc7622aefb2aee1c0b4ba23c1446d3e30225c8770b38d7aedbfb65ca9d5a/lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80", size = 252132 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2d/00/d90b10b962b4277f5e64a78b6609968859ff86889f5b898c1a778c06ec00/lark-1.2.2-py3-none-any.whl", hash = "sha256:c2276486b02f0f1b90be155f2c8ba4a8e194d42775786db622faccd652d8e80c", size = 111036 }, +] + +[[package]] +name = "lsprotocol" +version = "2023.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "cattrs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9d/f6/6e80484ec078d0b50699ceb1833597b792a6c695f90c645fbaf54b947e6f/lsprotocol-2023.0.1.tar.gz", hash = "sha256:cc5c15130d2403c18b734304339e51242d3018a05c4f7d0f198ad6e0cd21861d", size = 69434 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/37/2351e48cb3309673492d3a8c59d407b75fb6630e560eb27ecd4da03adc9a/lsprotocol-2023.0.1-py3-none-any.whl", hash = "sha256:c75223c9e4af2f24272b14c6375787438279369236cd568f596d4951052a60f2", size = 70826 }, +] + +[[package]] +name = "mako" +version = "1.3.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5f/d9/8518279534ed7dace1795d5a47e49d5299dd0994eed1053996402a8902f9/mako-1.3.8.tar.gz", hash = "sha256:577b97e414580d3e088d47c2dbbe9594aa7a5146ed2875d4dfa9075af2dd3cc8", size = 392069 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/bf/7a6a36ce2e4cafdfb202752be68850e22607fccd692847c45c1ae3c17ba6/Mako-1.3.8-py3-none-any.whl", hash = "sha256:42f48953c7eb91332040ff567eb7eea69b22e7a4affbc5ba8e845e8f730f6627", size = 78569 }, +] + +[[package]] +name = "markdown-it-py" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/38/71/3b932df36c1a044d397a1f92d1cf91ee0a503d91e470cbd670aa66b07ed0/markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", size = 74596 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528 }, +] + +[[package]] +name = "markupsafe" +version = "3.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/97/5d42485e71dfc078108a86d6de8fa46db44a1a9295e89c5d6d4a06e23a62/markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0", size = 20537 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/90/d08277ce111dd22f77149fd1a5d4653eeb3b3eaacbdfcbae5afb2600eebd/MarkupSafe-3.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8", size = 14357 }, + { url = "https://files.pythonhosted.org/packages/04/e1/6e2194baeae0bca1fae6629dc0cbbb968d4d941469cbab11a3872edff374/MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158", size = 12393 }, + { url = "https://files.pythonhosted.org/packages/1d/69/35fa85a8ece0a437493dc61ce0bb6d459dcba482c34197e3efc829aa357f/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38a9ef736c01fccdd6600705b09dc574584b89bea478200c5fbf112a6b0d5579", size = 21732 }, + { url = "https://files.pythonhosted.org/packages/22/35/137da042dfb4720b638d2937c38a9c2df83fe32d20e8c8f3185dbfef05f7/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbcb445fa71794da8f178f0f6d66789a28d7319071af7a496d4d507ed566270d", size = 20866 }, + { url = "https://files.pythonhosted.org/packages/29/28/6d029a903727a1b62edb51863232152fd335d602def598dade38996887f0/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57cb5a3cf367aeb1d316576250f65edec5bb3be939e9247ae594b4bcbc317dfb", size = 20964 }, + { url = "https://files.pythonhosted.org/packages/cc/cd/07438f95f83e8bc028279909d9c9bd39e24149b0d60053a97b2bc4f8aa51/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3809ede931876f5b2ec92eef964286840ed3540dadf803dd570c3b7e13141a3b", size = 21977 }, + { url = "https://files.pythonhosted.org/packages/29/01/84b57395b4cc062f9c4c55ce0df7d3108ca32397299d9df00fedd9117d3d/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e07c3764494e3776c602c1e78e298937c3315ccc9043ead7e685b7f2b8d47b3c", size = 21366 }, + { url = "https://files.pythonhosted.org/packages/bd/6e/61ebf08d8940553afff20d1fb1ba7294b6f8d279df9fd0c0db911b4bbcfd/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b424c77b206d63d500bcb69fa55ed8d0e6a3774056bdc4839fc9298a7edca171", size = 21091 }, + { url = "https://files.pythonhosted.org/packages/11/23/ffbf53694e8c94ebd1e7e491de185124277964344733c45481f32ede2499/MarkupSafe-3.0.2-cp310-cp310-win32.whl", hash = "sha256:fcabf5ff6eea076f859677f5f0b6b5c1a51e70a376b0579e0eadef8db48c6b50", size = 15065 }, + { url = "https://files.pythonhosted.org/packages/44/06/e7175d06dd6e9172d4a69a72592cb3f7a996a9c396eee29082826449bbc3/MarkupSafe-3.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:6af100e168aa82a50e186c82875a5893c5597a0c1ccdb0d8b40240b1f28b969a", size = 15514 }, + { url = "https://files.pythonhosted.org/packages/6b/28/bbf83e3f76936960b850435576dd5e67034e200469571be53f69174a2dfd/MarkupSafe-3.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9025b4018f3a1314059769c7bf15441064b2207cb3f065e6ea1e7359cb46db9d", size = 14353 }, + { url = "https://files.pythonhosted.org/packages/6c/30/316d194b093cde57d448a4c3209f22e3046c5bb2fb0820b118292b334be7/MarkupSafe-3.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:93335ca3812df2f366e80509ae119189886b0f3c2b81325d39efdb84a1e2ae93", size = 12392 }, + { url = "https://files.pythonhosted.org/packages/f2/96/9cdafba8445d3a53cae530aaf83c38ec64c4d5427d975c974084af5bc5d2/MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cb8438c3cbb25e220c2ab33bb226559e7afb3baec11c4f218ffa7308603c832", size = 23984 }, + { url = "https://files.pythonhosted.org/packages/f1/a4/aefb044a2cd8d7334c8a47d3fb2c9f328ac48cb349468cc31c20b539305f/MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a123e330ef0853c6e822384873bef7507557d8e4a082961e1defa947aa59ba84", size = 23120 }, + { url = "https://files.pythonhosted.org/packages/8d/21/5e4851379f88f3fad1de30361db501300d4f07bcad047d3cb0449fc51f8c/MarkupSafe-3.0.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e084f686b92e5b83186b07e8a17fc09e38fff551f3602b249881fec658d3eca", size = 23032 }, + { url = "https://files.pythonhosted.org/packages/00/7b/e92c64e079b2d0d7ddf69899c98842f3f9a60a1ae72657c89ce2655c999d/MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8213e09c917a951de9d09ecee036d5c7d36cb6cb7dbaece4c71a60d79fb9798", size = 24057 }, + { url = "https://files.pythonhosted.org/packages/f9/ac/46f960ca323037caa0a10662ef97d0a4728e890334fc156b9f9e52bcc4ca/MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:5b02fb34468b6aaa40dfc198d813a641e3a63b98c2b05a16b9f80b7ec314185e", size = 23359 }, + { url = "https://files.pythonhosted.org/packages/69/84/83439e16197337b8b14b6a5b9c2105fff81d42c2a7c5b58ac7b62ee2c3b1/MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4", size = 23306 }, + { url = "https://files.pythonhosted.org/packages/9a/34/a15aa69f01e2181ed8d2b685c0d2f6655d5cca2c4db0ddea775e631918cd/MarkupSafe-3.0.2-cp311-cp311-win32.whl", hash = "sha256:6c89876f41da747c8d3677a2b540fb32ef5715f97b66eeb0c6b66f5e3ef6f59d", size = 15094 }, + { url = "https://files.pythonhosted.org/packages/da/b8/3a3bd761922d416f3dc5d00bfbed11f66b1ab89a0c2b6e887240a30b0f6b/MarkupSafe-3.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:70a87b411535ccad5ef2f1df5136506a10775d267e197e4cf531ced10537bd6b", size = 15521 }, +] + +[[package]] +name = "matplotlib" +version = "3.10.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "contourpy" }, + { name = "cycler" }, + { name = "fonttools" }, + { name = "kiwisolver" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pillow" }, + { name = "pyparsing" }, + { name = "python-dateutil" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/dd/fa2e1a45fce2d09f4aea3cee169760e672c8262325aa5796c49d543dc7e6/matplotlib-3.10.0.tar.gz", hash = "sha256:b886d02a581b96704c9d1ffe55709e49b4d2d52709ccebc4be42db856e511278", size = 36686418 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/09/ec/3cdff7b5239adaaacefcc4f77c316dfbbdf853c4ed2beec467e0fec31b9f/matplotlib-3.10.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2c5829a5a1dd5a71f0e31e6e8bb449bc0ee9dbfb05ad28fc0c6b55101b3a4be6", size = 8160551 }, + { url = "https://files.pythonhosted.org/packages/41/f2/b518f2c7f29895c9b167bf79f8529c63383ae94eaf49a247a4528e9a148d/matplotlib-3.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2a43cbefe22d653ab34bb55d42384ed30f611bcbdea1f8d7f431011a2e1c62e", size = 8034853 }, + { url = "https://files.pythonhosted.org/packages/ed/8d/45754b4affdb8f0d1a44e4e2bcd932cdf35b256b60d5eda9f455bb293ed0/matplotlib-3.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:607b16c8a73943df110f99ee2e940b8a1cbf9714b65307c040d422558397dac5", size = 8446724 }, + { url = "https://files.pythonhosted.org/packages/09/5a/a113495110ae3e3395c72d82d7bc4802902e46dc797f6b041e572f195c56/matplotlib-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01d2b19f13aeec2e759414d3bfe19ddfb16b13a1250add08d46d5ff6f9be83c6", size = 8583905 }, + { url = "https://files.pythonhosted.org/packages/12/b1/8b1655b4c9ed4600c817c419f7eaaf70082630efd7556a5b2e77a8a3cdaf/matplotlib-3.10.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5e6c6461e1fc63df30bf6f80f0b93f5b6784299f721bc28530477acd51bfc3d1", size = 9395223 }, + { url = "https://files.pythonhosted.org/packages/5a/85/b9a54d64585a6b8737a78a61897450403c30f39e0bd3214270bb0b96f002/matplotlib-3.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:994c07b9d9fe8d25951e3202a68c17900679274dadfc1248738dcfa1bd40d7f3", size = 8025355 }, + { url = "https://files.pythonhosted.org/packages/0c/f1/e37f6c84d252867d7ddc418fff70fc661cfd363179263b08e52e8b748e30/matplotlib-3.10.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:fd44fc75522f58612ec4a33958a7e5552562b7705b42ef1b4f8c0818e304a363", size = 8171677 }, + { url = "https://files.pythonhosted.org/packages/c7/8b/92e9da1f28310a1f6572b5c55097b0c0ceb5e27486d85fb73b54f5a9b939/matplotlib-3.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c58a9622d5dbeb668f407f35f4e6bfac34bb9ecdcc81680c04d0258169747997", size = 8044945 }, + { url = "https://files.pythonhosted.org/packages/c5/cb/49e83f0fd066937a5bd3bc5c5d63093703f3637b2824df8d856e0558beef/matplotlib-3.10.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:845d96568ec873be63f25fa80e9e7fae4be854a66a7e2f0c8ccc99e94a8bd4ef", size = 8458269 }, + { url = "https://files.pythonhosted.org/packages/b2/7d/2d873209536b9ee17340754118a2a17988bc18981b5b56e6715ee07373ac/matplotlib-3.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5439f4c5a3e2e8eab18e2f8c3ef929772fd5641876db71f08127eed95ab64683", size = 8599369 }, + { url = "https://files.pythonhosted.org/packages/b8/03/57d6cbbe85c61fe4cbb7c94b54dce443d68c21961830833a1f34d056e5ea/matplotlib-3.10.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4673ff67a36152c48ddeaf1135e74ce0d4bce1bbf836ae40ed39c29edf7e2765", size = 9405992 }, + { url = "https://files.pythonhosted.org/packages/14/cf/e382598f98be11bf51dd0bc60eca44a517f6793e3dc8b9d53634a144620c/matplotlib-3.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:7e8632baebb058555ac0cde75db885c61f1212e47723d63921879806b40bec6a", size = 8034580 }, + { url = "https://files.pythonhosted.org/packages/32/5f/29def7ce4e815ab939b56280976ee35afffb3bbdb43f332caee74cb8c951/matplotlib-3.10.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:81713dd0d103b379de4516b861d964b1d789a144103277769238c732229d7f03", size = 8155500 }, + { url = "https://files.pythonhosted.org/packages/de/6d/d570383c9f7ca799d0a54161446f9ce7b17d6c50f2994b653514bcaa108f/matplotlib-3.10.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:359f87baedb1f836ce307f0e850d12bb5f1936f70d035561f90d41d305fdacea", size = 8032398 }, + { url = "https://files.pythonhosted.org/packages/c9/b4/680aa700d99b48e8c4393fa08e9ab8c49c0555ee6f4c9c0a5e8ea8dfde5d/matplotlib-3.10.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ae80dc3a4add4665cf2faa90138384a7ffe2a4e37c58d83e115b54287c4f06ef", size = 8587361 }, +] + +[[package]] +name = "matplotlib-inline" +version = "0.1.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/99/5b/a36a337438a14116b16480db471ad061c36c3694df7c2084a0da7ba538b7/matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90", size = 8159 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899 }, +] + +[[package]] +name = "mdit-py-plugins" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/03/a2ecab526543b152300717cf232bb4bb8605b6edb946c845016fa9c9c9fd/mdit_py_plugins-0.4.2.tar.gz", hash = "sha256:5f2cd1fdb606ddf152d37ec30e46101a60512bc0e5fa1a7002c36647b09e26b5", size = 43542 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/f7/7782a043553ee469c1ff49cfa1cdace2d6bf99a1f333cf38676b3ddf30da/mdit_py_plugins-0.4.2-py3-none-any.whl", hash = "sha256:0c673c3f889399a33b95e88d2f0d111b4447bdfea7f237dab2d488f459835636", size = 55316 }, +] + +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 }, +] + +[[package]] +name = "ml-dtypes" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/32/49/6e67c334872d2c114df3020e579f3718c333198f8312290e09ec0216703a/ml_dtypes-0.5.1.tar.gz", hash = "sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9", size = 698772 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/88/11ebdbc75445eeb5b6869b708a0d787d1ed812ff86c2170bbfb95febdce1/ml_dtypes-0.5.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190", size = 671450 }, + { url = "https://files.pythonhosted.org/packages/a4/a4/9321cae435d6140f9b0e7af8334456a854b60e3a9c6101280a16e3594965/ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed", size = 4621075 }, + { url = "https://files.pythonhosted.org/packages/16/d8/4502e12c6a10d42e13a552e8d97f20198e3cf82a0d1411ad50be56a5077c/ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe", size = 4738414 }, + { url = "https://files.pythonhosted.org/packages/6b/7e/bc54ae885e4d702e60a4bf50aa9066ff35e9c66b5213d11091f6bffb3036/ml_dtypes-0.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4", size = 209718 }, + { url = "https://files.pythonhosted.org/packages/c9/fd/691335926126bb9beeb030b61a28f462773dcf16b8e8a2253b599013a303/ml_dtypes-0.5.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327", size = 671448 }, + { url = "https://files.pythonhosted.org/packages/ff/a6/63832d91f2feb250d865d069ba1a5d0c686b1f308d1c74ce9764472c5e22/ml_dtypes-0.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f", size = 4625792 }, + { url = "https://files.pythonhosted.org/packages/cc/2a/5421fd3dbe6eef9b844cc9d05f568b9fb568503a2e51cb1eb4443d9fc56b/ml_dtypes-0.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab", size = 4743893 }, + { url = "https://files.pythonhosted.org/packages/60/30/d3f0fc9499a22801219679a7f3f8d59f1429943c6261f445fb4bfce20718/ml_dtypes-0.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478", size = 209712 }, +] + +[[package]] +name = "more-itertools" +version = "10.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/88/3b/7fa1fe835e2e93fd6d7b52b2f95ae810cf5ba133e1845f726f5a992d62c2/more-itertools-10.6.0.tar.gz", hash = "sha256:2cd7fad1009c31cc9fb6a035108509e6547547a7a738374f10bd49a09eb3ee3b", size = 125009 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/62/0fe302c6d1be1c777cab0616e6302478251dfbf9055ad426f5d0def75c89/more_itertools-10.6.0-py3-none-any.whl", hash = "sha256:6eb054cb4b6db1473f6e15fcc676a08e4732548acd47c708f0e179c2c7c01e89", size = 63038 }, +] + +[[package]] +name = "mpmath" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198 }, +] + +[[package]] +name = "msgpack" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cb/d0/7555686ae7ff5731205df1012ede15dd9d927f6227ea151e901c7406af4f/msgpack-1.1.0.tar.gz", hash = "sha256:dd432ccc2c72b914e4cb77afce64aab761c1137cc698be3984eee260bcb2896e", size = 167260 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/f9/a892a6038c861fa849b11a2bb0502c07bc698ab6ea53359e5771397d883b/msgpack-1.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7ad442d527a7e358a469faf43fda45aaf4ac3249c8310a82f0ccff9164e5dccd", size = 150428 }, + { url = "https://files.pythonhosted.org/packages/df/7a/d174cc6a3b6bb85556e6a046d3193294a92f9a8e583cdbd46dc8a1d7e7f4/msgpack-1.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:74bed8f63f8f14d75eec75cf3d04ad581da6b914001b474a5d3cd3372c8cc27d", size = 84131 }, + { url = "https://files.pythonhosted.org/packages/08/52/bf4fbf72f897a23a56b822997a72c16de07d8d56d7bf273242f884055682/msgpack-1.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:914571a2a5b4e7606997e169f64ce53a8b1e06f2cf2c3a7273aa106236d43dd5", size = 81215 }, + { url = "https://files.pythonhosted.org/packages/02/95/dc0044b439b518236aaf012da4677c1b8183ce388411ad1b1e63c32d8979/msgpack-1.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c921af52214dcbb75e6bdf6a661b23c3e6417f00c603dd2070bccb5c3ef499f5", size = 371229 }, + { url = "https://files.pythonhosted.org/packages/ff/75/09081792db60470bef19d9c2be89f024d366b1e1973c197bb59e6aabc647/msgpack-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8ce0b22b890be5d252de90d0e0d119f363012027cf256185fc3d474c44b1b9e", size = 378034 }, + { url = "https://files.pythonhosted.org/packages/32/d3/c152e0c55fead87dd948d4b29879b0f14feeeec92ef1fd2ec21b107c3f49/msgpack-1.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:73322a6cc57fcee3c0c57c4463d828e9428275fb85a27aa2aa1a92fdc42afd7b", size = 363070 }, + { url = "https://files.pythonhosted.org/packages/d9/2c/82e73506dd55f9e43ac8aa007c9dd088c6f0de2aa19e8f7330e6a65879fc/msgpack-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e1f3c3d21f7cf67bcf2da8e494d30a75e4cf60041d98b3f79875afb5b96f3a3f", size = 359863 }, + { url = "https://files.pythonhosted.org/packages/cb/a0/3d093b248837094220e1edc9ec4337de3443b1cfeeb6e0896af8ccc4cc7a/msgpack-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:64fc9068d701233effd61b19efb1485587560b66fe57b3e50d29c5d78e7fef68", size = 368166 }, + { url = "https://files.pythonhosted.org/packages/e4/13/7646f14f06838b406cf5a6ddbb7e8dc78b4996d891ab3b93c33d1ccc8678/msgpack-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:42f754515e0f683f9c79210a5d1cad631ec3d06cea5172214d2176a42e67e19b", size = 370105 }, + { url = "https://files.pythonhosted.org/packages/67/fa/dbbd2443e4578e165192dabbc6a22c0812cda2649261b1264ff515f19f15/msgpack-1.1.0-cp310-cp310-win32.whl", hash = "sha256:3df7e6b05571b3814361e8464f9304c42d2196808e0119f55d0d3e62cd5ea044", size = 68513 }, + { url = "https://files.pythonhosted.org/packages/24/ce/c2c8fbf0ded750cb63cbcbb61bc1f2dfd69e16dca30a8af8ba80ec182dcd/msgpack-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:685ec345eefc757a7c8af44a3032734a739f8c45d1b0ac45efc5d8977aa4720f", size = 74687 }, + { url = "https://files.pythonhosted.org/packages/b7/5e/a4c7154ba65d93be91f2f1e55f90e76c5f91ccadc7efc4341e6f04c8647f/msgpack-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3d364a55082fb2a7416f6c63ae383fbd903adb5a6cf78c5b96cc6316dc1cedc7", size = 150803 }, + { url = "https://files.pythonhosted.org/packages/60/c2/687684164698f1d51c41778c838d854965dd284a4b9d3a44beba9265c931/msgpack-1.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:79ec007767b9b56860e0372085f8504db5d06bd6a327a335449508bbee9648fa", size = 84343 }, + { url = "https://files.pythonhosted.org/packages/42/ae/d3adea9bb4a1342763556078b5765e666f8fdf242e00f3f6657380920972/msgpack-1.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6ad622bf7756d5a497d5b6836e7fc3752e2dd6f4c648e24b1803f6048596f701", size = 81408 }, + { url = "https://files.pythonhosted.org/packages/dc/17/6313325a6ff40ce9c3207293aee3ba50104aed6c2c1559d20d09e5c1ff54/msgpack-1.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e59bca908d9ca0de3dc8684f21ebf9a690fe47b6be93236eb40b99af28b6ea6", size = 396096 }, + { url = "https://files.pythonhosted.org/packages/a8/a1/ad7b84b91ab5a324e707f4c9761633e357820b011a01e34ce658c1dda7cc/msgpack-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e1da8f11a3dd397f0a32c76165cf0c4eb95b31013a94f6ecc0b280c05c91b59", size = 403671 }, + { url = "https://files.pythonhosted.org/packages/bb/0b/fd5b7c0b308bbf1831df0ca04ec76fe2f5bf6319833646b0a4bd5e9dc76d/msgpack-1.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:452aff037287acb1d70a804ffd022b21fa2bb7c46bee884dbc864cc9024128a0", size = 387414 }, + { url = "https://files.pythonhosted.org/packages/f0/03/ff8233b7c6e9929a1f5da3c7860eccd847e2523ca2de0d8ef4878d354cfa/msgpack-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8da4bf6d54ceed70e8861f833f83ce0814a2b72102e890cbdfe4b34764cdd66e", size = 383759 }, + { url = "https://files.pythonhosted.org/packages/1f/1b/eb82e1fed5a16dddd9bc75f0854b6e2fe86c0259c4353666d7fab37d39f4/msgpack-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:41c991beebf175faf352fb940bf2af9ad1fb77fd25f38d9142053914947cdbf6", size = 394405 }, + { url = "https://files.pythonhosted.org/packages/90/2e/962c6004e373d54ecf33d695fb1402f99b51832631e37c49273cc564ffc5/msgpack-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a52a1f3a5af7ba1c9ace055b659189f6c669cf3657095b50f9602af3a3ba0fe5", size = 396041 }, + { url = "https://files.pythonhosted.org/packages/f8/20/6e03342f629474414860c48aeffcc2f7f50ddaf351d95f20c3f1c67399a8/msgpack-1.1.0-cp311-cp311-win32.whl", hash = "sha256:58638690ebd0a06427c5fe1a227bb6b8b9fdc2bd07701bec13c2335c82131a88", size = 68538 }, + { url = "https://files.pythonhosted.org/packages/aa/c4/5a582fc9a87991a3e6f6800e9bb2f3c82972912235eb9539954f3e9997c7/msgpack-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fd2906780f25c8ed5d7b323379f6138524ba793428db5d0e9d226d3fa6aa1788", size = 74871 }, +] + +[[package]] +name = "mypy" +version = "1.14.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mypy-extensions" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/eb/2c92d8ea1e684440f54fa49ac5d9a5f19967b7b472a281f419e69a8d228e/mypy-1.14.1.tar.gz", hash = "sha256:7ec88144fe9b510e8475ec2f5f251992690fcf89ccb4500b214b4226abcd32d6", size = 3216051 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/7a/87ae2adb31d68402da6da1e5f30c07ea6063e9f09b5e7cfc9dfa44075e74/mypy-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:52686e37cf13d559f668aa398dd7ddf1f92c5d613e4f8cb262be2fb4fedb0fcb", size = 11211002 }, + { url = "https://files.pythonhosted.org/packages/e1/23/eada4c38608b444618a132be0d199b280049ded278b24cbb9d3fc59658e4/mypy-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1fb545ca340537d4b45d3eecdb3def05e913299ca72c290326be19b3804b39c0", size = 10358400 }, + { url = "https://files.pythonhosted.org/packages/43/c9/d6785c6f66241c62fd2992b05057f404237deaad1566545e9f144ced07f5/mypy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:90716d8b2d1f4cd503309788e51366f07c56635a3309b0f6a32547eaaa36a64d", size = 12095172 }, + { url = "https://files.pythonhosted.org/packages/c3/62/daa7e787770c83c52ce2aaf1a111eae5893de9e004743f51bfcad9e487ec/mypy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2ae753f5c9fef278bcf12e1a564351764f2a6da579d4a81347e1d5a15819997b", size = 12828732 }, + { url = "https://files.pythonhosted.org/packages/1b/a2/5fb18318a3637f29f16f4e41340b795da14f4751ef4f51c99ff39ab62e52/mypy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e0fe0f5feaafcb04505bcf439e991c6d8f1bf8b15f12b05feeed96e9e7bf1427", size = 13012197 }, + { url = "https://files.pythonhosted.org/packages/28/99/e153ce39105d164b5f02c06c35c7ba958aaff50a2babba7d080988b03fe7/mypy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:7d54bd85b925e501c555a3227f3ec0cfc54ee8b6930bd6141ec872d1c572f81f", size = 9780836 }, + { url = "https://files.pythonhosted.org/packages/da/11/a9422850fd506edbcdc7f6090682ecceaf1f87b9dd847f9df79942da8506/mypy-1.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f995e511de847791c3b11ed90084a7a0aafdc074ab88c5a9711622fe4751138c", size = 11120432 }, + { url = "https://files.pythonhosted.org/packages/b6/9e/47e450fd39078d9c02d620545b2cb37993a8a8bdf7db3652ace2f80521ca/mypy-1.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d64169ec3b8461311f8ce2fd2eb5d33e2d0f2c7b49116259c51d0d96edee48d1", size = 10279515 }, + { url = "https://files.pythonhosted.org/packages/01/b5/6c8d33bd0f851a7692a8bfe4ee75eb82b6983a3cf39e5e32a5d2a723f0c1/mypy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ba24549de7b89b6381b91fbc068d798192b1b5201987070319889e93038967a8", size = 12025791 }, + { url = "https://files.pythonhosted.org/packages/f0/4c/e10e2c46ea37cab5c471d0ddaaa9a434dc1d28650078ac1b56c2d7b9b2e4/mypy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:183cf0a45457d28ff9d758730cd0210419ac27d4d3f285beda038c9083363b1f", size = 12749203 }, + { url = "https://files.pythonhosted.org/packages/88/55/beacb0c69beab2153a0f57671ec07861d27d735a0faff135a494cd4f5020/mypy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f2a0ecc86378f45347f586e4163d1769dd81c5a223d577fe351f26b179e148b1", size = 12885900 }, + { url = "https://files.pythonhosted.org/packages/a2/75/8c93ff7f315c4d086a2dfcde02f713004357d70a163eddb6c56a6a5eff40/mypy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:ad3301ebebec9e8ee7135d8e3109ca76c23752bac1e717bc84cd3836b4bf3eae", size = 9777869 }, + { url = "https://files.pythonhosted.org/packages/a0/b5/32dd67b69a16d088e533962e5044e51004176a9952419de0370cdaead0f8/mypy-1.14.1-py3-none-any.whl", hash = "sha256:b66a60cc4073aeb8ae00057f9c1f64d49e90f918fbcef9a977eb121da8b8f1d1", size = 2752905 }, +] + +[package.optional-dependencies] +faster-cache = [ + { name = "orjson" }, +] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/98/a4/1ab47638b92648243faf97a5aeb6ea83059cc3624972ab6b8d2316078d3f/mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782", size = 4433 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695 }, +] + +[[package]] +name = "myst-parser" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "docutils" }, + { name = "jinja2" }, + { name = "markdown-it-py" }, + { name = "mdit-py-plugins" }, + { name = "pyyaml" }, + { name = "sphinx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/55/6d1741a1780e5e65038b74bce6689da15f620261c490c3511eb4c12bac4b/myst_parser-4.0.0.tar.gz", hash = "sha256:851c9dfb44e36e56d15d05e72f02b80da21a9e0d07cba96baf5e2d476bb91531", size = 93858 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/b4/b036f8fdb667587bb37df29dc6644681dd78b7a2a6321a34684b79412b28/myst_parser-4.0.0-py3-none-any.whl", hash = "sha256:b9317997552424448c6096c2558872fdb6f81d3ecb3a40ce84a7518798f3f28d", size = 84563 }, +] + +[[package]] +name = "nanobind" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1e/01/a28722f6626e5c8a606dee71cb40c0b2ab9f7715b96bd34a9553c79dbf42/nanobind-2.4.0.tar.gz", hash = "sha256:a0392dee5f58881085b2ac8bfe8e53f74285aa4868b1472bfaf76cfb414e1c96", size = 953467 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/07/abff41fcade3613349eac71dacb166352babef515efd960a751e3175c262/nanobind-2.4.0-py3-none-any.whl", hash = "sha256:8cf27b04fbadeb9deb4a73f02bd838bf9f7e3e5a8ce44c50c93142b5728da58a", size = 232882 }, +] + +[[package]] +name = "natsort" +version = "8.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e2/a9/a0c57aee75f77794adaf35322f8b6404cbd0f89ad45c87197a937764b7d0/natsort-8.4.0.tar.gz", hash = "sha256:45312c4a0e5507593da193dedd04abb1469253b601ecaf63445ad80f0a1ea581", size = 76575 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/82/7a9d0550484a62c6da82858ee9419f3dd1ccc9aa1c26a1e43da3ecd20b0d/natsort-8.4.0-py3-none-any.whl", hash = "sha256:4732914fb471f56b5cce04d7bae6f164a592c7712e1c85f9ef585e197299521c", size = 38268 }, +] + +[[package]] +name = "nbclient" +version = "0.10.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyter-client" }, + { name = "jupyter-core" }, + { name = "nbformat" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/87/66/7ffd18d58eae90d5721f9f39212327695b749e23ad44b3881744eaf4d9e8/nbclient-0.10.2.tar.gz", hash = "sha256:90b7fc6b810630db87a6d0c2250b1f0ab4cf4d3c27a299b0cde78a4ed3fd9193", size = 62424 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/34/6d/e7fa07f03a4a7b221d94b4d586edb754a9b0dc3c9e2c93353e9fa4e0d117/nbclient-0.10.2-py3-none-any.whl", hash = "sha256:4ffee11e788b4a27fabeb7955547e4318a5298f34342a4bfd01f2e1faaeadc3d", size = 25434 }, +] + +[[package]] +name = "nbformat" +version = "5.10.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastjsonschema" }, + { name = "jsonschema" }, + { name = "jupyter-core" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/fd/91545e604bc3dad7dca9ed03284086039b294c6b3d75c0d2fa45f9e9caf3/nbformat-5.10.4.tar.gz", hash = "sha256:322168b14f937a5d11362988ecac2a4952d3d8e3a2cbeb2319584631226d5b3a", size = 142749 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/82/0340caa499416c78e5d8f5f05947ae4bc3cba53c9f038ab6e9ed964e22f1/nbformat-5.10.4-py3-none-any.whl", hash = "sha256:3b48d6c8fbca4b299bf3982ea7db1af21580e4fec269ad087b9e81588891200b", size = 78454 }, +] + +[[package]] +name = "nbmake" +version = "1.5.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ipykernel" }, + { name = "nbclient" }, + { name = "nbformat" }, + { name = "pygments" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/43/9a/aae201cee5639e1d562b3843af8fd9f8d018bb323e776a2b973bdd5fc64b/nbmake-1.5.5.tar.gz", hash = "sha256:239dc868ea13a7c049746e2aba2c229bd0f6cdbc6bfa1d22f4c88638aa4c5f5c", size = 85929 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/be/b257e12f9710819fde40adc972578bee6b72c5992da1bc8369bef2597756/nbmake-1.5.5-py3-none-any.whl", hash = "sha256:c6fbe6e48b60cacac14af40b38bf338a3b88f47f085c54ac5b8639ff0babaf4b", size = 12818 }, +] + +[[package]] +name = "nest-asyncio" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/83/f8/51569ac65d696c8ecbee95938f89d4abf00f47d58d48f6fbabfe8f0baefe/nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe", size = 7418 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195 }, +] + +[[package]] +name = "networkx" +version = "3.4.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263 }, +] + +[[package]] +name = "ninja" +version = "1.11.1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bd/8f/21a2701f95b7d0d5137736561b3427ece0c4a1e085d4a223b92d16ab7d8b/ninja-1.11.1.3.tar.gz", hash = "sha256:edfa0d2e9d7ead1635b03e40a32ad56cc8f56798b6e2e9848d8300b174897076", size = 129532 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/ba/0069cd4a83d68f7b0308be70e219b15d675e50c8ea28763a3f0373c45bfc/ninja-1.11.1.3-py3-none-macosx_10_9_universal2.whl", hash = "sha256:2b4879ea3f1169f3d855182c57dcc84d1b5048628c8b7be0d702b81882a37237", size = 279132 }, + { url = "https://files.pythonhosted.org/packages/72/6b/3805be87df8417a0c7b21078c8045f2a1e59b34f371bfe4cb4fb0d6df7f2/ninja-1.11.1.3-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:bc3ebc8b2e47716149f3541742b5cd8e0b08f51013b825c05baca3e34854370d", size = 472101 }, + { url = "https://files.pythonhosted.org/packages/6b/35/a8e38d54768e67324e365e2a41162be298f51ec93e6bd4b18d237d7250d8/ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a27e78ca71316c8654965ee94b286a98c83877bfebe2607db96897bbfe458af0", size = 422884 }, + { url = "https://files.pythonhosted.org/packages/2f/99/7996457319e139c02697fb2aa28e42fe32bb0752cef492edc69d56a3552e/ninja-1.11.1.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2883ea46b3c5079074f56820f9989c6261fcc6fd873d914ee49010ecf283c3b2", size = 157046 }, + { url = "https://files.pythonhosted.org/packages/6d/8b/93f38e5cddf76ccfdab70946515b554f25d2b4c95ef9b2f9cfbc43fa7cc1/ninja-1.11.1.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c4bdb9fd2d0c06501ae15abfd23407660e95659e384acd36e013b6dd7d8a8e4", size = 180014 }, + { url = "https://files.pythonhosted.org/packages/7d/1d/713884d0fa3c972164f69d552e0701d30e2bf25eba9ef160bfb3dc69926a/ninja-1.11.1.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:114ed5c61c8474df6a69ab89097a20749b769e2c219a452cb2fadc49b0d581b0", size = 157098 }, + { url = "https://files.pythonhosted.org/packages/c7/22/ecb0f70e77c9e22ee250aa717a608a142756833a34d43943d7d658ee0e56/ninja-1.11.1.3-py3-none-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7fa2247fce98f683bc712562d82b22b8a0a5c000738a13147ca2d1b68c122298", size = 130089 }, + { url = "https://files.pythonhosted.org/packages/ec/a6/3ee846c20ab6ad95b90c5c8703c76cb1f39cc8ce2d1ae468956e3b1b2581/ninja-1.11.1.3-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:a38c6c6c8032bed68b70c3b065d944c35e9f903342875d3a3218c1607987077c", size = 372508 }, + { url = "https://files.pythonhosted.org/packages/95/0d/aa44abe4141f29148ce671ac8c92045878906b18691c6f87a29711c2ff1c/ninja-1.11.1.3-py3-none-musllinux_1_1_i686.whl", hash = "sha256:56ada5d33b8741d298836644042faddebc83ee669782d661e21563034beb5aba", size = 419369 }, + { url = "https://files.pythonhosted.org/packages/f7/ec/48bf5105568ac9bd2016b701777bdd5000cc09a14ac837fef9f15e8d634e/ninja-1.11.1.3-py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:53409151da081f3c198bb0bfc220a7f4e821e022c5b7d29719adda892ddb31bb", size = 420304 }, + { url = "https://files.pythonhosted.org/packages/18/e5/69df63976cf971a03379899f8520a036c9dbab26330b37197512aed5b3df/ninja-1.11.1.3-py3-none-musllinux_1_1_s390x.whl", hash = "sha256:1ad2112c2b0159ed7c4ae3731595191b1546ba62316fc40808edecd0306fefa3", size = 416056 }, + { url = "https://files.pythonhosted.org/packages/6f/4f/bdb401af7ed0e24a3fef058e13a149f2de1ce4b176699076993615d55610/ninja-1.11.1.3-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:28aea3c1c280cba95b8608d50797169f3a34280e3e9a6379b6e340f0c9eaeeb0", size = 379725 }, + { url = "https://files.pythonhosted.org/packages/bd/68/05e7863bf13128c61652eeb3ec7096c3d3a602f32f31752dbfb034e3fa07/ninja-1.11.1.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b6966f83064a88a51693073eea3decd47e08c3965241e09578ef7aa3a7738329", size = 434881 }, + { url = "https://files.pythonhosted.org/packages/bd/ad/edc0d1efe77f29f45bbca2e1dab07ef597f61a88de6e4bccffc0aec2256c/ninja-1.11.1.3-py3-none-win32.whl", hash = "sha256:a4a3b71490557e18c010cbb26bd1ea9a0c32ee67e8f105e9731515b6e0af792e", size = 255988 }, + { url = "https://files.pythonhosted.org/packages/03/93/09a9f7672b4f97438aca6217ac54212a63273f1cd3b46b731d0bb22c53e7/ninja-1.11.1.3-py3-none-win_amd64.whl", hash = "sha256:04d48d14ea7ba11951c156599ab526bdda575450797ff57c6fdf99b2554d09c7", size = 296502 }, + { url = "https://files.pythonhosted.org/packages/d9/9d/0cc1e82849070ff3cbee69f326cb48a839407bcd15d8844443c30a5e7509/ninja-1.11.1.3-py3-none-win_arm64.whl", hash = "sha256:17978ad611d8ead578d83637f5ae80c2261b033db0b493a7ce94f88623f29e1b", size = 270571 }, +] + +[[package]] +name = "nodeenv" +version = "1.9.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 }, +] + +[[package]] +name = "nox" +version = "2024.10.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "argcomplete" }, + { name = "colorlog" }, + { name = "packaging" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "virtualenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/93/4df547afcd56e0b2bbaa99bc2637deb218a01802ed62d80f763189be802c/nox-2024.10.9.tar.gz", hash = "sha256:7aa9dc8d1c27e9f45ab046ffd1c3b2c4f7c91755304769df231308849ebded95", size = 4003197 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/66/00/981f0dcaddf111b6caf6e03d7f7f01b07fd4af117316a7eb1c22039d9e37/nox-2024.10.9-py3-none-any.whl", hash = "sha256:1d36f309a0a2a853e9bccb76bbef6bb118ba92fa92674d15604ca99adeb29eab", size = 61210 }, +] + +[[package]] +name = "numpy" +version = "1.26.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/65/6e/09db70a523a96d25e115e71cc56a6f9031e7b8cd166c1ac8438307c14058/numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010", size = 15786129 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/94/ace0fdea5241a27d13543ee117cbc65868e82213fb31a8eb7fe9ff23f313/numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0", size = 20631468 }, + { url = "https://files.pythonhosted.org/packages/20/f7/b24208eba89f9d1b58c1668bc6c8c4fd472b20c45573cb767f59d49fb0f6/numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a", size = 13966411 }, + { url = "https://files.pythonhosted.org/packages/fc/a5/4beee6488160798683eed5bdb7eead455892c3b4e1f78d79d8d3f3b084ac/numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4", size = 14219016 }, + { url = "https://files.pythonhosted.org/packages/4b/d7/ecf66c1cd12dc28b4040b15ab4d17b773b87fa9d29ca16125de01adb36cd/numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f", size = 18240889 }, + { url = "https://files.pythonhosted.org/packages/24/03/6f229fe3187546435c4f6f89f6d26c129d4f5bed40552899fcf1f0bf9e50/numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a", size = 13876746 }, + { url = "https://files.pythonhosted.org/packages/39/fe/39ada9b094f01f5a35486577c848fe274e374bbf8d8f472e1423a0bbd26d/numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2", size = 18078620 }, + { url = "https://files.pythonhosted.org/packages/d5/ef/6ad11d51197aad206a9ad2286dc1aac6a378059e06e8cf22cd08ed4f20dc/numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07", size = 5972659 }, + { url = "https://files.pythonhosted.org/packages/19/77/538f202862b9183f54108557bfda67e17603fc560c384559e769321c9d92/numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5", size = 15808905 }, + { url = "https://files.pythonhosted.org/packages/11/57/baae43d14fe163fa0e4c47f307b6b2511ab8d7d30177c491960504252053/numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71", size = 20630554 }, + { url = "https://files.pythonhosted.org/packages/1a/2e/151484f49fd03944c4a3ad9c418ed193cfd02724e138ac8a9505d056c582/numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef", size = 13997127 }, + { url = "https://files.pythonhosted.org/packages/79/ae/7e5b85136806f9dadf4878bf73cf223fe5c2636818ba3ab1c585d0403164/numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e", size = 14222994 }, + { url = "https://files.pythonhosted.org/packages/3a/d0/edc009c27b406c4f9cbc79274d6e46d634d139075492ad055e3d68445925/numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5", size = 18252005 }, + { url = "https://files.pythonhosted.org/packages/09/bf/2b1aaf8f525f2923ff6cfcf134ae5e750e279ac65ebf386c75a0cf6da06a/numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a", size = 13885297 }, + { url = "https://files.pythonhosted.org/packages/df/a0/4e0f14d847cfc2a633a1c8621d00724f3206cfeddeb66d35698c4e2cf3d2/numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a", size = 18093567 }, + { url = "https://files.pythonhosted.org/packages/d2/b7/a734c733286e10a7f1a8ad1ae8c90f2d33bf604a96548e0a4a3a6739b468/numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20", size = 5968812 }, + { url = "https://files.pythonhosted.org/packages/3f/6b/5610004206cf7f8e7ad91c5a85a8c71b2f2f8051a0c0c4d5916b76d6cbb2/numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2", size = 15811913 }, +] + +[[package]] +name = "opt-einsum" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/b9/2ac072041e899a52f20cf9510850ff58295003aa75525e58343591b0cbfb/opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac", size = 63004 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932 }, +] + +[[package]] +name = "orderly-set" +version = "5.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5d/9e/8fdcb9ab1b6983cc7c185a4ddafc27518118bd80e9ff2f30aba83636af37/orderly_set-5.2.3.tar.gz", hash = "sha256:571ed97c5a5fca7ddeb6b2d26c19aca896b0ed91f334d9c109edd2f265fb3017", size = 19698 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cc/bb/a3a4eab8430f14c7d1476f9db261d32654cb3d1794c0266a46f6574e1190/orderly_set-5.2.3-py3-none-any.whl", hash = "sha256:d357cedcf67f4ebff0d4cbd5b0997e98eeb65dd24fdf5c990a501ae9e82c7d34", size = 12024 }, +] + +[[package]] +name = "orjson" +version = "3.10.15" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ae/f9/5dea21763eeff8c1590076918a446ea3d6140743e0e36f58f369928ed0f4/orjson-3.10.15.tar.gz", hash = "sha256:05ca7fe452a2e9d8d9d706a2984c95b9c2ebc5db417ce0b7a49b91d50642a23e", size = 5282482 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/09/e5ff18ad009e6f97eb7edc5f67ef98b3ce0c189da9c3eaca1f9587cd4c61/orjson-3.10.15-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:552c883d03ad185f720d0c09583ebde257e41b9521b74ff40e08b7dec4559c04", size = 249532 }, + { url = "https://files.pythonhosted.org/packages/bd/b8/a75883301fe332bd433d9b0ded7d2bb706ccac679602c3516984f8814fb5/orjson-3.10.15-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:616e3e8d438d02e4854f70bfdc03a6bcdb697358dbaa6bcd19cbe24d24ece1f8", size = 125229 }, + { url = "https://files.pythonhosted.org/packages/83/4b/22f053e7a364cc9c685be203b1e40fc5f2b3f164a9b2284547504eec682e/orjson-3.10.15-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7c2c79fa308e6edb0ffab0a31fd75a7841bf2a79a20ef08a3c6e3b26814c8ca8", size = 150148 }, + { url = "https://files.pythonhosted.org/packages/63/64/1b54fc75ca328b57dd810541a4035fe48c12a161d466e3cf5b11a8c25649/orjson-3.10.15-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:73cb85490aa6bf98abd20607ab5c8324c0acb48d6da7863a51be48505646c814", size = 139748 }, + { url = "https://files.pythonhosted.org/packages/5e/ff/ff0c5da781807bb0a5acd789d9a7fbcb57f7b0c6e1916595da1f5ce69f3c/orjson-3.10.15-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:763dadac05e4e9d2bc14938a45a2d0560549561287d41c465d3c58aec818b164", size = 154559 }, + { url = "https://files.pythonhosted.org/packages/4e/9a/11e2974383384ace8495810d4a2ebef5f55aacfc97b333b65e789c9d362d/orjson-3.10.15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a330b9b4734f09a623f74a7490db713695e13b67c959713b78369f26b3dee6bf", size = 130349 }, + { url = "https://files.pythonhosted.org/packages/2d/c4/dd9583aea6aefee1b64d3aed13f51d2aadb014028bc929fe52936ec5091f/orjson-3.10.15-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a61a4622b7ff861f019974f73d8165be1bd9a0855e1cad18ee167acacabeb061", size = 138514 }, + { url = "https://files.pythonhosted.org/packages/53/3e/dcf1729230654f5c5594fc752de1f43dcf67e055ac0d300c8cdb1309269a/orjson-3.10.15-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:acd271247691574416b3228db667b84775c497b245fa275c6ab90dc1ffbbd2b3", size = 130940 }, + { url = "https://files.pythonhosted.org/packages/e8/2b/b9759fe704789937705c8a56a03f6c03e50dff7df87d65cba9a20fec5282/orjson-3.10.15-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:e4759b109c37f635aa5c5cc93a1b26927bfde24b254bcc0e1149a9fada253d2d", size = 414713 }, + { url = "https://files.pythonhosted.org/packages/a7/6b/b9dfdbd4b6e20a59238319eb203ae07c3f6abf07eef909169b7a37ae3bba/orjson-3.10.15-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:9e992fd5cfb8b9f00bfad2fd7a05a4299db2bbe92e6440d9dd2fab27655b3182", size = 141028 }, + { url = "https://files.pythonhosted.org/packages/7c/b5/40f5bbea619c7caf75eb4d652a9821875a8ed04acc45fe3d3ef054ca69fb/orjson-3.10.15-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f95fb363d79366af56c3f26b71df40b9a583b07bbaaf5b317407c4d58497852e", size = 129715 }, + { url = "https://files.pythonhosted.org/packages/38/60/2272514061cbdf4d672edbca6e59c7e01cd1c706e881427d88f3c3e79761/orjson-3.10.15-cp310-cp310-win32.whl", hash = "sha256:f9875f5fea7492da8ec2444839dcc439b0ef298978f311103d0b7dfd775898ab", size = 142473 }, + { url = "https://files.pythonhosted.org/packages/11/5d/be1490ff7eafe7fef890eb4527cf5bcd8cfd6117f3efe42a3249ec847b60/orjson-3.10.15-cp310-cp310-win_amd64.whl", hash = "sha256:17085a6aa91e1cd70ca8533989a18b5433e15d29c574582f76f821737c8d5806", size = 133564 }, + { url = "https://files.pythonhosted.org/packages/7a/a2/21b25ce4a2c71dbb90948ee81bd7a42b4fbfc63162e57faf83157d5540ae/orjson-3.10.15-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:c4cc83960ab79a4031f3119cc4b1a1c627a3dc09df125b27c4201dff2af7eaa6", size = 249533 }, + { url = "https://files.pythonhosted.org/packages/b2/85/2076fc12d8225698a51278009726750c9c65c846eda741e77e1761cfef33/orjson-3.10.15-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ddbeef2481d895ab8be5185f2432c334d6dec1f5d1933a9c83014d188e102cef", size = 125230 }, + { url = "https://files.pythonhosted.org/packages/06/df/a85a7955f11274191eccf559e8481b2be74a7c6d43075d0a9506aa80284d/orjson-3.10.15-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9e590a0477b23ecd5b0ac865b1b907b01b3c5535f5e8a8f6ab0e503efb896334", size = 150148 }, + { url = "https://files.pythonhosted.org/packages/37/b3/94c55625a29b8767c0eed194cb000b3787e3c23b4cdd13be17bae6ccbb4b/orjson-3.10.15-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a6be38bd103d2fd9bdfa31c2720b23b5d47c6796bcb1d1b598e3924441b4298d", size = 139749 }, + { url = "https://files.pythonhosted.org/packages/53/ba/c608b1e719971e8ddac2379f290404c2e914cf8e976369bae3cad88768b1/orjson-3.10.15-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ff4f6edb1578960ed628a3b998fa54d78d9bb3e2eb2cfc5c2a09732431c678d0", size = 154558 }, + { url = "https://files.pythonhosted.org/packages/b2/c4/c1fb835bb23ad788a39aa9ebb8821d51b1c03588d9a9e4ca7de5b354fdd5/orjson-3.10.15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b0482b21d0462eddd67e7fce10b89e0b6ac56570424662b685a0d6fccf581e13", size = 130349 }, + { url = "https://files.pythonhosted.org/packages/78/14/bb2b48b26ab3c570b284eb2157d98c1ef331a8397f6c8bd983b270467f5c/orjson-3.10.15-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bb5cc3527036ae3d98b65e37b7986a918955f85332c1ee07f9d3f82f3a6899b5", size = 138513 }, + { url = "https://files.pythonhosted.org/packages/4a/97/d5b353a5fe532e92c46467aa37e637f81af8468aa894cd77d2ec8a12f99e/orjson-3.10.15-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d569c1c462912acdd119ccbf719cf7102ea2c67dd03b99edcb1a3048651ac96b", size = 130942 }, + { url = "https://files.pythonhosted.org/packages/b5/5d/a067bec55293cca48fea8b9928cfa84c623be0cce8141d47690e64a6ca12/orjson-3.10.15-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:1e6d33efab6b71d67f22bf2962895d3dc6f82a6273a965fab762e64fa90dc399", size = 414717 }, + { url = "https://files.pythonhosted.org/packages/6f/9a/1485b8b05c6b4c4db172c438cf5db5dcfd10e72a9bc23c151a1137e763e0/orjson-3.10.15-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:c33be3795e299f565681d69852ac8c1bc5c84863c0b0030b2b3468843be90388", size = 141033 }, + { url = "https://files.pythonhosted.org/packages/f8/d2/fc67523656e43a0c7eaeae9007c8b02e86076b15d591e9be11554d3d3138/orjson-3.10.15-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:eea80037b9fae5339b214f59308ef0589fc06dc870578b7cce6d71eb2096764c", size = 129720 }, + { url = "https://files.pythonhosted.org/packages/79/42/f58c7bd4e5b54da2ce2ef0331a39ccbbaa7699b7f70206fbf06737c9ed7d/orjson-3.10.15-cp311-cp311-win32.whl", hash = "sha256:d5ac11b659fd798228a7adba3e37c010e0152b78b1982897020a8e019a94882e", size = 142473 }, + { url = "https://files.pythonhosted.org/packages/00/f8/bb60a4644287a544ec81df1699d5b965776bc9848d9029d9f9b3402ac8bb/orjson-3.10.15-cp311-cp311-win_amd64.whl", hash = "sha256:cf45e0214c593660339ef63e875f32ddd5aa3b4adc15e662cdb80dc49e194f8e", size = 133570 }, +] + +[[package]] +name = "packaging" +version = "24.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451 }, +] + +[[package]] +name = "parso" +version = "0.8.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/66/94/68e2e17afaa9169cf6412ab0f28623903be73d1b32e208d9e8e541bb086d/parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d", size = 400609 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650 }, +] + +[[package]] +name = "pathspec" +version = "0.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191 }, +] + +[[package]] +name = "pexpect" +version = "4.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ptyprocess" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772 }, +] + +[[package]] +name = "pillow" +version = "11.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/af/c097e544e7bd278333db77933e535098c259609c4eb3b85381109602fb5b/pillow-11.1.0.tar.gz", hash = "sha256:368da70808b36d73b4b390a8ffac11069f8a5c85f29eff1f1b01bcf3ef5b2a20", size = 46742715 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/50/1c/2dcea34ac3d7bc96a1fd1bd0a6e06a57c67167fec2cff8d95d88229a8817/pillow-11.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:e1abe69aca89514737465752b4bcaf8016de61b3be1397a8fc260ba33321b3a8", size = 3229983 }, + { url = "https://files.pythonhosted.org/packages/14/ca/6bec3df25e4c88432681de94a3531cc738bd85dea6c7aa6ab6f81ad8bd11/pillow-11.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c640e5a06869c75994624551f45e5506e4256562ead981cce820d5ab39ae2192", size = 3101831 }, + { url = "https://files.pythonhosted.org/packages/d4/2c/668e18e5521e46eb9667b09e501d8e07049eb5bfe39d56be0724a43117e6/pillow-11.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a07dba04c5e22824816b2615ad7a7484432d7f540e6fa86af60d2de57b0fcee2", size = 4314074 }, + { url = "https://files.pythonhosted.org/packages/02/80/79f99b714f0fc25f6a8499ecfd1f810df12aec170ea1e32a4f75746051ce/pillow-11.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e267b0ed063341f3e60acd25c05200df4193e15a4a5807075cd71225a2386e26", size = 4394933 }, + { url = "https://files.pythonhosted.org/packages/81/aa/8d4ad25dc11fd10a2001d5b8a80fdc0e564ac33b293bdfe04ed387e0fd95/pillow-11.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:bd165131fd51697e22421d0e467997ad31621b74bfc0b75956608cb2906dda07", size = 4353349 }, + { url = "https://files.pythonhosted.org/packages/84/7a/cd0c3eaf4a28cb2a74bdd19129f7726277a7f30c4f8424cd27a62987d864/pillow-11.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:abc56501c3fd148d60659aae0af6ddc149660469082859fa7b066a298bde9482", size = 4476532 }, + { url = "https://files.pythonhosted.org/packages/8f/8b/a907fdd3ae8f01c7670dfb1499c53c28e217c338b47a813af8d815e7ce97/pillow-11.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:54ce1c9a16a9561b6d6d8cb30089ab1e5eb66918cb47d457bd996ef34182922e", size = 4279789 }, + { url = "https://files.pythonhosted.org/packages/6f/9a/9f139d9e8cccd661c3efbf6898967a9a337eb2e9be2b454ba0a09533100d/pillow-11.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:73ddde795ee9b06257dac5ad42fcb07f3b9b813f8c1f7f870f402f4dc54b5269", size = 4413131 }, + { url = "https://files.pythonhosted.org/packages/a8/68/0d8d461f42a3f37432203c8e6df94da10ac8081b6d35af1c203bf3111088/pillow-11.1.0-cp310-cp310-win32.whl", hash = "sha256:3a5fe20a7b66e8135d7fd617b13272626a28278d0e578c98720d9ba4b2439d49", size = 2291213 }, + { url = "https://files.pythonhosted.org/packages/14/81/d0dff759a74ba87715509af9f6cb21fa21d93b02b3316ed43bda83664db9/pillow-11.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:b6123aa4a59d75f06e9dd3dac5bf8bc9aa383121bb3dd9a7a612e05eabc9961a", size = 2625725 }, + { url = "https://files.pythonhosted.org/packages/ce/1f/8d50c096a1d58ef0584ddc37e6f602828515219e9d2428e14ce50f5ecad1/pillow-11.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:a76da0a31da6fcae4210aa94fd779c65c75786bc9af06289cd1c184451ef7a65", size = 2375213 }, + { url = "https://files.pythonhosted.org/packages/dd/d6/2000bfd8d5414fb70cbbe52c8332f2283ff30ed66a9cde42716c8ecbe22c/pillow-11.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:e06695e0326d05b06833b40b7ef477e475d0b1ba3a6d27da1bb48c23209bf457", size = 3229968 }, + { url = "https://files.pythonhosted.org/packages/d9/45/3fe487010dd9ce0a06adf9b8ff4f273cc0a44536e234b0fad3532a42c15b/pillow-11.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96f82000e12f23e4f29346e42702b6ed9a2f2fea34a740dd5ffffcc8c539eb35", size = 3101806 }, + { url = "https://files.pythonhosted.org/packages/e3/72/776b3629c47d9d5f1c160113158a7a7ad177688d3a1159cd3b62ded5a33a/pillow-11.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3cd561ded2cf2bbae44d4605837221b987c216cff94f49dfeed63488bb228d2", size = 4322283 }, + { url = "https://files.pythonhosted.org/packages/e4/c2/e25199e7e4e71d64eeb869f5b72c7ddec70e0a87926398785ab944d92375/pillow-11.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f189805c8be5ca5add39e6f899e6ce2ed824e65fb45f3c28cb2841911da19070", size = 4402945 }, + { url = "https://files.pythonhosted.org/packages/c1/ed/51d6136c9d5911f78632b1b86c45241c712c5a80ed7fa7f9120a5dff1eba/pillow-11.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dd0052e9db3474df30433f83a71b9b23bd9e4ef1de13d92df21a52c0303b8ab6", size = 4361228 }, + { url = "https://files.pythonhosted.org/packages/48/a4/fbfe9d5581d7b111b28f1d8c2762dee92e9821bb209af9fa83c940e507a0/pillow-11.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:837060a8599b8f5d402e97197d4924f05a2e0d68756998345c829c33186217b1", size = 4484021 }, + { url = "https://files.pythonhosted.org/packages/39/db/0b3c1a5018117f3c1d4df671fb8e47d08937f27519e8614bbe86153b65a5/pillow-11.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:aa8dd43daa836b9a8128dbe7d923423e5ad86f50a7a14dc688194b7be5c0dea2", size = 4287449 }, + { url = "https://files.pythonhosted.org/packages/d9/58/bc128da7fea8c89fc85e09f773c4901e95b5936000e6f303222490c052f3/pillow-11.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0a2f91f8a8b367e7a57c6e91cd25af510168091fb89ec5146003e424e1558a96", size = 4419972 }, + { url = "https://files.pythonhosted.org/packages/5f/bb/58f34379bde9fe197f51841c5bbe8830c28bbb6d3801f16a83b8f2ad37df/pillow-11.1.0-cp311-cp311-win32.whl", hash = "sha256:c12fc111ef090845de2bb15009372175d76ac99969bdf31e2ce9b42e4b8cd88f", size = 2291201 }, + { url = "https://files.pythonhosted.org/packages/3a/c6/fce9255272bcf0c39e15abd2f8fd8429a954cf344469eaceb9d0d1366913/pillow-11.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fbd43429d0d7ed6533b25fc993861b8fd512c42d04514a0dd6337fb3ccf22761", size = 2625686 }, + { url = "https://files.pythonhosted.org/packages/c8/52/8ba066d569d932365509054859f74f2a9abee273edcef5cd75e4bc3e831e/pillow-11.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:f7955ecf5609dee9442cbface754f2c6e541d9e6eda87fad7f7a989b0bdb9d71", size = 2375194 }, + { url = "https://files.pythonhosted.org/packages/fa/c5/389961578fb677b8b3244fcd934f720ed25a148b9a5cc81c91bdf59d8588/pillow-11.1.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8c730dc3a83e5ac137fbc92dfcfe1511ce3b2b5d7578315b63dbbb76f7f51d90", size = 3198345 }, + { url = "https://files.pythonhosted.org/packages/c4/fa/803c0e50ffee74d4b965229e816af55276eac1d5806712de86f9371858fd/pillow-11.1.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7d33d2fae0e8b170b6a6c57400e077412240f6f5bb2a342cf1ee512a787942bb", size = 3072938 }, + { url = "https://files.pythonhosted.org/packages/dc/67/2a3a5f8012b5d8c63fe53958ba906c1b1d0482ebed5618057ef4d22f8076/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a8d65b38173085f24bc07f8b6c505cbb7418009fa1a1fcb111b1f4961814a442", size = 3400049 }, + { url = "https://files.pythonhosted.org/packages/e5/a0/514f0d317446c98c478d1872497eb92e7cde67003fed74f696441e647446/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:015c6e863faa4779251436db398ae75051469f7c903b043a48f078e437656f83", size = 3422431 }, + { url = "https://files.pythonhosted.org/packages/cd/00/20f40a935514037b7d3f87adfc87d2c538430ea625b63b3af8c3f5578e72/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d44ff19eea13ae4acdaaab0179fa68c0c6f2f45d66a4d8ec1eda7d6cecbcc15f", size = 3446208 }, + { url = "https://files.pythonhosted.org/packages/28/3c/7de681727963043e093c72e6c3348411b0185eab3263100d4490234ba2f6/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d3d8da4a631471dfaf94c10c85f5277b1f8e42ac42bade1ac67da4b4a7359b73", size = 3509746 }, + { url = "https://files.pythonhosted.org/packages/41/67/936f9814bdd74b2dfd4822f1f7725ab5d8ff4103919a1664eb4874c58b2f/pillow-11.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:4637b88343166249fe8aa94e7c4a62a180c4b3898283bb5d3d2fd5fe10d8e4e0", size = 2626353 }, +] + +[[package]] +name = "pip" +version = "25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/47/3e/68beeeeb306ea20ffd30b3ed993f531d16cd884ec4f60c9b1e238f69f2af/pip-25.0.tar.gz", hash = "sha256:8e0a97f7b4c47ae4a494560da84775e9e2f671d415d8d828e052efefb206b30b", size = 1950328 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/8a/1ddf40be20103bcc605db840e9ade09c8e8c9f920a03e9cfe88eae97a058/pip-25.0-py3-none-any.whl", hash = "sha256:b6eb97a803356a52b2dd4bb73ba9e65b2ba16caa6bcb25a7497350a4e5859b65", size = 1841506 }, +] + +[[package]] +name = "platformdirs" +version = "4.3.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/13/fc/128cc9cb8f03208bdbf93d3aa862e16d376844a14f9a0ce5cf4507372de4/platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907", size = 21302 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/a6/bc1012356d8ece4d66dd75c4b9fc6c1f6650ddd5991e421177d9f8f671be/platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb", size = 18439 }, +] + +[[package]] +name = "pluggy" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, +] + +[[package]] +name = "ply" +version = "3.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e5/69/882ee5c9d017149285cab114ebeab373308ef0f874fcdac9beb90e0ac4da/ply-3.11.tar.gz", hash = "sha256:00c7c1aaa88358b9c765b6d3000c6eec0ba42abca5351b095321aef446081da3", size = 159130 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/58/35da89ee790598a0700ea49b2a66594140f44dec458c07e8e3d4979137fc/ply-3.11-py2.py3-none-any.whl", hash = "sha256:096f9b8350b65ebd2fd1346b12452efe5b9607f7482813ffca50c22722a807ce", size = 49567 }, +] + +[[package]] +name = "pre-commit" +version = "4.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cfgv" }, + { name = "identify" }, + { name = "nodeenv" }, + { name = "pyyaml" }, + { name = "virtualenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2a/13/b62d075317d8686071eb843f0bb1f195eb332f48869d3c31a4c6f1e063ac/pre_commit-4.1.0.tar.gz", hash = "sha256:ae3f018575a588e30dfddfab9a05448bfbd6b73d78709617b5a2b853549716d4", size = 193330 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/b3/df14c580d82b9627d173ceea305ba898dca135feb360b6d84019d0803d3b/pre_commit-4.1.0-py2.py3-none-any.whl", hash = "sha256:d29e7cb346295bcc1cc75fc3e92e343495e3ea0196c9ec6ba53f49f10ab6ae7b", size = 220560 }, +] + +[[package]] +name = "prompt-toolkit" +version = "3.0.50" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/e1/bd15cb8ffdcfeeb2bdc215de3c3cffca11408d829e4b8416dcfe71ba8854/prompt_toolkit-3.0.50.tar.gz", hash = "sha256:544748f3860a2623ca5cd6d2795e7a14f3d0e1c3c9728359013f79877fc89bab", size = 429087 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e4/ea/d836f008d33151c7a1f62caf3d8dd782e4d15f6a43897f64480c2b8de2ad/prompt_toolkit-3.0.50-py3-none-any.whl", hash = "sha256:9b6427eb19e479d98acff65196a307c555eb567989e6d88ebbb1b509d9779198", size = 387816 }, +] + +[[package]] +name = "psutil" +version = "6.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1f/5a/07871137bb752428aa4b659f910b399ba6f291156bdea939be3e96cae7cb/psutil-6.1.1.tar.gz", hash = "sha256:cf8496728c18f2d0b45198f06895be52f36611711746b7f30c464b422b50e2f5", size = 508502 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/99/ca79d302be46f7bdd8321089762dd4476ee725fce16fc2b2e1dbba8cac17/psutil-6.1.1-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:fc0ed7fe2231a444fc219b9c42d0376e0a9a1a72f16c5cfa0f68d19f1a0663e8", size = 247511 }, + { url = "https://files.pythonhosted.org/packages/0b/6b/73dbde0dd38f3782905d4587049b9be64d76671042fdcaf60e2430c6796d/psutil-6.1.1-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:0bdd4eab935276290ad3cb718e9809412895ca6b5b334f5a9111ee6d9aff9377", size = 248985 }, + { url = "https://files.pythonhosted.org/packages/17/38/c319d31a1d3f88c5b79c68b3116c129e5133f1822157dd6da34043e32ed6/psutil-6.1.1-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6e06c20c05fe95a3d7302d74e7097756d4ba1247975ad6905441ae1b5b66003", size = 284488 }, + { url = "https://files.pythonhosted.org/packages/9c/39/0f88a830a1c8a3aba27fededc642da37613c57cbff143412e3536f89784f/psutil-6.1.1-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97f7cb9921fbec4904f522d972f0c0e1f4fabbdd4e0287813b21215074a0f160", size = 287477 }, + { url = "https://files.pythonhosted.org/packages/47/da/99f4345d4ddf2845cb5b5bd0d93d554e84542d116934fde07a0c50bd4e9f/psutil-6.1.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33431e84fee02bc84ea36d9e2c4a6d395d479c9dd9bba2376c1f6ee8f3a4e0b3", size = 289017 }, + { url = "https://files.pythonhosted.org/packages/38/53/bd755c2896f4461fd4f36fa6a6dcb66a88a9e4b9fd4e5b66a77cf9d4a584/psutil-6.1.1-cp37-abi3-win32.whl", hash = "sha256:eaa912e0b11848c4d9279a93d7e2783df352b082f40111e078388701fd479e53", size = 250602 }, + { url = "https://files.pythonhosted.org/packages/7b/d7/7831438e6c3ebbfa6e01a927127a6cb42ad3ab844247f3c5b96bea25d73d/psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649", size = 254444 }, +] + +[[package]] +name = "ptyprocess" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/e5/16ff212c1e452235a90aeb09066144d0c5a6a8c0834397e03f5224495c4e/ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220", size = 70762 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35", size = 13993 }, +] + +[[package]] +name = "pure-eval" +version = "0.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/05/0a34433a064256a578f1783a10da6df098ceaa4a57bbeaa96a6c0352786b/pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42", size = 19752 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842 }, +] + +[[package]] +name = "py-cpuinfo" +version = "9.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/37/a8/d832f7293ebb21690860d2e01d8115e5ff6f2ae8bbdc953f0eb0fa4bd2c7/py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690", size = 104716 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335 }, +] + +[[package]] +name = "pybind11" +version = "2.13.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d2/c1/72b9622fcb32ff98b054f724e213c7f70d6898baa714f4516288456ceaba/pybind11-2.13.6.tar.gz", hash = "sha256:ba6af10348c12b24e92fa086b39cfba0eff619b61ac77c406167d813b096d39a", size = 218403 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/13/2f/0f24b288e2ce56f51c920137620b4434a38fd80583dbbe24fc2a1656c388/pybind11-2.13.6-py3-none-any.whl", hash = "sha256:237c41e29157b962835d356b370ededd57594a26d5894a795960f0047cb5caf5", size = 243282 }, +] + +[[package]] +name = "pycparser" +version = "2.22" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1d/b2/31537cf4b1ca988837256c910a668b553fceb8f069bedc4b1c826024b52c/pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6", size = 172736 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/13/a3/a812df4e2dd5696d1f351d58b8fe16a405b234ad2886a0dab9183fb78109/pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc", size = 117552 }, +] + +[[package]] +name = "pydantic" +version = "2.10.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b7/ae/d5220c5c52b158b1de7ca89fc5edb72f304a70a4c540c84c8844bf4008de/pydantic-2.10.6.tar.gz", hash = "sha256:ca5daa827cce33de7a42be142548b0096bf05a7e7b365aebfa5f8eeec7128236", size = 761681 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/3c/8cc1cc84deffa6e25d2d0c688ebb80635dfdbf1dbea3e30c541c8cf4d860/pydantic-2.10.6-py3-none-any.whl", hash = "sha256:427d664bf0b8a2b34ff5dd0f5a18df00591adcee7198fbd71981054cef37b584", size = 431696 }, +] + +[[package]] +name = "pydantic-core" +version = "2.27.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fc/01/f3e5ac5e7c25833db5eb555f7b7ab24cd6f8c322d3a3ad2d67a952dc0abc/pydantic_core-2.27.2.tar.gz", hash = "sha256:eb026e5a4c1fee05726072337ff51d1efb6f59090b7da90d30ea58625b1ffb39", size = 413443 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3a/bc/fed5f74b5d802cf9a03e83f60f18864e90e3aed7223adaca5ffb7a8d8d64/pydantic_core-2.27.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2d367ca20b2f14095a8f4fa1210f5a7b78b8a20009ecced6b12818f455b1e9fa", size = 1895938 }, + { url = "https://files.pythonhosted.org/packages/71/2a/185aff24ce844e39abb8dd680f4e959f0006944f4a8a0ea372d9f9ae2e53/pydantic_core-2.27.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:491a2b73db93fab69731eaee494f320faa4e093dbed776be1a829c2eb222c34c", size = 1815684 }, + { url = "https://files.pythonhosted.org/packages/c3/43/fafabd3d94d159d4f1ed62e383e264f146a17dd4d48453319fd782e7979e/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7969e133a6f183be60e9f6f56bfae753585680f3b7307a8e555a948d443cc05a", size = 1829169 }, + { url = "https://files.pythonhosted.org/packages/a2/d1/f2dfe1a2a637ce6800b799aa086d079998959f6f1215eb4497966efd2274/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3de9961f2a346257caf0aa508a4da705467f53778e9ef6fe744c038119737ef5", size = 1867227 }, + { url = "https://files.pythonhosted.org/packages/7d/39/e06fcbcc1c785daa3160ccf6c1c38fea31f5754b756e34b65f74e99780b5/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e2bb4d3e5873c37bb3dd58714d4cd0b0e6238cebc4177ac8fe878f8b3aa8e74c", size = 2037695 }, + { url = "https://files.pythonhosted.org/packages/7a/67/61291ee98e07f0650eb756d44998214231f50751ba7e13f4f325d95249ab/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:280d219beebb0752699480fe8f1dc61ab6615c2046d76b7ab7ee38858de0a4e7", size = 2741662 }, + { url = "https://files.pythonhosted.org/packages/32/90/3b15e31b88ca39e9e626630b4c4a1f5a0dfd09076366f4219429e6786076/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47956ae78b6422cbd46f772f1746799cbb862de838fd8d1fbd34a82e05b0983a", size = 1993370 }, + { url = "https://files.pythonhosted.org/packages/ff/83/c06d333ee3a67e2e13e07794995c1535565132940715931c1c43bfc85b11/pydantic_core-2.27.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:14d4a5c49d2f009d62a2a7140d3064f686d17a5d1a268bc641954ba181880236", size = 1996813 }, + { url = "https://files.pythonhosted.org/packages/7c/f7/89be1c8deb6e22618a74f0ca0d933fdcb8baa254753b26b25ad3acff8f74/pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:337b443af21d488716f8d0b6164de833e788aa6bd7e3a39c005febc1284f4962", size = 2005287 }, + { url = "https://files.pythonhosted.org/packages/b7/7d/8eb3e23206c00ef7feee17b83a4ffa0a623eb1a9d382e56e4aa46fd15ff2/pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:03d0f86ea3184a12f41a2d23f7ccb79cdb5a18e06993f8a45baa8dfec746f0e9", size = 2128414 }, + { url = "https://files.pythonhosted.org/packages/4e/99/fe80f3ff8dd71a3ea15763878d464476e6cb0a2db95ff1c5c554133b6b83/pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7041c36f5680c6e0f08d922aed302e98b3745d97fe1589db0a3eebf6624523af", size = 2155301 }, + { url = "https://files.pythonhosted.org/packages/2b/a3/e50460b9a5789ca1451b70d4f52546fa9e2b420ba3bfa6100105c0559238/pydantic_core-2.27.2-cp310-cp310-win32.whl", hash = "sha256:50a68f3e3819077be2c98110c1f9dcb3817e93f267ba80a2c05bb4f8799e2ff4", size = 1816685 }, + { url = "https://files.pythonhosted.org/packages/57/4c/a8838731cb0f2c2a39d3535376466de6049034d7b239c0202a64aaa05533/pydantic_core-2.27.2-cp310-cp310-win_amd64.whl", hash = "sha256:e0fd26b16394ead34a424eecf8a31a1f5137094cabe84a1bcb10fa6ba39d3d31", size = 1982876 }, + { url = "https://files.pythonhosted.org/packages/c2/89/f3450af9d09d44eea1f2c369f49e8f181d742f28220f88cc4dfaae91ea6e/pydantic_core-2.27.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:8e10c99ef58cfdf2a66fc15d66b16c4a04f62bca39db589ae8cba08bc55331bc", size = 1893421 }, + { url = "https://files.pythonhosted.org/packages/9e/e3/71fe85af2021f3f386da42d291412e5baf6ce7716bd7101ea49c810eda90/pydantic_core-2.27.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:26f32e0adf166a84d0cb63be85c562ca8a6fa8de28e5f0d92250c6b7e9e2aff7", size = 1814998 }, + { url = "https://files.pythonhosted.org/packages/a6/3c/724039e0d848fd69dbf5806894e26479577316c6f0f112bacaf67aa889ac/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c19d1ea0673cd13cc2f872f6c9ab42acc4e4f492a7ca9d3795ce2b112dd7e15", size = 1826167 }, + { url = "https://files.pythonhosted.org/packages/2b/5b/1b29e8c1fb5f3199a9a57c1452004ff39f494bbe9bdbe9a81e18172e40d3/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5e68c4446fe0810e959cdff46ab0a41ce2f2c86d227d96dc3847af0ba7def306", size = 1865071 }, + { url = "https://files.pythonhosted.org/packages/89/6c/3985203863d76bb7d7266e36970d7e3b6385148c18a68cc8915fd8c84d57/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d9640b0059ff4f14d1f37321b94061c6db164fbe49b334b31643e0528d100d99", size = 2036244 }, + { url = "https://files.pythonhosted.org/packages/0e/41/f15316858a246b5d723f7d7f599f79e37493b2e84bfc789e58d88c209f8a/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:40d02e7d45c9f8af700f3452f329ead92da4c5f4317ca9b896de7ce7199ea459", size = 2737470 }, + { url = "https://files.pythonhosted.org/packages/a8/7c/b860618c25678bbd6d1d99dbdfdf0510ccb50790099b963ff78a124b754f/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c1fd185014191700554795c99b347d64f2bb637966c4cfc16998a0ca700d048", size = 1992291 }, + { url = "https://files.pythonhosted.org/packages/bf/73/42c3742a391eccbeab39f15213ecda3104ae8682ba3c0c28069fbcb8c10d/pydantic_core-2.27.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d81d2068e1c1228a565af076598f9e7451712700b673de8f502f0334f281387d", size = 1994613 }, + { url = "https://files.pythonhosted.org/packages/94/7a/941e89096d1175d56f59340f3a8ebaf20762fef222c298ea96d36a6328c5/pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1a4207639fb02ec2dbb76227d7c751a20b1a6b4bc52850568e52260cae64ca3b", size = 2002355 }, + { url = "https://files.pythonhosted.org/packages/6e/95/2359937a73d49e336a5a19848713555605d4d8d6940c3ec6c6c0ca4dcf25/pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:3de3ce3c9ddc8bbd88f6e0e304dea0e66d843ec9de1b0042b0911c1663ffd474", size = 2126661 }, + { url = "https://files.pythonhosted.org/packages/2b/4c/ca02b7bdb6012a1adef21a50625b14f43ed4d11f1fc237f9d7490aa5078c/pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:30c5f68ded0c36466acede341551106821043e9afaad516adfb6e8fa80a4e6a6", size = 2153261 }, + { url = "https://files.pythonhosted.org/packages/72/9d/a241db83f973049a1092a079272ffe2e3e82e98561ef6214ab53fe53b1c7/pydantic_core-2.27.2-cp311-cp311-win32.whl", hash = "sha256:c70c26d2c99f78b125a3459f8afe1aed4d9687c24fd677c6a4436bc042e50d6c", size = 1812361 }, + { url = "https://files.pythonhosted.org/packages/e8/ef/013f07248041b74abd48a385e2110aa3a9bbfef0fbd97d4e6d07d2f5b89a/pydantic_core-2.27.2-cp311-cp311-win_amd64.whl", hash = "sha256:08e125dbdc505fa69ca7d9c499639ab6407cfa909214d500897d02afb816e7cc", size = 1982484 }, + { url = "https://files.pythonhosted.org/packages/10/1c/16b3a3e3398fd29dca77cea0a1d998d6bde3902fa2706985191e2313cc76/pydantic_core-2.27.2-cp311-cp311-win_arm64.whl", hash = "sha256:26f0d68d4b235a2bae0c3fc585c585b4ecc51382db0e3ba402a22cbc440915e4", size = 1867102 }, + { url = "https://files.pythonhosted.org/packages/46/72/af70981a341500419e67d5cb45abe552a7c74b66326ac8877588488da1ac/pydantic_core-2.27.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:2bf14caea37e91198329b828eae1618c068dfb8ef17bb33287a7ad4b61ac314e", size = 1891159 }, + { url = "https://files.pythonhosted.org/packages/ad/3d/c5913cccdef93e0a6a95c2d057d2c2cba347815c845cda79ddd3c0f5e17d/pydantic_core-2.27.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:b0cb791f5b45307caae8810c2023a184c74605ec3bcbb67d13846c28ff731ff8", size = 1768331 }, + { url = "https://files.pythonhosted.org/packages/f6/f0/a3ae8fbee269e4934f14e2e0e00928f9346c5943174f2811193113e58252/pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:688d3fd9fcb71f41c4c015c023d12a79d1c4c0732ec9eb35d96e3388a120dcf3", size = 1822467 }, + { url = "https://files.pythonhosted.org/packages/d7/7a/7bbf241a04e9f9ea24cd5874354a83526d639b02674648af3f350554276c/pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d591580c34f4d731592f0e9fe40f9cc1b430d297eecc70b962e93c5c668f15f", size = 1979797 }, + { url = "https://files.pythonhosted.org/packages/4f/5f/4784c6107731f89e0005a92ecb8a2efeafdb55eb992b8e9d0a2be5199335/pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:82f986faf4e644ffc189a7f1aafc86e46ef70372bb153e7001e8afccc6e54133", size = 1987839 }, + { url = "https://files.pythonhosted.org/packages/6d/a7/61246562b651dff00de86a5f01b6e4befb518df314c54dec187a78d81c84/pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:bec317a27290e2537f922639cafd54990551725fc844249e64c523301d0822fc", size = 1998861 }, + { url = "https://files.pythonhosted.org/packages/86/aa/837821ecf0c022bbb74ca132e117c358321e72e7f9702d1b6a03758545e2/pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:0296abcb83a797db256b773f45773da397da75a08f5fcaef41f2044adec05f50", size = 2116582 }, + { url = "https://files.pythonhosted.org/packages/81/b0/5e74656e95623cbaa0a6278d16cf15e10a51f6002e3ec126541e95c29ea3/pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0d75070718e369e452075a6017fbf187f788e17ed67a3abd47fa934d001863d9", size = 2151985 }, + { url = "https://files.pythonhosted.org/packages/63/37/3e32eeb2a451fddaa3898e2163746b0cffbbdbb4740d38372db0490d67f3/pydantic_core-2.27.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7e17b560be3c98a8e3aa66ce828bdebb9e9ac6ad5466fba92eb74c4c95cb1151", size = 2004715 }, +] + +[[package]] +name = "pydantic-settings" +version = "2.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/73/7b/c58a586cd7d9ac66d2ee4ba60ca2d241fa837c02bca9bea80a9a8c3d22a9/pydantic_settings-2.7.1.tar.gz", hash = "sha256:10c9caad35e64bfb3c2fbf70a078c0e25cc92499782e5200747f942a065dec93", size = 79920 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b4/46/93416fdae86d40879714f72956ac14df9c7b76f7d41a4d68aa9f71a0028b/pydantic_settings-2.7.1-py3-none-any.whl", hash = "sha256:590be9e6e24d06db33a4262829edef682500ef008565a969c73d39d5f8bfb3fd", size = 29718 }, +] + +[[package]] +name = "pydot" +version = "3.0.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyparsing" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/dd/e0e6a4fb84c22050f6a9701ad9fd6a67ef82faa7ba97b97eb6fdc6b49b34/pydot-3.0.4.tar.gz", hash = "sha256:3ce88b2558f3808b0376f22bfa6c263909e1c3981e2a7b629b65b451eee4a25d", size = 168167 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b0/5f/1ebfd430df05c4f9e438dd3313c4456eab937d976f6ab8ce81a98f9fb381/pydot-3.0.4-py3-none-any.whl", hash = "sha256:bfa9c3fc0c44ba1d132adce131802d7df00429d1a79cc0346b0a5cd374dbe9c6", size = 35776 }, +] + +[[package]] +name = "pygls" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cattrs" }, + { name = "lsprotocol" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/b9/41d173dad9eaa9db9c785a85671fc3d68961f08d67706dc2e79011e10b5c/pygls-1.3.1.tar.gz", hash = "sha256:140edceefa0da0e9b3c533547c892a42a7d2fd9217ae848c330c53d266a55018", size = 45527 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/19/b74a10dd24548e96e8c80226cbacb28b021bc3a168a7d2709fb0d0185348/pygls-1.3.1-py3-none-any.whl", hash = "sha256:6e00f11efc56321bdeb6eac04f6d86131f654c7d49124344a9ebb968da3dd91e", size = 56031 }, +] + +[[package]] +name = "pygments" +version = "2.19.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7c/2d/c3338d48ea6cc0feb8446d8e6937e1408088a72a39937982cc6111d17f84/pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f", size = 4968581 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c", size = 1225293 }, +] + +[[package]] +name = "pyparsing" +version = "3.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/1a/3544f4f299a47911c2ab3710f534e52fea62a633c96806995da5d25be4b2/pyparsing-3.2.1.tar.gz", hash = "sha256:61980854fd66de3a90028d679a954d5f2623e83144b5afe5ee86f43d762e5f0a", size = 1067694 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1c/a7/c8a2d361bf89c0d9577c934ebb7421b25dc84bf3a8e3ac0a40aed9acc547/pyparsing-3.2.1-py3-none-any.whl", hash = "sha256:506ff4f4386c4cec0590ec19e6302d3aedb992fdc02c761e90416f158dacf8e1", size = 107716 }, +] + +[[package]] +name = "pyreadline" +version = "2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bc/7c/d724ef1ec3ab2125f38a1d53285745445ec4a8f19b9bb0761b4064316679/pyreadline-2.1.zip", hash = "sha256:4530592fc2e85b25b1a9f79664433da09237c1a270e4d78ea5aa3a2c7229e2d1", size = 109189 } + +[[package]] +name = "pyspellchecker" +version = "0.8.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/42/5d/86d94aceb9c0813f27004ec71c036d8ec6a6324d989854ff0fe13fe036dc/pyspellchecker-0.8.2.tar.gz", hash = "sha256:2b026be14a162ba810bdda8e5454c56e364f42d3b9e14aeff31706e5ebcdc78f", size = 7149207 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/8e/7c79443d302a80cfd59bc365938d51e36e7e9aa7ce8ab1d8a0ca0c8e6065/pyspellchecker-0.8.2-py3-none-any.whl", hash = "sha256:4fee22e1859c5153c3bc3953ac3041bf07d4541520b7e01901e955062022290a", size = 7147898 }, +] + +[[package]] +name = "pytest" +version = "8.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 }, +] + +[[package]] +name = "pytest-benchmark" +version = "5.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "py-cpuinfo" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/39/d0/a8bd08d641b393db3be3819b03e2d9bb8760ca8479080a26a5f6e540e99c/pytest-benchmark-5.1.0.tar.gz", hash = "sha256:9ea661cdc292e8231f7cd4c10b0319e56a2118e2c09d9f50e1b3d150d2aca105", size = 337810 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/d6/b41653199ea09d5969d4e385df9bbfd9a100f28ca7e824ce7c0a016e3053/pytest_benchmark-5.1.0-py3-none-any.whl", hash = "sha256:922de2dfa3033c227c96da942d1878191afa135a29485fb942e85dff1c592c89", size = 44259 }, +] + +[[package]] +name = "pytest-cache" +version = "1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "execnet" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d1/15/082fd0428aab33d2bafa014f3beb241830427ba803a8912a5aaeaf3a5663/pytest-cache-1.0.tar.gz", hash = "sha256:be7468edd4d3d83f1e844959fd6e3fd28e77a481440a7118d430130ea31b07a9", size = 16242 } + +[[package]] +name = "pytest-cov" +version = "6.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coverage", extra = ["toml"] }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/be/45/9b538de8cef30e17c7b45ef42f538a94889ed6a16f2387a6c89e73220651/pytest-cov-6.0.0.tar.gz", hash = "sha256:fde0b595ca248bb8e2d76f020b465f3b107c9632e6a1d1705f17834c89dcadc0", size = 66945 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/36/3b/48e79f2cd6a61dbbd4807b4ed46cb564b4fd50a76166b1c4ea5c1d9e2371/pytest_cov-6.0.0-py3-none-any.whl", hash = "sha256:eee6f1b9e61008bd34975a4d5bab25801eb31898b032dd55addc93e96fcaaa35", size = 22949 }, +] + +[[package]] +name = "pytest-factoryboy" +version = "2.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "factory-boy" }, + { name = "inflection" }, + { name = "packaging" }, + { name = "pytest" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a6/bc/179653e8cce651575ac95377e4fdf9afd3c4821ab4bba101aae913ebcc27/pytest_factoryboy-2.7.0.tar.gz", hash = "sha256:67fc54ec8669a3feb8ac60094dd57cd71eb0b20b2c319d2957873674c776a77b", size = 17398 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/56/d3ef25286dc8df9d1da0b325ee4b1b1ffd9736e44f9b30cfbe464e9f4f14/pytest_factoryboy-2.7.0-py3-none-any.whl", hash = "sha256:bf3222db22d954fbf46f4bff902a0a8d82f3fc3594a47c04bbdc0546ff4c59a6", size = 16268 }, +] + +[[package]] +name = "pytest-instafail" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/bd/e0ba6c3cd20b9aa445f0af229f3a9582cce589f083537978a23e6f14e310/pytest-instafail-0.5.0.tar.gz", hash = "sha256:33a606f7e0c8e646dc3bfee0d5e3a4b7b78ef7c36168cfa1f3d93af7ca706c9e", size = 5849 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e8/c0/c32dc39fc172e684fdb3d30169843efb65c067be1e12689af4345731126e/pytest_instafail-0.5.0-py3-none-any.whl", hash = "sha256:6855414487e9e4bb76a118ce952c3c27d3866af15487506c4ded92eb72387819", size = 4176 }, +] + +[[package]] +name = "pytest-xdist" +version = "3.6.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "execnet" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/41/c4/3c310a19bc1f1e9ef50075582652673ef2bfc8cd62afef9585683821902f/pytest_xdist-3.6.1.tar.gz", hash = "sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d", size = 84060 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6d/82/1d96bf03ee4c0fdc3c0cbe61470070e659ca78dc0086fb88b66c185e2449/pytest_xdist-3.6.1-py3-none-any.whl", hash = "sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7", size = 46108 }, +] + +[package.optional-dependencies] +psutil = [ + { name = "psutil" }, +] + +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892 }, +] + +[[package]] +name = "python-dotenv" +version = "1.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bc/57/e84d88dfe0aec03b7a2d4327012c1627ab5f03652216c63d49846d7a6c58/python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca", size = 39115 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863 }, +] + +[[package]] +name = "pywin32" +version = "308" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/a6/3e9f2c474895c1bb61b11fa9640be00067b5c5b363c501ee9c3fa53aec01/pywin32-308-cp310-cp310-win32.whl", hash = "sha256:796ff4426437896550d2981b9c2ac0ffd75238ad9ea2d3bfa67a1abd546d262e", size = 5927028 }, + { url = "https://files.pythonhosted.org/packages/d9/b4/84e2463422f869b4b718f79eb7530a4c1693e96b8a4e5e968de38be4d2ba/pywin32-308-cp310-cp310-win_amd64.whl", hash = "sha256:4fc888c59b3c0bef905ce7eb7e2106a07712015ea1c8234b703a088d46110e8e", size = 6558484 }, + { url = "https://files.pythonhosted.org/packages/9f/8f/fb84ab789713f7c6feacaa08dad3ec8105b88ade8d1c4f0f0dfcaaa017d6/pywin32-308-cp310-cp310-win_arm64.whl", hash = "sha256:a5ab5381813b40f264fa3495b98af850098f814a25a63589a8e9eb12560f450c", size = 7971454 }, + { url = "https://files.pythonhosted.org/packages/eb/e2/02652007469263fe1466e98439831d65d4ca80ea1a2df29abecedf7e47b7/pywin32-308-cp311-cp311-win32.whl", hash = "sha256:5d8c8015b24a7d6855b1550d8e660d8daa09983c80e5daf89a273e5c6fb5095a", size = 5928156 }, + { url = "https://files.pythonhosted.org/packages/48/ef/f4fb45e2196bc7ffe09cad0542d9aff66b0e33f6c0954b43e49c33cad7bd/pywin32-308-cp311-cp311-win_amd64.whl", hash = "sha256:575621b90f0dc2695fec346b2d6302faebd4f0f45c05ea29404cefe35d89442b", size = 6559559 }, + { url = "https://files.pythonhosted.org/packages/79/ef/68bb6aa865c5c9b11a35771329e95917b5559845bd75b65549407f9fc6b4/pywin32-308-cp311-cp311-win_arm64.whl", hash = "sha256:100a5442b7332070983c4cd03f2e906a5648a5104b8a7f50175f7906efd16bb6", size = 7972495 }, +] + +[[package]] +name = "pyyaml" +version = "6.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/54/ed/79a089b6be93607fa5cdaedf301d7dfb23af5f25c398d5ead2525b063e17/pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e", size = 130631 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/95/a3fac87cb7158e231b5a6012e438c647e1a87f09f8e0d123acec8ab8bf71/PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086", size = 184199 }, + { url = "https://files.pythonhosted.org/packages/c7/7a/68bd47624dab8fd4afbfd3c48e3b79efe09098ae941de5b58abcbadff5cb/PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf", size = 171758 }, + { url = "https://files.pythonhosted.org/packages/49/ee/14c54df452143b9ee9f0f29074d7ca5516a36edb0b4cc40c3f280131656f/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237", size = 718463 }, + { url = "https://files.pythonhosted.org/packages/4d/61/de363a97476e766574650d742205be468921a7b532aa2499fcd886b62530/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b", size = 719280 }, + { url = "https://files.pythonhosted.org/packages/6b/4e/1523cb902fd98355e2e9ea5e5eb237cbc5f3ad5f3075fa65087aa0ecb669/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed", size = 751239 }, + { url = "https://files.pythonhosted.org/packages/b7/33/5504b3a9a4464893c32f118a9cc045190a91637b119a9c881da1cf6b7a72/PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180", size = 695802 }, + { url = "https://files.pythonhosted.org/packages/5c/20/8347dcabd41ef3a3cdc4f7b7a2aff3d06598c8779faa189cdbf878b626a4/PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68", size = 720527 }, + { url = "https://files.pythonhosted.org/packages/be/aa/5afe99233fb360d0ff37377145a949ae258aaab831bde4792b32650a4378/PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99", size = 144052 }, + { url = "https://files.pythonhosted.org/packages/b5/84/0fa4b06f6d6c958d207620fc60005e241ecedceee58931bb20138e1e5776/PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e", size = 161774 }, + { url = "https://files.pythonhosted.org/packages/f8/aa/7af4e81f7acba21a4c6be026da38fd2b872ca46226673c89a758ebdc4fd2/PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774", size = 184612 }, + { url = "https://files.pythonhosted.org/packages/8b/62/b9faa998fd185f65c1371643678e4d58254add437edb764a08c5a98fb986/PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee", size = 172040 }, + { url = "https://files.pythonhosted.org/packages/ad/0c/c804f5f922a9a6563bab712d8dcc70251e8af811fce4524d57c2c0fd49a4/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c", size = 736829 }, + { url = "https://files.pythonhosted.org/packages/51/16/6af8d6a6b210c8e54f1406a6b9481febf9c64a3109c541567e35a49aa2e7/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317", size = 764167 }, + { url = "https://files.pythonhosted.org/packages/75/e4/2c27590dfc9992f73aabbeb9241ae20220bd9452df27483b6e56d3975cc5/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85", size = 762952 }, + { url = "https://files.pythonhosted.org/packages/9b/97/ecc1abf4a823f5ac61941a9c00fe501b02ac3ab0e373c3857f7d4b83e2b6/PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4", size = 735301 }, + { url = "https://files.pythonhosted.org/packages/45/73/0f49dacd6e82c9430e46f4a027baa4ca205e8b0a9dce1397f44edc23559d/PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e", size = 756638 }, + { url = "https://files.pythonhosted.org/packages/22/5f/956f0f9fc65223a58fbc14459bf34b4cc48dec52e00535c79b8db361aabd/PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5", size = 143850 }, + { url = "https://files.pythonhosted.org/packages/ed/23/8da0bbe2ab9dcdd11f4f4557ccaf95c10b9811b13ecced089d43ce59c3c8/PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44", size = 161980 }, +] + +[[package]] +name = "pyzmq" +version = "26.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "implementation_name == 'pypy' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fd/05/bed626b9f7bb2322cdbbf7b4bd8f54b1b617b0d2ab2d3547d6e39428a48e/pyzmq-26.2.0.tar.gz", hash = "sha256:070672c258581c8e4f640b5159297580a9974b026043bd4ab0470be9ed324f1f", size = 271975 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/a8/9837c39aba390eb7d01924ace49d761c8dbe7bc2d6082346d00c8332e431/pyzmq-26.2.0-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:ddf33d97d2f52d89f6e6e7ae66ee35a4d9ca6f36eda89c24591b0c40205a3629", size = 1340058 }, + { url = "https://files.pythonhosted.org/packages/a2/1f/a006f2e8e4f7d41d464272012695da17fb95f33b54342612a6890da96ff6/pyzmq-26.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dacd995031a01d16eec825bf30802fceb2c3791ef24bcce48fa98ce40918c27b", size = 1008818 }, + { url = "https://files.pythonhosted.org/packages/b6/09/b51b6683fde5ca04593a57bbe81788b6b43114d8f8ee4e80afc991e14760/pyzmq-26.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89289a5ee32ef6c439086184529ae060c741334b8970a6855ec0b6ad3ff28764", size = 673199 }, + { url = "https://files.pythonhosted.org/packages/c9/78/486f3e2e824f3a645238332bf5a4c4b4477c3063033a27c1e4052358dee2/pyzmq-26.2.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5506f06d7dc6ecf1efacb4a013b1f05071bb24b76350832c96449f4a2d95091c", size = 911762 }, + { url = "https://files.pythonhosted.org/packages/5e/3b/2eb1667c9b866f53e76ee8b0c301b0469745a23bd5a87b7ee3d5dd9eb6e5/pyzmq-26.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ea039387c10202ce304af74def5021e9adc6297067f3441d348d2b633e8166a", size = 868773 }, + { url = "https://files.pythonhosted.org/packages/16/29/ca99b4598a9dc7e468b5417eda91f372b595be1e3eec9b7cbe8e5d3584e8/pyzmq-26.2.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a2224fa4a4c2ee872886ed00a571f5e967c85e078e8e8c2530a2fb01b3309b88", size = 868834 }, + { url = "https://files.pythonhosted.org/packages/ad/e5/9efaeb1d2f4f8c50da04144f639b042bc52869d3a206d6bf672ab3522163/pyzmq-26.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:28ad5233e9c3b52d76196c696e362508959741e1a005fb8fa03b51aea156088f", size = 1202861 }, + { url = "https://files.pythonhosted.org/packages/c3/62/c721b5608a8ac0a69bb83cbb7d07a56f3ff00b3991a138e44198a16f94c7/pyzmq-26.2.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:1c17211bc037c7d88e85ed8b7d8f7e52db6dc8eca5590d162717c654550f7282", size = 1515304 }, + { url = "https://files.pythonhosted.org/packages/87/84/e8bd321aa99b72f48d4606fc5a0a920154125bd0a4608c67eab742dab087/pyzmq-26.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b8f86dd868d41bea9a5f873ee13bf5551c94cf6bc51baebc6f85075971fe6eea", size = 1414712 }, + { url = "https://files.pythonhosted.org/packages/cd/cd/420e3fd1ac6977b008b72e7ad2dae6350cc84d4c5027fc390b024e61738f/pyzmq-26.2.0-cp310-cp310-win32.whl", hash = "sha256:46a446c212e58456b23af260f3d9fb785054f3e3653dbf7279d8f2b5546b21c2", size = 578113 }, + { url = "https://files.pythonhosted.org/packages/5c/57/73930d56ed45ae0cb4946f383f985c855c9b3d4063f26416998f07523c0e/pyzmq-26.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:49d34ab71db5a9c292a7644ce74190b1dd5a3475612eefb1f8be1d6961441971", size = 641631 }, + { url = "https://files.pythonhosted.org/packages/61/d2/ae6ac5c397f1ccad59031c64beaafce7a0d6182e0452cc48f1c9c87d2dd0/pyzmq-26.2.0-cp310-cp310-win_arm64.whl", hash = "sha256:bfa832bfa540e5b5c27dcf5de5d82ebc431b82c453a43d141afb1e5d2de025fa", size = 543528 }, + { url = "https://files.pythonhosted.org/packages/12/20/de7442172f77f7c96299a0ac70e7d4fb78cd51eca67aa2cf552b66c14196/pyzmq-26.2.0-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:8f7e66c7113c684c2b3f1c83cdd3376103ee0ce4c49ff80a648643e57fb22218", size = 1340639 }, + { url = "https://files.pythonhosted.org/packages/98/4d/5000468bd64c7910190ed0a6c76a1ca59a68189ec1f007c451dc181a22f4/pyzmq-26.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3a495b30fc91db2db25120df5847d9833af237546fd59170701acd816ccc01c4", size = 1008710 }, + { url = "https://files.pythonhosted.org/packages/e1/bf/c67fd638c2f9fbbab8090a3ee779370b97c82b84cc12d0c498b285d7b2c0/pyzmq-26.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77eb0968da535cba0470a5165468b2cac7772cfb569977cff92e240f57e31bef", size = 673129 }, + { url = "https://files.pythonhosted.org/packages/86/94/99085a3f492aa538161cbf27246e8886ff850e113e0c294a5b8245f13b52/pyzmq-26.2.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ace4f71f1900a548f48407fc9be59c6ba9d9aaf658c2eea6cf2779e72f9f317", size = 910107 }, + { url = "https://files.pythonhosted.org/packages/31/1d/346809e8a9b999646d03f21096428453465b1bca5cd5c64ecd048d9ecb01/pyzmq-26.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92a78853d7280bffb93df0a4a6a2498cba10ee793cc8076ef797ef2f74d107cf", size = 867960 }, + { url = "https://files.pythonhosted.org/packages/ab/68/6fb6ae5551846ad5beca295b7bca32bf0a7ce19f135cb30e55fa2314e6b6/pyzmq-26.2.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:689c5d781014956a4a6de61d74ba97b23547e431e9e7d64f27d4922ba96e9d6e", size = 869204 }, + { url = "https://files.pythonhosted.org/packages/0f/f9/18417771dee223ccf0f48e29adf8b4e25ba6d0e8285e33bcbce078070bc3/pyzmq-26.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0aca98bc423eb7d153214b2df397c6421ba6373d3397b26c057af3c904452e37", size = 1203351 }, + { url = "https://files.pythonhosted.org/packages/e0/46/f13e67fe0d4f8a2315782cbad50493de6203ea0d744610faf4d5f5b16e90/pyzmq-26.2.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f3496d76b89d9429a656293744ceca4d2ac2a10ae59b84c1da9b5165f429ad3", size = 1514204 }, + { url = "https://files.pythonhosted.org/packages/50/11/ddcf7343b7b7a226e0fc7b68cbf5a5bb56291fac07f5c3023bb4c319ebb4/pyzmq-26.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5c2b3bfd4b9689919db068ac6c9911f3fcb231c39f7dd30e3138be94896d18e6", size = 1414339 }, + { url = "https://files.pythonhosted.org/packages/01/14/1c18d7d5b7be2708f513f37c61bfadfa62161c10624f8733f1c8451b3509/pyzmq-26.2.0-cp311-cp311-win32.whl", hash = "sha256:eac5174677da084abf378739dbf4ad245661635f1600edd1221f150b165343f4", size = 576928 }, + { url = "https://files.pythonhosted.org/packages/3b/1b/0a540edd75a41df14ec416a9a500b9fec66e554aac920d4c58fbd5756776/pyzmq-26.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:5a509df7d0a83a4b178d0f937ef14286659225ef4e8812e05580776c70e155d5", size = 642317 }, + { url = "https://files.pythonhosted.org/packages/98/77/1cbfec0358078a4c5add529d8a70892db1be900980cdb5dd0898b3d6ab9d/pyzmq-26.2.0-cp311-cp311-win_arm64.whl", hash = "sha256:c0e6091b157d48cbe37bd67233318dbb53e1e6327d6fc3bb284afd585d141003", size = 543834 }, + { url = "https://files.pythonhosted.org/packages/53/fb/36b2b2548286e9444e52fcd198760af99fd89102b5be50f0660fcfe902df/pyzmq-26.2.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:706e794564bec25819d21a41c31d4df2d48e1cc4b061e8d345d7fb4dd3e94072", size = 906955 }, + { url = "https://files.pythonhosted.org/packages/77/8f/6ce54f8979a01656e894946db6299e2273fcee21c8e5fa57c6295ef11f57/pyzmq-26.2.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b435f2753621cd36e7c1762156815e21c985c72b19135dac43a7f4f31d28dd1", size = 565701 }, + { url = "https://files.pythonhosted.org/packages/ee/1c/bf8cd66730a866b16db8483286078892b7f6536f8c389fb46e4beba0a970/pyzmq-26.2.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:160c7e0a5eb178011e72892f99f918c04a131f36056d10d9c1afb223fc952c2d", size = 794312 }, + { url = "https://files.pythonhosted.org/packages/71/43/91fa4ff25bbfdc914ab6bafa0f03241d69370ef31a761d16bb859f346582/pyzmq-26.2.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c4a71d5d6e7b28a47a394c0471b7e77a0661e2d651e7ae91e0cab0a587859ca", size = 752775 }, + { url = "https://files.pythonhosted.org/packages/ec/d2/3b2ab40f455a256cb6672186bea95cd97b459ce4594050132d71e76f0d6f/pyzmq-26.2.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:90412f2db8c02a3864cbfc67db0e3dcdbda336acf1c469526d3e869394fe001c", size = 550762 }, +] + +[[package]] +name = "questionary" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "prompt-toolkit" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/b8/d16eb579277f3de9e56e5ad25280fab52fc5774117fb70362e8c2e016559/questionary-2.1.0.tar.gz", hash = "sha256:6302cdd645b19667d8f6e6634774e9538bfcd1aad9be287e743d96cacaf95587", size = 26775 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ad/3f/11dd4cd4f39e05128bfd20138faea57bec56f9ffba6185d276e3107ba5b2/questionary-2.1.0-py3-none-any.whl", hash = "sha256:44174d237b68bc828e4878c763a9ad6790ee61990e0ae72927694ead57bab8ec", size = 36747 }, +] + +[[package]] +name = "referencing" +version = "0.36.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "rpds-py" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2f/db/98b5c277be99dd18bfd91dd04e1b759cad18d1a338188c936e92f921c7e2/referencing-0.36.2.tar.gz", hash = "sha256:df2e89862cd09deabbdba16944cc3f10feb6b3e6f18e902f7cc25609a34775aa", size = 74744 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/b1/3baf80dc6d2b7bc27a95a67752d0208e410351e3feb4eb78de5f77454d8d/referencing-0.36.2-py3-none-any.whl", hash = "sha256:e8699adbbf8b5c7de96d8ffa0eb5c158b3beafce084968e2ea8bb08c6794dcd0", size = 26775 }, +] + +[[package]] +name = "requests" +version = "2.32.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/63/70/2bf7780ad2d390a8d301ad0b550f1581eadbd9a20f896afe06353c2a2913/requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", size = 131218 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 }, +] + +[[package]] +name = "rich" +version = "13.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, + { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ab/3a/0316b28d0761c6734d6bc14e770d85506c986c85ffb239e688eeaab2c2bc/rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098", size = 223149 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424 }, +] + +[[package]] +name = "rich-click" +version = "1.8.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "rich" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9a/31/103501e85e885e3e202c087fa612cfe450693210372766552ce1ab5b57b9/rich_click-1.8.5.tar.gz", hash = "sha256:a3eebe81da1c9da3c32f3810017c79bd687ff1b3fa35bfc9d8a3338797f1d1a1", size = 38229 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/0b/e2de98c538c0ee9336211d260f88b7e69affab44969750aaca0b48a697c8/rich_click-1.8.5-py3-none-any.whl", hash = "sha256:0fab7bb5b66c15da17c210b4104277cd45f3653a7322e0098820a169880baee0", size = 35081 }, +] + +[[package]] +name = "rpds-py" +version = "0.22.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/80/cce854d0921ff2f0a9fa831ba3ad3c65cee3a46711addf39a2af52df2cfd/rpds_py-0.22.3.tar.gz", hash = "sha256:e32fee8ab45d3c2db6da19a5323bc3362237c8b653c70194414b892fd06a080d", size = 26771 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/2a/ead1d09e57449b99dcc190d8d2323e3a167421d8f8fdf0f217c6f6befe47/rpds_py-0.22.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6c7b99ca52c2c1752b544e310101b98a659b720b21db00e65edca34483259967", size = 359514 }, + { url = "https://files.pythonhosted.org/packages/8f/7e/1254f406b7793b586c68e217a6a24ec79040f85e030fff7e9049069284f4/rpds_py-0.22.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:be2eb3f2495ba669d2a985f9b426c1797b7d48d6963899276d22f23e33d47e37", size = 349031 }, + { url = "https://files.pythonhosted.org/packages/aa/da/17c6a2c73730d426df53675ff9cc6653ac7a60b6438d03c18e1c822a576a/rpds_py-0.22.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70eb60b3ae9245ddea20f8a4190bd79c705a22f8028aaf8bbdebe4716c3fab24", size = 381485 }, + { url = "https://files.pythonhosted.org/packages/aa/13/2dbacd820466aa2a3c4b747afb18d71209523d353cf865bf8f4796c969ea/rpds_py-0.22.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4041711832360a9b75cfb11b25a6a97c8fb49c07b8bd43d0d02b45d0b499a4ff", size = 386794 }, + { url = "https://files.pythonhosted.org/packages/6d/62/96905d0a35ad4e4bc3c098b2f34b2e7266e211d08635baa690643d2227be/rpds_py-0.22.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:64607d4cbf1b7e3c3c8a14948b99345eda0e161b852e122c6bb71aab6d1d798c", size = 423523 }, + { url = "https://files.pythonhosted.org/packages/eb/1b/d12770f2b6a9fc2c3ec0d810d7d440f6d465ccd8b7f16ae5385952c28b89/rpds_py-0.22.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e69b0a0e2537f26d73b4e43ad7bc8c8efb39621639b4434b76a3de50c6966e", size = 446695 }, + { url = "https://files.pythonhosted.org/packages/4d/cf/96f1fd75512a017f8e07408b6d5dbeb492d9ed46bfe0555544294f3681b3/rpds_py-0.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc27863442d388870c1809a87507727b799c8460573cfbb6dc0eeaef5a11b5ec", size = 381959 }, + { url = "https://files.pythonhosted.org/packages/ab/f0/d1c5b501c8aea85aeb938b555bfdf7612110a2f8cdc21ae0482c93dd0c24/rpds_py-0.22.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e79dd39f1e8c3504be0607e5fc6e86bb60fe3584bec8b782578c3b0fde8d932c", size = 410420 }, + { url = "https://files.pythonhosted.org/packages/33/3b/45b6c58fb6aad5a569ae40fb890fc494c6b02203505a5008ee6dc68e65f7/rpds_py-0.22.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e0fa2d4ec53dc51cf7d3bb22e0aa0143966119f42a0c3e4998293a3dd2856b09", size = 557620 }, + { url = "https://files.pythonhosted.org/packages/83/62/3fdd2d3d47bf0bb9b931c4c73036b4ab3ec77b25e016ae26fab0f02be2af/rpds_py-0.22.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fda7cb070f442bf80b642cd56483b5548e43d366fe3f39b98e67cce780cded00", size = 584202 }, + { url = "https://files.pythonhosted.org/packages/04/f2/5dced98b64874b84ca824292f9cee2e3f30f3bcf231d15a903126684f74d/rpds_py-0.22.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cff63a0272fcd259dcc3be1657b07c929c466b067ceb1c20060e8d10af56f5bf", size = 552787 }, + { url = "https://files.pythonhosted.org/packages/67/13/2273dea1204eda0aea0ef55145da96a9aa28b3f88bb5c70e994f69eda7c3/rpds_py-0.22.3-cp310-cp310-win32.whl", hash = "sha256:9bd7228827ec7bb817089e2eb301d907c0d9827a9e558f22f762bb690b131652", size = 220088 }, + { url = "https://files.pythonhosted.org/packages/4e/80/8c8176b67ad7f4a894967a7a4014ba039626d96f1d4874d53e409b58d69f/rpds_py-0.22.3-cp310-cp310-win_amd64.whl", hash = "sha256:9beeb01d8c190d7581a4d59522cd3d4b6887040dcfc744af99aa59fef3e041a8", size = 231737 }, + { url = "https://files.pythonhosted.org/packages/15/ad/8d1ddf78f2805a71253fcd388017e7b4a0615c22c762b6d35301fef20106/rpds_py-0.22.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d20cfb4e099748ea39e6f7b16c91ab057989712d31761d3300d43134e26e165f", size = 359773 }, + { url = "https://files.pythonhosted.org/packages/c8/75/68c15732293a8485d79fe4ebe9045525502a067865fa4278f178851b2d87/rpds_py-0.22.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:68049202f67380ff9aa52f12e92b1c30115f32e6895cd7198fa2a7961621fc5a", size = 349214 }, + { url = "https://files.pythonhosted.org/packages/3c/4c/7ce50f3070083c2e1b2bbd0fb7046f3da55f510d19e283222f8f33d7d5f4/rpds_py-0.22.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb4f868f712b2dd4bcc538b0a0c1f63a2b1d584c925e69a224d759e7070a12d5", size = 380477 }, + { url = "https://files.pythonhosted.org/packages/9a/e9/835196a69cb229d5c31c13b8ae603bd2da9a6695f35fe4270d398e1db44c/rpds_py-0.22.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bc51abd01f08117283c5ebf64844a35144a0843ff7b2983e0648e4d3d9f10dbb", size = 386171 }, + { url = "https://files.pythonhosted.org/packages/f9/8e/33fc4eba6683db71e91e6d594a2cf3a8fbceb5316629f0477f7ece5e3f75/rpds_py-0.22.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0f3cec041684de9a4684b1572fe28c7267410e02450f4561700ca5a3bc6695a2", size = 422676 }, + { url = "https://files.pythonhosted.org/packages/37/47/2e82d58f8046a98bb9497a8319604c92b827b94d558df30877c4b3c6ccb3/rpds_py-0.22.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7ef9d9da710be50ff6809fed8f1963fecdfecc8b86656cadfca3bc24289414b0", size = 446152 }, + { url = "https://files.pythonhosted.org/packages/e1/78/79c128c3e71abbc8e9739ac27af11dc0f91840a86fce67ff83c65d1ba195/rpds_py-0.22.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59f4a79c19232a5774aee369a0c296712ad0e77f24e62cad53160312b1c1eaa1", size = 381300 }, + { url = "https://files.pythonhosted.org/packages/c9/5b/2e193be0e8b228c1207f31fa3ea79de64dadb4f6a4833111af8145a6bc33/rpds_py-0.22.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1a60bce91f81ddaac922a40bbb571a12c1070cb20ebd6d49c48e0b101d87300d", size = 409636 }, + { url = "https://files.pythonhosted.org/packages/c2/3f/687c7100b762d62186a1c1100ffdf99825f6fa5ea94556844bbbd2d0f3a9/rpds_py-0.22.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e89391e6d60251560f0a8f4bd32137b077a80d9b7dbe6d5cab1cd80d2746f648", size = 556708 }, + { url = "https://files.pythonhosted.org/packages/8c/a2/c00cbc4b857e8b3d5e7f7fc4c81e23afd8c138b930f4f3ccf9a41a23e9e4/rpds_py-0.22.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e3fb866d9932a3d7d0c82da76d816996d1667c44891bd861a0f97ba27e84fc74", size = 583554 }, + { url = "https://files.pythonhosted.org/packages/d0/08/696c9872cf56effdad9ed617ac072f6774a898d46b8b8964eab39ec562d2/rpds_py-0.22.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1352ae4f7c717ae8cba93421a63373e582d19d55d2ee2cbb184344c82d2ae55a", size = 552105 }, + { url = "https://files.pythonhosted.org/packages/18/1f/4df560be1e994f5adf56cabd6c117e02de7c88ee238bb4ce03ed50da9d56/rpds_py-0.22.3-cp311-cp311-win32.whl", hash = "sha256:b0b4136a252cadfa1adb705bb81524eee47d9f6aab4f2ee4fa1e9d3cd4581f64", size = 220199 }, + { url = "https://files.pythonhosted.org/packages/b8/1b/c29b570bc5db8237553002788dc734d6bd71443a2ceac2a58202ec06ef12/rpds_py-0.22.3-cp311-cp311-win_amd64.whl", hash = "sha256:8bd7c8cfc0b8247c8799080fbff54e0b9619e17cdfeb0478ba7295d43f635d7c", size = 231775 }, + { url = "https://files.pythonhosted.org/packages/8b/63/e29f8ee14fcf383574f73b6bbdcbec0fbc2e5fc36b4de44d1ac389b1de62/rpds_py-0.22.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:d48424e39c2611ee1b84ad0f44fb3b2b53d473e65de061e3f460fc0be5f1939d", size = 360786 }, + { url = "https://files.pythonhosted.org/packages/d3/e0/771ee28b02a24e81c8c0e645796a371350a2bb6672753144f36ae2d2afc9/rpds_py-0.22.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:24e8abb5878e250f2eb0d7859a8e561846f98910326d06c0d51381fed59357bd", size = 350589 }, + { url = "https://files.pythonhosted.org/packages/cf/49/abad4c4a1e6f3adf04785a99c247bfabe55ed868133e2d1881200aa5d381/rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b232061ca880db21fa14defe219840ad9b74b6158adb52ddf0e87bead9e8493", size = 381848 }, + { url = "https://files.pythonhosted.org/packages/3a/7d/f4bc6d6fbe6af7a0d2b5f2ee77079efef7c8528712745659ec0026888998/rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac0a03221cdb5058ce0167ecc92a8c89e8d0decdc9e99a2ec23380793c4dcb96", size = 387879 }, + { url = "https://files.pythonhosted.org/packages/13/b0/575c797377fdcd26cedbb00a3324232e4cb2c5d121f6e4b0dbf8468b12ef/rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb0c341fa71df5a4595f9501df4ac5abfb5a09580081dffbd1ddd4654e6e9123", size = 423916 }, + { url = "https://files.pythonhosted.org/packages/54/78/87157fa39d58f32a68d3326f8a81ad8fb99f49fe2aa7ad9a1b7d544f9478/rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bf9db5488121b596dbfc6718c76092fda77b703c1f7533a226a5a9f65248f8ad", size = 448410 }, + { url = "https://files.pythonhosted.org/packages/59/69/860f89996065a88be1b6ff2d60e96a02b920a262d8aadab99e7903986597/rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b8db6b5b2d4491ad5b6bdc2bc7c017eec108acbf4e6785f42a9eb0ba234f4c9", size = 382841 }, + { url = "https://files.pythonhosted.org/packages/bd/d7/bc144e10d27e3cb350f98df2492a319edd3caaf52ddfe1293f37a9afbfd7/rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b3d504047aba448d70cf6fa22e06cb09f7cbd761939fdd47604f5e007675c24e", size = 409662 }, + { url = "https://files.pythonhosted.org/packages/14/2a/6bed0b05233c291a94c7e89bc76ffa1c619d4e1979fbfe5d96024020c1fb/rpds_py-0.22.3-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:e61b02c3f7a1e0b75e20c3978f7135fd13cb6cf551bf4a6d29b999a88830a338", size = 558221 }, + { url = "https://files.pythonhosted.org/packages/11/23/cd8f566de444a137bc1ee5795e47069a947e60810ba4152886fe5308e1b7/rpds_py-0.22.3-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:e35ba67d65d49080e8e5a1dd40101fccdd9798adb9b050ff670b7d74fa41c566", size = 583780 }, + { url = "https://files.pythonhosted.org/packages/8d/63/79c3602afd14d501f751e615a74a59040328da5ef29ed5754ae80d236b84/rpds_py-0.22.3-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:26fd7cac7dd51011a245f29a2cc6489c4608b5a8ce8d75661bb4a1066c52dfbe", size = 553619 }, + { url = "https://files.pythonhosted.org/packages/9f/2e/c5c1689e80298d4e94c75b70faada4c25445739d91b94c211244a3ed7ed1/rpds_py-0.22.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:177c7c0fce2855833819c98e43c262007f42ce86651ffbb84f37883308cb0e7d", size = 233338 }, +] + +[[package]] +name = "ruamel-yaml" +version = "0.18.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ruamel-yaml-clib", marker = "platform_python_implementation == 'CPython' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ea/46/f44d8be06b85bc7c4d8c95d658be2b68f27711f279bf9dd0612a5e4794f5/ruamel.yaml-0.18.10.tar.gz", hash = "sha256:20c86ab29ac2153f80a428e1254a8adf686d3383df04490514ca3b79a362db58", size = 143447 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/36/dfc1ebc0081e6d39924a2cc53654497f967a084a436bb64402dfce4254d9/ruamel.yaml-0.18.10-py3-none-any.whl", hash = "sha256:30f22513ab2301b3d2b577adc121c6471f28734d3d9728581245f1e76468b4f1", size = 117729 }, +] + +[[package]] +name = "ruamel-yaml-clib" +version = "0.2.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/84/80203abff8ea4993a87d823a5f632e4d92831ef75d404c9fc78d0176d2b5/ruamel.yaml.clib-0.2.12.tar.gz", hash = "sha256:6c8fbb13ec503f99a91901ab46e0b07ae7941cd527393187039aec586fdfd36f", size = 225315 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/70/57/40a958e863e299f0c74ef32a3bde9f2d1ea8d69669368c0c502a0997f57f/ruamel.yaml.clib-0.2.12-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:11f891336688faf5156a36293a9c362bdc7c88f03a8a027c2c1d8e0bcde998e5", size = 131301 }, + { url = "https://files.pythonhosted.org/packages/98/a8/29a3eb437b12b95f50a6bcc3d7d7214301c6c529d8fdc227247fa84162b5/ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:a606ef75a60ecf3d924613892cc603b154178ee25abb3055db5062da811fd969", size = 633728 }, + { url = "https://files.pythonhosted.org/packages/35/6d/ae05a87a3ad540259c3ad88d71275cbd1c0f2d30ae04c65dcbfb6dcd4b9f/ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd5415dded15c3822597455bc02bcd66e81ef8b7a48cb71a33628fc9fdde39df", size = 722230 }, + { url = "https://files.pythonhosted.org/packages/7f/b7/20c6f3c0b656fe609675d69bc135c03aac9e3865912444be6339207b6648/ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f66efbc1caa63c088dead1c4170d148eabc9b80d95fb75b6c92ac0aad2437d76", size = 686712 }, + { url = "https://files.pythonhosted.org/packages/cd/11/d12dbf683471f888d354dac59593873c2b45feb193c5e3e0f2ebf85e68b9/ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22353049ba4181685023b25b5b51a574bce33e7f51c759371a7422dcae5402a6", size = 663936 }, + { url = "https://files.pythonhosted.org/packages/72/14/4c268f5077db5c83f743ee1daeb236269fa8577133a5cfa49f8b382baf13/ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:932205970b9f9991b34f55136be327501903f7c66830e9760a8ffb15b07f05cd", size = 696580 }, + { url = "https://files.pythonhosted.org/packages/30/fc/8cd12f189c6405a4c1cf37bd633aa740a9538c8e40497c231072d0fef5cf/ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a52d48f4e7bf9005e8f0a89209bf9a73f7190ddf0489eee5eb51377385f59f2a", size = 663393 }, + { url = "https://files.pythonhosted.org/packages/80/29/c0a017b704aaf3cbf704989785cd9c5d5b8ccec2dae6ac0c53833c84e677/ruamel.yaml.clib-0.2.12-cp310-cp310-win32.whl", hash = "sha256:3eac5a91891ceb88138c113f9db04f3cebdae277f5d44eaa3651a4f573e6a5da", size = 100326 }, + { url = "https://files.pythonhosted.org/packages/3a/65/fa39d74db4e2d0cd252355732d966a460a41cd01c6353b820a0952432839/ruamel.yaml.clib-0.2.12-cp310-cp310-win_amd64.whl", hash = "sha256:ab007f2f5a87bd08ab1499bdf96f3d5c6ad4dcfa364884cb4549aa0154b13a28", size = 118079 }, + { url = "https://files.pythonhosted.org/packages/fb/8f/683c6ad562f558cbc4f7c029abcd9599148c51c54b5ef0f24f2638da9fbb/ruamel.yaml.clib-0.2.12-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:4a6679521a58256a90b0d89e03992c15144c5f3858f40d7c18886023d7943db6", size = 132224 }, + { url = "https://files.pythonhosted.org/packages/3c/d2/b79b7d695e2f21da020bd44c782490578f300dd44f0a4c57a92575758a76/ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d84318609196d6bd6da0edfa25cedfbabd8dbde5140a0a23af29ad4b8f91fb1e", size = 641480 }, + { url = "https://files.pythonhosted.org/packages/68/6e/264c50ce2a31473a9fdbf4fa66ca9b2b17c7455b31ef585462343818bd6c/ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb43a269eb827806502c7c8efb7ae7e9e9d0573257a46e8e952f4d4caba4f31e", size = 739068 }, + { url = "https://files.pythonhosted.org/packages/86/29/88c2567bc893c84d88b4c48027367c3562ae69121d568e8a3f3a8d363f4d/ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:811ea1594b8a0fb466172c384267a4e5e367298af6b228931f273b111f17ef52", size = 703012 }, + { url = "https://files.pythonhosted.org/packages/11/46/879763c619b5470820f0cd6ca97d134771e502776bc2b844d2adb6e37753/ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cf12567a7b565cbf65d438dec6cfbe2917d3c1bdddfce84a9930b7d35ea59642", size = 704352 }, + { url = "https://files.pythonhosted.org/packages/02/80/ece7e6034256a4186bbe50dee28cd032d816974941a6abf6a9d65e4228a7/ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7dd5adc8b930b12c8fc5b99e2d535a09889941aa0d0bd06f4749e9a9397c71d2", size = 737344 }, + { url = "https://files.pythonhosted.org/packages/f0/ca/e4106ac7e80efbabdf4bf91d3d32fc424e41418458251712f5672eada9ce/ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1492a6051dab8d912fc2adeef0e8c72216b24d57bd896ea607cb90bb0c4981d3", size = 714498 }, + { url = "https://files.pythonhosted.org/packages/67/58/b1f60a1d591b771298ffa0428237afb092c7f29ae23bad93420b1eb10703/ruamel.yaml.clib-0.2.12-cp311-cp311-win32.whl", hash = "sha256:bd0a08f0bab19093c54e18a14a10b4322e1eacc5217056f3c063bd2f59853ce4", size = 100205 }, + { url = "https://files.pythonhosted.org/packages/b4/4f/b52f634c9548a9291a70dfce26ca7ebce388235c93588a1068028ea23fcc/ruamel.yaml.clib-0.2.12-cp311-cp311-win_amd64.whl", hash = "sha256:a274fb2cb086c7a3dea4322ec27f4cb5cc4b6298adb583ab0e211a4682f241eb", size = 118185 }, +] + +[[package]] +name = "ruff" +version = "0.9.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1e/7f/60fda2eec81f23f8aa7cbbfdf6ec2ca11eb11c273827933fb2541c2ce9d8/ruff-0.9.3.tar.gz", hash = "sha256:8293f89985a090ebc3ed1064df31f3b4b56320cdfcec8b60d3295bddb955c22a", size = 3586740 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/77/4fb790596d5d52c87fd55b7160c557c400e90f6116a56d82d76e95d9374a/ruff-0.9.3-py3-none-linux_armv6l.whl", hash = "sha256:7f39b879064c7d9670197d91124a75d118d00b0990586549949aae80cdc16624", size = 11656815 }, + { url = "https://files.pythonhosted.org/packages/a2/a8/3338ecb97573eafe74505f28431df3842c1933c5f8eae615427c1de32858/ruff-0.9.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:a187171e7c09efa4b4cc30ee5d0d55a8d6c5311b3e1b74ac5cb96cc89bafc43c", size = 11594821 }, + { url = "https://files.pythonhosted.org/packages/8e/89/320223c3421962762531a6b2dd58579b858ca9916fb2674874df5e97d628/ruff-0.9.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c59ab92f8e92d6725b7ded9d4a31be3ef42688a115c6d3da9457a5bda140e2b4", size = 11040475 }, + { url = "https://files.pythonhosted.org/packages/b2/bd/1d775eac5e51409535804a3a888a9623e87a8f4b53e2491580858a083692/ruff-0.9.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2dc153c25e715be41bb228bc651c1e9b1a88d5c6e5ed0194fa0dfea02b026439", size = 11856207 }, + { url = "https://files.pythonhosted.org/packages/7f/c6/3e14e09be29587393d188454064a4aa85174910d16644051a80444e4fd88/ruff-0.9.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:646909a1e25e0dc28fbc529eab8eb7bb583079628e8cbe738192853dbbe43af5", size = 11420460 }, + { url = "https://files.pythonhosted.org/packages/ef/42/b7ca38ffd568ae9b128a2fa76353e9a9a3c80ef19746408d4ce99217ecc1/ruff-0.9.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a5a46e09355695fbdbb30ed9889d6cf1c61b77b700a9fafc21b41f097bfbba4", size = 12605472 }, + { url = "https://files.pythonhosted.org/packages/a6/a1/3167023f23e3530fde899497ccfe239e4523854cb874458ac082992d206c/ruff-0.9.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c4bb09d2bbb394e3730d0918c00276e79b2de70ec2a5231cd4ebb51a57df9ba1", size = 13243123 }, + { url = "https://files.pythonhosted.org/packages/d0/b4/3c600758e320f5bf7de16858502e849f4216cb0151f819fa0d1154874802/ruff-0.9.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:96a87ec31dc1044d8c2da2ebbed1c456d9b561e7d087734336518181b26b3aa5", size = 12744650 }, + { url = "https://files.pythonhosted.org/packages/be/38/266fbcbb3d0088862c9bafa8b1b99486691d2945a90b9a7316336a0d9a1b/ruff-0.9.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9bb7554aca6f842645022fe2d301c264e6925baa708b392867b7a62645304df4", size = 14458585 }, + { url = "https://files.pythonhosted.org/packages/63/a6/47fd0e96990ee9b7a4abda62de26d291bd3f7647218d05b7d6d38af47c30/ruff-0.9.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cabc332b7075a914ecea912cd1f3d4370489c8018f2c945a30bcc934e3bc06a6", size = 12419624 }, + { url = "https://files.pythonhosted.org/packages/84/5d/de0b7652e09f7dda49e1a3825a164a65f4998175b6486603c7601279baad/ruff-0.9.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:33866c3cc2a575cbd546f2cd02bdd466fed65118e4365ee538a3deffd6fcb730", size = 11843238 }, + { url = "https://files.pythonhosted.org/packages/9e/be/3f341ceb1c62b565ec1fb6fd2139cc40b60ae6eff4b6fb8f94b1bb37c7a9/ruff-0.9.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:006e5de2621304c8810bcd2ee101587712fa93b4f955ed0985907a36c427e0c2", size = 11484012 }, + { url = "https://files.pythonhosted.org/packages/a3/c8/ff8acbd33addc7e797e702cf00bfde352ab469723720c5607b964491d5cf/ruff-0.9.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ba6eea4459dbd6b1be4e6bfc766079fb9b8dd2e5a35aff6baee4d9b1514ea519", size = 12038494 }, + { url = "https://files.pythonhosted.org/packages/73/b1/8d9a2c0efbbabe848b55f877bc10c5001a37ab10aca13c711431673414e5/ruff-0.9.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:90230a6b8055ad47d3325e9ee8f8a9ae7e273078a66401ac66df68943ced029b", size = 12473639 }, + { url = "https://files.pythonhosted.org/packages/cb/44/a673647105b1ba6da9824a928634fe23186ab19f9d526d7bdf278cd27bc3/ruff-0.9.3-py3-none-win32.whl", hash = "sha256:eabe5eb2c19a42f4808c03b82bd313fc84d4e395133fb3fc1b1516170a31213c", size = 9834353 }, + { url = "https://files.pythonhosted.org/packages/c3/01/65cadb59bf8d4fbe33d1a750103e6883d9ef302f60c28b73b773092fbde5/ruff-0.9.3-py3-none-win_amd64.whl", hash = "sha256:040ceb7f20791dfa0e78b4230ee9dce23da3b64dd5848e40e3bf3ab76468dcf4", size = 10821444 }, + { url = "https://files.pythonhosted.org/packages/69/cb/b3fe58a136a27d981911cba2f18e4b29f15010623b79f0f2510fd0d31fd3/ruff-0.9.3-py3-none-win_arm64.whl", hash = "sha256:800d773f6d4d33b0a3c60e2c6ae8f4c202ea2de056365acfa519aa48acf28e0b", size = 10038168 }, +] + +[[package]] +name = "scipy" +version = "1.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/76/c6/8eb0654ba0c7d0bb1bf67bf8fbace101a8e4f250f7722371105e8b6f68fc/scipy-1.15.1.tar.gz", hash = "sha256:033a75ddad1463970c96a88063a1df87ccfddd526437136b6ee81ff0312ebdf6", size = 59407493 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/86/53/b204ce5a4433f1864001b9d16f103b9c25f5002a602ae83585d0ea5f9c4a/scipy-1.15.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:c64ded12dcab08afff9e805a67ff4480f5e69993310e093434b10e85dc9d43e1", size = 41414518 }, + { url = "https://files.pythonhosted.org/packages/c7/fc/54ffa7a8847f7f303197a6ba65a66104724beba2e38f328135a78f0dc480/scipy-1.15.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5b190b935e7db569960b48840e5bef71dc513314cc4e79a1b7d14664f57fd4ff", size = 32519265 }, + { url = "https://files.pythonhosted.org/packages/f1/77/a98b8ba03d6f371dc31a38719affd53426d4665729dcffbed4afe296784a/scipy-1.15.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:4b17d4220df99bacb63065c76b0d1126d82bbf00167d1730019d2a30d6ae01ea", size = 24792859 }, + { url = "https://files.pythonhosted.org/packages/a7/78/70bb9f0df7444b18b108580934bfef774822e28fd34a68e5c263c7d2828a/scipy-1.15.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:63b9b6cd0333d0eb1a49de6f834e8aeaefe438df8f6372352084535ad095219e", size = 27886506 }, + { url = "https://files.pythonhosted.org/packages/14/a7/f40f6033e06de4176ddd6cc8c3ae9f10a226c3bca5d6b4ab883bc9914a14/scipy-1.15.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f151e9fb60fbf8e52426132f473221a49362091ce7a5e72f8aa41f8e0da4f25", size = 38375041 }, + { url = "https://files.pythonhosted.org/packages/17/03/390a1c5c61fd76b0fa4b3c5aa3bdd7e60f6c46f712924f1a9df5705ec046/scipy-1.15.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21e10b1dd56ce92fba3e786007322542361984f8463c6d37f6f25935a5a6ef52", size = 40597556 }, + { url = "https://files.pythonhosted.org/packages/4e/70/fa95b3ae026b97eeca58204a90868802e5155ac71b9d7bdee92b68115dd3/scipy-1.15.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5dff14e75cdbcf07cdaa1c7707db6017d130f0af9ac41f6ce443a93318d6c6e0", size = 42938505 }, + { url = "https://files.pythonhosted.org/packages/d6/07/427859116bdd71847c898180f01802691f203c3e2455a1eb496130ff07c5/scipy-1.15.1-cp310-cp310-win_amd64.whl", hash = "sha256:f82fcf4e5b377f819542fbc8541f7b5fbcf1c0017d0df0bc22c781bf60abc4d8", size = 43909663 }, + { url = "https://files.pythonhosted.org/packages/8e/2e/7b71312da9c2dabff53e7c9a9d08231bc34d9d8fdabe88a6f1155b44591c/scipy-1.15.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:5bd8d27d44e2c13d0c1124e6a556454f52cd3f704742985f6b09e75e163d20d2", size = 41424362 }, + { url = "https://files.pythonhosted.org/packages/81/8c/ab85f1aa1cc200c796532a385b6ebf6a81089747adc1da7482a062acc46c/scipy-1.15.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:be3deeb32844c27599347faa077b359584ba96664c5c79d71a354b80a0ad0ce0", size = 32535910 }, + { url = "https://files.pythonhosted.org/packages/3b/9c/6f4b787058daa8d8da21ddff881b4320e28de4704a65ec147adb50cb2230/scipy-1.15.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:5eb0ca35d4b08e95da99a9f9c400dc9f6c21c424298a0ba876fdc69c7afacedf", size = 24809398 }, + { url = "https://files.pythonhosted.org/packages/16/2b/949460a796df75fc7a1ee1becea202cf072edbe325ebe29f6d2029947aa7/scipy-1.15.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:74bb864ff7640dea310a1377d8567dc2cb7599c26a79ca852fc184cc851954ac", size = 27918045 }, + { url = "https://files.pythonhosted.org/packages/5f/36/67fe249dd7ccfcd2a38b25a640e3af7e59d9169c802478b6035ba91dfd6d/scipy-1.15.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:667f950bf8b7c3a23b4199db24cb9bf7512e27e86d0e3813f015b74ec2c6e3df", size = 38332074 }, + { url = "https://files.pythonhosted.org/packages/fc/da/452e1119e6f720df3feb588cce3c42c5e3d628d4bfd4aec097bd30b7de0c/scipy-1.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:395be70220d1189756068b3173853029a013d8c8dd5fd3d1361d505b2aa58fa7", size = 40588469 }, + { url = "https://files.pythonhosted.org/packages/7f/71/5f94aceeac99a4941478af94fe9f459c6752d497035b6b0761a700f5f9ff/scipy-1.15.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ce3a000cd28b4430426db2ca44d96636f701ed12e2b3ca1f2b1dd7abdd84b39a", size = 42965214 }, + { url = "https://files.pythonhosted.org/packages/af/25/caa430865749d504271757cafd24066d596217e83326155993980bc22f97/scipy-1.15.1-cp311-cp311-win_amd64.whl", hash = "sha256:3fe1d95944f9cf6ba77aa28b82dd6bb2a5b52f2026beb39ecf05304b8392864b", size = 43896034 }, +] + +[[package]] +name = "setuptools" +version = "75.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/92/ec/089608b791d210aec4e7f97488e67ab0d33add3efccb83a056cbafe3a2a6/setuptools-75.8.0.tar.gz", hash = "sha256:c5afc8f407c626b8313a86e10311dd3f661c6cd9c09d4bf8c15c0e11f9f2b0e6", size = 1343222 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/8a/b9dc7678803429e4a3bc9ba462fa3dd9066824d3c607490235c6a796be5a/setuptools-75.8.0-py3-none-any.whl", hash = "sha256:e3982f444617239225d675215d51f6ba05f845d4eec313da4418fdbb56fb27e3", size = 1228782 }, +] + +[[package]] +name = "setuptools-scm" +version = "8.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, + { name = "setuptools" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4f/a4/00a9ac1b555294710d4a68d2ce8dfdf39d72aa4d769a7395d05218d88a42/setuptools_scm-8.1.0.tar.gz", hash = "sha256:42dea1b65771cba93b7a515d65a65d8246e560768a66b9106a592c8e7f26c8a7", size = 76465 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/b9/1906bfeb30f2fc13bb39bf7ddb8749784c05faadbd18a21cf141ba37bff2/setuptools_scm-8.1.0-py3-none-any.whl", hash = "sha256:897a3226a6fd4a6eb2f068745e49733261a21f70b1bb28fce0339feb978d9af3", size = 43666 }, +] + +[[package]] +name = "six" +version = "1.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050 }, +] + +[[package]] +name = "smmap" +version = "5.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/cd/a040c4b3119bbe532e5b0732286f805445375489fceaec1f48306068ee3b/smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5", size = 22329 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303 }, +] + +[[package]] +name = "snowballstemmer" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/7b/af302bebf22c749c56c9c3e8ae13190b5b5db37a33d9068652e8f73b7089/snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1", size = 86699 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/dc/c02e01294f7265e63a7315fe086dd1df7dacb9f840a804da846b96d01b96/snowballstemmer-2.2.0-py2.py3-none-any.whl", hash = "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a", size = 93002 }, +] + +[[package]] +name = "sortedcontainers" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575 }, +] + +[[package]] +name = "soupsieve" +version = "2.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/ce/fbaeed4f9fb8b2daa961f90591662df6a86c1abf25c548329a86920aedfb/soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb", size = 101569 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/c2/fe97d779f3ef3b15f05c94a2f1e3d21732574ed441687474db9d342a7315/soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9", size = 36186 }, +] + +[[package]] +name = "sphinx" +version = "8.1.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "alabaster" }, + { name = "babel" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "docutils" }, + { name = "imagesize" }, + { name = "jinja2" }, + { name = "packaging" }, + { name = "pygments" }, + { name = "requests" }, + { name = "snowballstemmer" }, + { name = "sphinxcontrib-applehelp" }, + { name = "sphinxcontrib-devhelp" }, + { name = "sphinxcontrib-htmlhelp" }, + { name = "sphinxcontrib-jsmath" }, + { name = "sphinxcontrib-qthelp" }, + { name = "sphinxcontrib-serializinghtml" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6f/6d/be0b61178fe2cdcb67e2a92fc9ebb488e3c51c4f74a36a7824c0adf23425/sphinx-8.1.3.tar.gz", hash = "sha256:43c1911eecb0d3e161ad78611bc905d1ad0e523e4ddc202a58a821773dc4c927", size = 8184611 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/26/60/1ddff83a56d33aaf6f10ec8ce84b4c007d9368b21008876fceda7e7381ef/sphinx-8.1.3-py3-none-any.whl", hash = "sha256:09719015511837b76bf6e03e42eb7595ac8c2e41eeb9c29c5b755c6b677992a2", size = 3487125 }, +] + +[[package]] +name = "sphinx-autodoc-typehints" +version = "3.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "sphinx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/26/f0/43c6a5ff3e7b08a8c3b32f81b859f1b518ccc31e45f22e2b41ced38be7b9/sphinx_autodoc_typehints-3.0.1.tar.gz", hash = "sha256:b9b40dd15dee54f6f810c924f863f9cf1c54f9f3265c495140ea01be7f44fa55", size = 36282 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/dc/dc46c5c7c566b7ec5e8f860f9c89533bf03c0e6aadc96fb9b337867e4460/sphinx_autodoc_typehints-3.0.1-py3-none-any.whl", hash = "sha256:4b64b676a14b5b79cefb6628a6dc8070e320d4963e8ff640a2f3e9390ae9045a", size = 20245 }, +] + +[[package]] +name = "sphinx-jinja2-compat" +version = "0.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jinja2" }, + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/26/df/27282da6f8c549f765beca9de1a5fc56f9651ed87711a5cac1e914137753/sphinx_jinja2_compat-0.3.0.tar.gz", hash = "sha256:f3c1590b275f42e7a654e081db5e3e5fb97f515608422bde94015ddf795dfe7c", size = 4998 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/42/2fd09d672eaaa937d6893d8b747d07943f97a6e5e30653aee6ebd339b704/sphinx_jinja2_compat-0.3.0-py3-none-any.whl", hash = "sha256:b1e4006d8e1ea31013fa9946d1b075b0c8d2a42c6e3425e63542c1e9f8be9084", size = 7883 }, +] + +[[package]] +name = "sphinx-prompt" +version = "1.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "docutils" }, + { name = "idna" }, + { name = "pygments" }, + { name = "sphinx" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/34/fe/ac4e24f35b5148b31ac717ae7dcc7a2f7ec56eb729e22c7252ed8ad2d9a5/sphinx_prompt-1.9.0.tar.gz", hash = "sha256:471b3c6d466dce780a9b167d9541865fd4e9a80ed46e31b06a52a0529ae995a1", size = 5340 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/98/e90ca466e0ede452d3e5a8d92b8fb68db6de269856e019ed9cab69440522/sphinx_prompt-1.9.0-py3-none-any.whl", hash = "sha256:fd731446c03f043d1ff6df9f22414495b23067c67011cc21658ea8d36b3575fc", size = 7311 }, +] + +[[package]] +name = "sphinx-rtd-theme" +version = "3.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "docutils" }, + { name = "sphinx" }, + { name = "sphinxcontrib-jquery" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/44/c97faec644d29a5ceddd3020ae2edffa69e7d00054a8c7a6021e82f20335/sphinx_rtd_theme-3.0.2.tar.gz", hash = "sha256:b7457bc25dda723b20b086a670b9953c859eab60a2a03ee8eb2bb23e176e5f85", size = 7620463 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/77/46e3bac77b82b4df5bb5b61f2de98637724f246b4966cfc34bc5895d852a/sphinx_rtd_theme-3.0.2-py2.py3-none-any.whl", hash = "sha256:422ccc750c3a3a311de4ae327e82affdaf59eb695ba4936538552f3b00f4ee13", size = 7655561 }, +] + +[[package]] +name = "sphinx-tabs" +version = "3.4.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "docutils" }, + { name = "pygments" }, + { name = "sphinx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/27/32/ab475e252dc2b704e82a91141fa404cdd8901a5cf34958fd22afacebfccd/sphinx-tabs-3.4.5.tar.gz", hash = "sha256:ba9d0c1e3e37aaadd4b5678449eb08176770e0fc227e769b6ce747df3ceea531", size = 16070 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/9f/4ac7dbb9f23a2ff5a10903a4f9e9f43e0ff051f63a313e989c962526e305/sphinx_tabs-3.4.5-py3-none-any.whl", hash = "sha256:92cc9473e2ecf1828ca3f6617d0efc0aa8acb06b08c56ba29d1413f2f0f6cf09", size = 9904 }, +] + +[[package]] +name = "sphinx-toolbox" +version = "3.8.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "apeye" }, + { name = "autodocsumm" }, + { name = "beautifulsoup4" }, + { name = "cachecontrol", extra = ["filecache"] }, + { name = "dict2css" }, + { name = "docutils" }, + { name = "domdf-python-tools" }, + { name = "filelock" }, + { name = "html5lib" }, + { name = "ruamel-yaml" }, + { name = "sphinx" }, + { name = "sphinx-autodoc-typehints" }, + { name = "sphinx-jinja2-compat" }, + { name = "sphinx-prompt" }, + { name = "sphinx-tabs" }, + { name = "tabulate" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/30/80/f837e85c8c216cdeef9b60393e4b00c9092a1e3d734106e0021abbf5930c/sphinx_toolbox-3.8.1.tar.gz", hash = "sha256:a4b39a6ea24fc8f10e24f052199bda17837a0bf4c54163a56f521552395f5e1a", size = 111977 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/d6/2a28ee4cbc158ae65afb2cfcb6895ef54d972ce1e167f8a63c135b14b080/sphinx_toolbox-3.8.1-py3-none-any.whl", hash = "sha256:53d8e77dd79e807d9ef18590c4b2960a5aa3c147415054b04c31a91afed8b88b", size = 194621 }, +] + +[[package]] +name = "sphinxcontrib-applehelp" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/6e/b837e84a1a704953c62ef8776d45c3e8d759876b4a84fe14eba2859106fe/sphinxcontrib_applehelp-2.0.0.tar.gz", hash = "sha256:2f29ef331735ce958efa4734873f084941970894c6090408b079c61b2e1c06d1", size = 20053 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/85/9ebeae2f76e9e77b952f4b274c27238156eae7979c5421fba91a28f4970d/sphinxcontrib_applehelp-2.0.0-py3-none-any.whl", hash = "sha256:4cd3f0ec4ac5dd9c17ec65e9ab272c9b867ea77425228e68ecf08d6b28ddbdb5", size = 119300 }, +] + +[[package]] +name = "sphinxcontrib-devhelp" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f6/d2/5beee64d3e4e747f316bae86b55943f51e82bb86ecd325883ef65741e7da/sphinxcontrib_devhelp-2.0.0.tar.gz", hash = "sha256:411f5d96d445d1d73bb5d52133377b4248ec79db5c793ce7dbe59e074b4dd1ad", size = 12967 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/35/7a/987e583882f985fe4d7323774889ec58049171828b58c2217e7f79cdf44e/sphinxcontrib_devhelp-2.0.0-py3-none-any.whl", hash = "sha256:aefb8b83854e4b0998877524d1029fd3e6879210422ee3780459e28a1f03a8a2", size = 82530 }, +] + +[[package]] +name = "sphinxcontrib-htmlhelp" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/93/983afd9aa001e5201eab16b5a444ed5b9b0a7a010541e0ddfbbfd0b2470c/sphinxcontrib_htmlhelp-2.1.0.tar.gz", hash = "sha256:c9e2916ace8aad64cc13a0d233ee22317f2b9025b9cf3295249fa985cc7082e9", size = 22617 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0a/7b/18a8c0bcec9182c05a0b3ec2a776bba4ead82750a55ff798e8d406dae604/sphinxcontrib_htmlhelp-2.1.0-py3-none-any.whl", hash = "sha256:166759820b47002d22914d64a075ce08f4c46818e17cfc9470a9786b759b19f8", size = 98705 }, +] + +[[package]] +name = "sphinxcontrib-jquery" +version = "4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "sphinx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/de/f3/aa67467e051df70a6330fe7770894b3e4f09436dea6881ae0b4f3d87cad8/sphinxcontrib-jquery-4.1.tar.gz", hash = "sha256:1620739f04e36a2c779f1a131a2dfd49b2fd07351bf1968ced074365933abc7a", size = 122331 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/85/749bd22d1a68db7291c89e2ebca53f4306c3f205853cf31e9de279034c3c/sphinxcontrib_jquery-4.1-py2.py3-none-any.whl", hash = "sha256:f936030d7d0147dd026a4f2b5a57343d233f1fc7b363f68b3d4f1cb0993878ae", size = 121104 }, +] + +[[package]] +name = "sphinxcontrib-jsmath" +version = "1.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/e8/9ed3830aeed71f17c026a07a5097edcf44b692850ef215b161b8ad875729/sphinxcontrib-jsmath-1.0.1.tar.gz", hash = "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8", size = 5787 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/42/4c8646762ee83602e3fb3fbe774c2fac12f317deb0b5dbeeedd2d3ba4b77/sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl", hash = "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178", size = 5071 }, +] + +[[package]] +name = "sphinxcontrib-qthelp" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/68/bc/9104308fc285eb3e0b31b67688235db556cd5b0ef31d96f30e45f2e51cae/sphinxcontrib_qthelp-2.0.0.tar.gz", hash = "sha256:4fe7d0ac8fc171045be623aba3e2a8f613f8682731f9153bb2e40ece16b9bbab", size = 17165 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/83/859ecdd180cacc13b1f7e857abf8582a64552ea7a061057a6c716e790fce/sphinxcontrib_qthelp-2.0.0-py3-none-any.whl", hash = "sha256:b18a828cdba941ccd6ee8445dbe72ffa3ef8cbe7505d8cd1fa0d42d3f2d5f3eb", size = 88743 }, +] + +[[package]] +name = "sphinxcontrib-serializinghtml" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3b/44/6716b257b0aa6bfd51a1b31665d1c205fb12cb5ad56de752dfa15657de2f/sphinxcontrib_serializinghtml-2.0.0.tar.gz", hash = "sha256:e9d912827f872c029017a53f0ef2180b327c3f7fd23c87229f7a8e8b70031d4d", size = 16080 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/a7/d2782e4e3f77c8450f727ba74a8f12756d5ba823d81b941f1b04da9d033a/sphinxcontrib_serializinghtml-2.0.0-py3-none-any.whl", hash = "sha256:6e2cb0eef194e10c27ec0023bfeb25badbbb5868244cf5bc5bdc04e4464bf331", size = 92072 }, +] + +[[package]] +name = "stack-data" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asttokens" }, + { name = "executing" }, + { name = "pure-eval" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/28/e3/55dcc2cfbc3ca9c29519eb6884dd1415ecb53b0e934862d3559ddcb7e20b/stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9", size = 44707 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521 }, +] + +[[package]] +name = "sympy" +version = "1.13.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mpmath" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/11/8a/5a7fd6284fa8caac23a26c9ddf9c30485a48169344b4bd3b0f02fef1890f/sympy-1.13.3.tar.gz", hash = "sha256:b27fd2c6530e0ab39e275fc9b683895367e51d5da91baa8d3d64db2565fec4d9", size = 7533196 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/ff/c87e0622b1dadea79d2fb0b25ade9ed98954c9033722eb707053d310d4f3/sympy-1.13.3-py3-none-any.whl", hash = "sha256:54612cf55a62755ee71824ce692986f23c88ffa77207b30c1368eda4a7060f73", size = 6189483 }, +] + +[[package]] +name = "tabulate" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252 }, +] + +[[package]] +name = "tach" +version = "0.23.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gitpython" }, + { name = "networkx" }, + { name = "prompt-toolkit" }, + { name = "pydot" }, + { name = "pyyaml" }, + { name = "rich" }, + { name = "tomli" }, + { name = "tomli-w" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d9/4b/de2e7ad0a22e63fbed979064381da1290391dd623a3fd80d0728ea72d545/tach-0.23.0.tar.gz", hash = "sha256:ae123491231ab0712417d579b9a3259014d713d72626805ff64552955e43e912", size = 482218 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/87/9aa4142dc31314500af0003f406851a212b589a7e680e78c39751fc26681/tach-0.23.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:aa30db4158e48694154d346def14d3a096672381fa09e3cf09eae190ff9066f0", size = 3240516 }, + { url = "https://files.pythonhosted.org/packages/b3/db/3d856d856a688b024470494785dc8d177e1728904e180aa9394e80d8787e/tach-0.23.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:2e54365a3101c08a35d51357007e37723cd86c8bf464b73a3b43401edd2053d8", size = 3095903 }, + { url = "https://files.pythonhosted.org/packages/19/c9/1302175f5b350891727356c03bfdbffb884323db3c30cc34b2c7e93c932b/tach-0.23.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0af6be9328ec907deac141165b43b7db58f055bc20ea46b65b82b10fed72cd3", size = 3373159 }, + { url = "https://files.pythonhosted.org/packages/af/3d/ad4a2f4e2142b789085886a3acbb2f8e1a99068014303c7aa1166350aa38/tach-0.23.0-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1b8205440863f61389b29a9baf2e2cd171d87c6931f3d6baf69eda69092440df", size = 3325828 }, + { url = "https://files.pythonhosted.org/packages/ab/87/4114a20e97f9a8652865bdf541d7b3121a731d6539d7f6b7d6bb70a86f46/tach-0.23.0-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b783d0f121c579f761dad7bf6ceeddec8f901e3778ed29a2db57c1c17804577", size = 3627127 }, + { url = "https://files.pythonhosted.org/packages/b5/cd/88b4f103eea5d2a3b0696265131f43f07e5bf9b1b81ccc0471512121ceae/tach-0.23.0-cp37-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:625403b59430eee9b5c2c05dff9575c8623ea88bcf58728e55b843fdbf04031d", size = 3623389 }, + { url = "https://files.pythonhosted.org/packages/12/77/3be44b77ad3ab8a6f05c245e399ff1e9f48df6be5e706c34b0863eaa4bdc/tach-0.23.0-cp37-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0c9671d4be806f9aa6a714a38ac26b455704ac01019555f2441445335e749fb5", size = 3884923 }, + { url = "https://files.pythonhosted.org/packages/d7/8b/d7f9c9a1cb6a0f6745a1c4cdb824bc1abbac2a4f9fa30e57de37b7a223b9/tach-0.23.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:caede4e23800d412c83b96288c7f03845971f6ea10dcfff40a026d294db1996f", size = 3483408 }, + { url = "https://files.pythonhosted.org/packages/48/8e/930460944b5cddeff297de774981ce8ffd1e80c59ea5f0616ade89a6871b/tach-0.23.0-cp37-abi3-win32.whl", hash = "sha256:828a59f7e2effdac3802025177b1a83e53b27ee54b00ef6305a0e36cec448e55", size = 2725999 }, + { url = "https://files.pythonhosted.org/packages/ea/01/4e4c9b551fa9ffd0db74e14966c393928aefa59019b6d5bd8a9a645ee714/tach-0.23.0-cp37-abi3-win_amd64.whl", hash = "sha256:5dc03ef01d1a2e9d39fa238c271e9a4f8d9db2459212425ceb05b8ed0547000f", size = 2930346 }, +] + +[[package]] +name = "tomli" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/18/87/302344fed471e44a87289cf4967697d07e532f2421fdaf868a303cbae4ff/tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff", size = 17175 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/ca/75707e6efa2b37c77dadb324ae7d9571cb424e61ea73fad7c56c2d14527f/tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249", size = 131077 }, + { url = "https://files.pythonhosted.org/packages/c7/16/51ae563a8615d472fdbffc43a3f3d46588c264ac4f024f63f01283becfbb/tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6", size = 123429 }, + { url = "https://files.pythonhosted.org/packages/f1/dd/4f6cd1e7b160041db83c694abc78e100473c15d54620083dbd5aae7b990e/tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a", size = 226067 }, + { url = "https://files.pythonhosted.org/packages/a9/6b/c54ede5dc70d648cc6361eaf429304b02f2871a345bbdd51e993d6cdf550/tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee", size = 236030 }, + { url = "https://files.pythonhosted.org/packages/1f/47/999514fa49cfaf7a92c805a86c3c43f4215621855d151b61c602abb38091/tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e", size = 240898 }, + { url = "https://files.pythonhosted.org/packages/73/41/0a01279a7ae09ee1573b423318e7934674ce06eb33f50936655071d81a24/tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4", size = 229894 }, + { url = "https://files.pythonhosted.org/packages/55/18/5d8bc5b0a0362311ce4d18830a5d28943667599a60d20118074ea1b01bb7/tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106", size = 245319 }, + { url = "https://files.pythonhosted.org/packages/92/a3/7ade0576d17f3cdf5ff44d61390d4b3febb8a9fc2b480c75c47ea048c646/tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8", size = 238273 }, + { url = "https://files.pythonhosted.org/packages/72/6f/fa64ef058ac1446a1e51110c375339b3ec6be245af9d14c87c4a6412dd32/tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff", size = 98310 }, + { url = "https://files.pythonhosted.org/packages/6a/1c/4a2dcde4a51b81be3530565e92eda625d94dafb46dbeb15069df4caffc34/tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b", size = 108309 }, + { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 }, +] + +[[package]] +name = "tomli-w" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/19/75/241269d1da26b624c0d5e110e8149093c759b7a286138f4efd61a60e75fe/tomli_w-1.2.0.tar.gz", hash = "sha256:2dd14fac5a47c27be9cd4c976af5a12d87fb1f0b4512f81d69cce3b35ae25021", size = 7184 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/18/c86eb8e0202e32dd3df50d43d7ff9854f8e0603945ff398974c1d91ac1ef/tomli_w-1.2.0-py3-none-any.whl", hash = "sha256:188306098d013b691fcadc011abd66727d3c414c571bb01b1a174ba8c983cf90", size = 6675 }, +] + +[[package]] +name = "tomlkit" +version = "0.13.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b1/09/a439bec5888f00a54b8b9f05fa94d7f901d6735ef4e55dcec9bc37b5d8fa/tomlkit-0.13.2.tar.gz", hash = "sha256:fff5fe59a87295b278abd31bec92c15d9bc4a06885ab12bcea52c71119392e79", size = 192885 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/b6/a447b5e4ec71e13871be01ba81f5dfc9d0af7e473da256ff46bc0e24026f/tomlkit-0.13.2-py3-none-any.whl", hash = "sha256:7a974427f6e119197f670fbbbeae7bef749a6c14e793db934baefc1b5f03efde", size = 37955 }, +] + +[[package]] +name = "toolz" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8a/0b/d80dfa675bf592f636d1ea0b835eab4ec8df6e9415d8cfd766df54456123/toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02", size = 66790 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/03/98/eb27cc78ad3af8e302c9d8ff4977f5026676e130d28dd7578132a457170c/toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236", size = 56383 }, +] + +[[package]] +name = "tornado" +version = "6.4.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/59/45/a0daf161f7d6f36c3ea5fc0c2de619746cc3dd4c76402e9db545bd920f63/tornado-6.4.2.tar.gz", hash = "sha256:92bad5b4746e9879fd7bf1eb21dce4e3fc5128d71601f80005afa39237ad620b", size = 501135 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/26/7e/71f604d8cea1b58f82ba3590290b66da1e72d840aeb37e0d5f7291bd30db/tornado-6.4.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e828cce1123e9e44ae2a50a9de3055497ab1d0aeb440c5ac23064d9e44880da1", size = 436299 }, + { url = "https://files.pythonhosted.org/packages/96/44/87543a3b99016d0bf54fdaab30d24bf0af2e848f1d13d34a3a5380aabe16/tornado-6.4.2-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:072ce12ada169c5b00b7d92a99ba089447ccc993ea2143c9ede887e0937aa803", size = 434253 }, + { url = "https://files.pythonhosted.org/packages/cb/fb/fdf679b4ce51bcb7210801ef4f11fdac96e9885daa402861751353beea6e/tornado-6.4.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a017d239bd1bb0919f72af256a970624241f070496635784d9bf0db640d3fec", size = 437602 }, + { url = "https://files.pythonhosted.org/packages/4f/3b/e31aeffffc22b475a64dbeb273026a21b5b566f74dee48742817626c47dc/tornado-6.4.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c36e62ce8f63409301537222faffcef7dfc5284f27eec227389f2ad11b09d946", size = 436972 }, + { url = "https://files.pythonhosted.org/packages/22/55/b78a464de78051a30599ceb6983b01d8f732e6f69bf37b4ed07f642ac0fc/tornado-6.4.2-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bca9eb02196e789c9cb5c3c7c0f04fb447dc2adffd95265b2c7223a8a615ccbf", size = 437173 }, + { url = "https://files.pythonhosted.org/packages/79/5e/be4fb0d1684eb822c9a62fb18a3e44a06188f78aa466b2ad991d2ee31104/tornado-6.4.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:304463bd0772442ff4d0f5149c6f1c2135a1fae045adf070821c6cdc76980634", size = 437892 }, + { url = "https://files.pythonhosted.org/packages/f5/33/4f91fdd94ea36e1d796147003b490fe60a0215ac5737b6f9c65e160d4fe0/tornado-6.4.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:c82c46813ba483a385ab2a99caeaedf92585a1f90defb5693351fa7e4ea0bf73", size = 437334 }, + { url = "https://files.pythonhosted.org/packages/2b/ae/c1b22d4524b0e10da2f29a176fb2890386f7bd1f63aacf186444873a88a0/tornado-6.4.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:932d195ca9015956fa502c6b56af9eb06106140d844a335590c1ec7f5277d10c", size = 437261 }, + { url = "https://files.pythonhosted.org/packages/b5/25/36dbd49ab6d179bcfc4c6c093a51795a4f3bed380543a8242ac3517a1751/tornado-6.4.2-cp38-abi3-win32.whl", hash = "sha256:2876cef82e6c5978fde1e0d5b1f919d756968d5b4282418f3146b79b58556482", size = 438463 }, + { url = "https://files.pythonhosted.org/packages/61/cc/58b1adeb1bb46228442081e746fcdbc4540905c87e8add7c277540934edb/tornado-6.4.2-cp38-abi3-win_amd64.whl", hash = "sha256:908b71bf3ff37d81073356a5fadcc660eb10c1476ee6e2725588626ce7e5ca38", size = 438907 }, +] + +[[package]] +name = "traitlets" +version = "5.14.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/79/72064e6a701c2183016abbbfedaba506d81e30e232a68c9f0d6f6fcd1574/traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7", size = 161621 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359 }, +] + +[[package]] +name = "types-decorator" +version = "5.1.8.20250121" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f4/e6/88de14bb1d1073495b9d9459f90fbb78fe93d89beefcf0af94b871993a56/types_decorator-5.1.8.20250121.tar.gz", hash = "sha256:1b89bb1c481a1d3399e28f1aa3459366b76dde951490992ae8475ba91287cd04", size = 8496 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/0e/59b9637fa66fbe419886b17d59b90e5e4256325c01f94f81dcc44fbeda53/types_decorator-5.1.8.20250121-py3-none-any.whl", hash = "sha256:6bfd5f4464f444a1ee0aea92705ed8466d74c0ddd7ade4bbd003c235db51d21a", size = 8078 }, +] + +[[package]] +name = "types-docutils" +version = "0.21.0.20241128" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/dd/df/64e7ab01a4fc5ce46895dc94e31cffc8b8087c8d91ee54c45ac2d8d82445/types_docutils-0.21.0.20241128.tar.gz", hash = "sha256:4dd059805b83ac6ec5a223699195c4e9eeb0446a4f7f2aeff1759a4a7cc17473", size = 26739 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/59/b6/10ba95739f2cbb9c5bd2f6568148d62b468afe01a94c633e8892a2936d8a/types_docutils-0.21.0.20241128-py3-none-any.whl", hash = "sha256:e0409204009639e9b0bf4521eeabe58b5e574ce9c0db08421c2ac26c32be0039", size = 34677 }, +] + +[[package]] +name = "types-pytz" +version = "2024.2.0.20241221" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/54/26/516311b02b5a215e721155fb65db8a965d061372e388d6125ebce8d674b0/types_pytz-2024.2.0.20241221.tar.gz", hash = "sha256:06d7cde9613e9f7504766a0554a270c369434b50e00975b3a4a0f6eed0f2c1a9", size = 10213 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/74/db/c92ca6920cccd9c2998b013601542e2ac5e59bc805bcff94c94ad254b7df/types_pytz-2024.2.0.20241221-py3-none-any.whl", hash = "sha256:8fc03195329c43637ed4f593663df721fef919b60a969066e22606edf0b53ad5", size = 10008 }, +] + +[[package]] +name = "types-pyyaml" +version = "6.0.12.20241230" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/f9/4d566925bcf9396136c0a2e5dc7e230ff08d86fa011a69888dd184469d80/types_pyyaml-6.0.12.20241230.tar.gz", hash = "sha256:7f07622dbd34bb9c8b264fe860a17e0efcad00d50b5f27e93984909d9363498c", size = 17078 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e8/c1/48474fbead512b70ccdb4f81ba5eb4a58f69d100ba19f17c92c0c4f50ae6/types_PyYAML-6.0.12.20241230-py3-none-any.whl", hash = "sha256:fa4d32565219b68e6dee5f67534c722e53c00d1cfc09c435ef04d7353e1e96e6", size = 20029 }, +] + +[[package]] +name = "types-tabulate" +version = "0.9.0.20241207" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3f/43/16030404a327e4ff8c692f2273854019ed36718667b2993609dc37d14dd4/types_tabulate-0.9.0.20241207.tar.gz", hash = "sha256:ac1ac174750c0a385dfd248edc6279fa328aaf4ea317915ab879a2ec47833230", size = 8195 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/86/a9ebfd509cbe74471106dffed320e208c72537f9aeb0a55eaa6b1b5e4d17/types_tabulate-0.9.0.20241207-py3-none-any.whl", hash = "sha256:b8dad1343c2a8ba5861c5441370c3e35908edd234ff036d4298708a1d4cf8a85", size = 8307 }, +] + +[[package]] +name = "typing-extensions" +version = "4.12.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/df/db/f35a00659bc03fec321ba8bce9420de607a1d37f8342eee1863174c69557/typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8", size = 85321 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d", size = 37438 }, +] + +[[package]] +name = "urllib3" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/aa/63/e53da845320b757bf29ef6a9062f5c669fe997973f966045cb019c3f4b66/urllib3-2.3.0.tar.gz", hash = "sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d", size = 307268 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 }, +] + +[[package]] +name = "virtualenv" +version = "20.29.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "distlib" }, + { name = "filelock" }, + { name = "platformdirs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a7/ca/f23dcb02e161a9bba141b1c08aa50e8da6ea25e6d780528f1d385a3efe25/virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35", size = 7658028 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/9b/599bcfc7064fbe5740919e78c5df18e5dceb0887e676256a1061bb5ae232/virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779", size = 4282379 }, +] + +[[package]] +name = "wcmatch" +version = "10.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "bracex" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/41/ab/b3a52228538ccb983653c446c1656eddf1d5303b9cb8b9aef6a91299f862/wcmatch-10.0.tar.gz", hash = "sha256:e72f0de09bba6a04e0de70937b0cf06e55f36f37b3deb422dfaf854b867b840a", size = 115578 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/df/4ee467ab39cc1de4b852c212c1ed3becfec2e486a51ac1ce0091f85f38d7/wcmatch-10.0-py3-none-any.whl", hash = "sha256:0dd927072d03c0a6527a20d2e6ad5ba8d0380e60870c383bc533b71744df7b7a", size = 39347 }, +] + +[[package]] +name = "wcwidth" +version = "0.2.13" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc5a5974feb13d31d78d752eb18aeba59c7fef1af7598/wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5", size = 101301 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 }, +] + +[[package]] +name = "webencodings" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/02/ae6ceac1baeda530866a85075641cec12989bd8d31af6d5ab4a3e8c92f47/webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923", size = 9721 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", size = 11774 }, +] + +[[package]] +name = "wheel" +version = "0.45.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8a/98/2d9906746cdc6a6ef809ae6338005b3f21bb568bea3165cfc6a243fdc25c/wheel-0.45.1.tar.gz", hash = "sha256:661e1abd9198507b1409a20c02106d9670b2576e916d58f520316666abca6729", size = 107545 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/2c/87f3254fd8ffd29e4c02732eee68a83a1d3c346ae39bc6822dcbcb697f2b/wheel-0.45.1-py3-none-any.whl", hash = "sha256:708e7481cc80179af0e556bbf0cc00b8444c7321e2700b8d8580231d13017248", size = 72494 }, +] + +[[package]] +name = "xxhash" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/3e/ca49932bade8b3308e74df951c36cbc84c8230c9b8715bae1e0014831aa7/xxhash-3.0.0.tar.gz", hash = "sha256:30b2d97aaf11fb122023f6b44ebb97c6955e9e00d7461a96415ca030b5ceb9c7", size = 74279 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f0/fe/41444c518df82da46bc7125c9daa4159e6cfc2b682ccc73493b0485b8a70/xxhash-3.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:219cba13991fd73cf21a5efdafa5056f0ae0b8f79e5e0112967e3058daf73eea", size = 34110 }, + { url = "https://files.pythonhosted.org/packages/6f/83/0afffed636656f65f78e35da174c9bdd86367f9d4da23a87fc9d1b933bbe/xxhash-3.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3fcbb846af15eff100c412ae54f4974ff277c92eacd41f1ec7803a64fd07fa0c", size = 30664 }, + { url = "https://files.pythonhosted.org/packages/8c/b1/cde24bf3c9d4d6bbe02e9e82604dbd40ab21c9799b0fdb66a4fe2046e96d/xxhash-3.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f475fa817ff7955fc118fc1ca29a6e691d329b7ff43f486af36c22dbdcff1db", size = 241825 }, + { url = "https://files.pythonhosted.org/packages/70/fd/7ebfe1549551c87875b64cf9c925e3cf8be53e475d29aed933643f6dd8aa/xxhash-3.0.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9200a90f02ff6fd5fb63dea107842da71d8626d99b768fd31be44f3002c60bbe", size = 206492 }, + { url = "https://files.pythonhosted.org/packages/d2/6f/eafbb4ec3baf499423f2de3a5f3b6c5898f3bf4a8714e100d5dfb911fbad/xxhash-3.0.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a1403e4f551c9ef7bcef09af55f1adb169f13e4de253db0887928e5129f87af1", size = 286394 }, + { url = "https://files.pythonhosted.org/packages/64/05/504e1a7accc8f115ebfba96104c2f4a4aea3fb415bd664a6a1cc8915671e/xxhash-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa7f6ca53170189a2268c83af0980e6c10aae69e6a5efa7ca989f89fff9f8c02", size = 211550 }, + { url = "https://files.pythonhosted.org/packages/f8/b9/b6558ba62479dbdd18f894842f6ec01bbbf94aa8a26340f889c1af550fa8/xxhash-3.0.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b63fbeb6d9c93d50ae0dc2b8a8b7f52f2de19e40fe9edc86637bfa5743b8ba2", size = 219718 }, + { url = "https://files.pythonhosted.org/packages/19/7a/270f9c47d9748b7d43ec2ce0ee1d50c189ccf21e7ba6adc39e4045fcd450/xxhash-3.0.0-cp310-cp310-win32.whl", hash = "sha256:31f25efd10b6f1f6d5c34cd231986d8aae9a42e042daa90b783917f170807869", size = 30157 }, + { url = "https://files.pythonhosted.org/packages/67/54/f98d6eccb96da4fc51f4397123828c593c6f2731ede141f2318d1aab8a6b/xxhash-3.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:807e88ed56e0fb347cb57d5bf44851f9878360fed700f2f63e622ef4eede87a5", size = 29918 }, +] From 5752f2b807e254fe55ee60273b49e07e7f8b5061 Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 30 Jan 2025 06:42:44 +0100 Subject: [PATCH 123/178] feat[next][dace]: keep transients on the output of a mapped nested SDFG (#1828) Small change to the lowering from GTIR, that keeps a transient buffer on the output of a nested SDFG, before the write memlet through the `MapExit` node. Removal of the transient, which was the baseline behavior, prevents map fusion in the optimization workflow. --- .../next/program_processors/runners/dace/gtir_dataflow.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py index a34828afcb..59d1a0087a 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -233,8 +233,12 @@ def connect( ) -> None: # retrieve the node which writes the result last_node = self.state.in_edges(self.result.dc_node)[0].src - if isinstance(last_node, (dace.nodes.Tasklet, dace.nodes.NestedSDFG)): + if isinstance(last_node, dace.nodes.Tasklet): # the last transient node can be deleted + # Note that it could also be applied when `last_node` is a NestedSDFG, + # but an exception would be when the inner write to global data is a + # WCR memlet, because that prevents fusion of the outer map. This case + # happens for the reduce with skip values, which uses a map with WCR. last_node_connector = self.state.in_edges(self.result.dc_node)[0].src_conn self.state.remove_node(self.result.dc_node) else: From 1ccc7c90a9ed5b726bb4d10c9106bb2d19346d2b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 30 Jan 2025 17:20:48 +0100 Subject: [PATCH 124/178] docs: update detailed nox examples (#1838) ## Description Fixed the detailed examples for working with nox (listing all sessions, running a specific session) in the contribution guidelines. ## Requirements - [x] All fixes and/or new features come with corresponding tests. Tested locally by copy/pasting the commands. - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- CONTRIBUTING.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e0ef75d31e..28134a61b9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -146,11 +146,11 @@ Check `pytest` documentation (`pytest --help`) for all the options to select and We recommended you to use `nox` for running the test suite in different environments. `nox` runs the package installation script in properly isolated environments to run tests in a reproducible way. A simple way to start with `nox` would be: ```bash -# List all the available task environments -nox list +# List all available sessions +nox --list -# Run a specific task environment -nox -e cartesian-py38-internal-cpu +# Run a specific session +nox -s "test_cartesian-3.10(internal, cpu)" ``` Check `nox` documentation (`nox --help`) for the complete reference. From 050d3b36ac176196f6758459f65efdac9ba084eb Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 30 Jan 2025 17:27:45 +0100 Subject: [PATCH 125/178] refactor[cartesian]: Minor cleanup in backends (#1833) ## Description I was reading a lot of code around DaCe/gt-codegen when debugging the new DaCe/gt4py bridge. This PR combines three cleanup commits: - Always get stencil_ir from builder in GTBaseBackends. I've found no usage of `stencil_ir` being anything else than `self.build.gtir` if it was explicitly passed as an argument at all. There's thus no need to pass around `self.build.gtir` as long as we stay in the same class hierarchy. - Avoid unnecessary indenting in generated code. Generated code is optionally formatted, but even if not, we can make sure the code doesn't look too ugly. - Avoid double formatting of source code (if gt4py/dace is configured to do so). No need for formatting intermediate code parts because it's formatted anyway at the end. ## Requirements - [x] All fixes and/or new features come with corresponding tests. Updated test accordingly. - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- src/gt4py/cartesian/backend/cuda_backend.py | 2 +- src/gt4py/cartesian/backend/dace_backend.py | 64 +++++++++---------- src/gt4py/cartesian/backend/gtc_common.py | 39 +++++------ src/gt4py/cartesian/backend/gtcpp_backend.py | 2 +- .../backend_tests/test_backend_api.py | 6 +- 5 files changed, 50 insertions(+), 63 deletions(-) diff --git a/src/gt4py/cartesian/backend/cuda_backend.py b/src/gt4py/cartesian/backend/cuda_backend.py index afa749e3f1..9646383c0f 100644 --- a/src/gt4py/cartesian/backend/cuda_backend.py +++ b/src/gt4py/cartesian/backend/cuda_backend.py @@ -141,7 +141,7 @@ class CudaBackend(BaseGTBackend, CLIBackendMixin): GT_BACKEND_T = "gpu" def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: - return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=True) + return self.make_extension(uses_cuda=True) def generate(self) -> Type[StencilObject]: self.check_options(self.builder.options) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 35265f0530..5b822a1ab5 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -12,7 +12,6 @@ import os import pathlib import re -import textwrap from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import dace @@ -432,20 +431,20 @@ def __call__(self, stencil_ir: gtir.Stencil) -> Dict[str, Dict[str, str]]: class DaCeComputationCodegen: template = as_mako( - """ - auto ${name}(const std::array& domain) { - return [domain](${",".join(functor_args)}) { - const int __I = domain[0]; - const int __J = domain[1]; - const int __K = domain[2]; - ${name}${state_suffix} dace_handle; - ${backend_specifics} - auto allocator = gt::sid::cached_allocator(&${allocator}); - ${"\\n".join(tmp_allocs)} - __program_${name}(${",".join(["&dace_handle", *dace_args])}); - }; - } - """ + """\ +auto ${name}(const std::array& domain) { + return [domain](${",".join(functor_args)}) { + const int __I = domain[0]; + const int __J = domain[1]; + const int __K = domain[2]; + ${name}${state_suffix} dace_handle; + ${backend_specifics} + auto allocator = gt::sid::cached_allocator(&${allocator}); + ${"\\n".join(tmp_allocs)} + __program_${name}(${",".join(["&dace_handle", *dace_args])}); + }; +} +""" ) def generate_tmp_allocs(self, sdfg): @@ -511,7 +510,7 @@ def _postprocess_dace_code(code_objects, is_gpu, builder): lines = lines[0:i] + cuda_code.split("\n") + lines[i + 1 :] break - def keep_line(line): + def keep_line(line: str) -> bool: line = line.strip() if line == '#include "../../include/hash.h"': return False @@ -521,11 +520,7 @@ def keep_line(line): return False return True - lines = filter(keep_line, lines) - generated_code = "\n".join(lines) - if builder.options.format_source: - generated_code = codegen.format_source("cpp", generated_code, style="LLVM") - return generated_code + return "\n".join(filter(keep_line, lines)) @classmethod def apply(cls, stencil_ir: gtir.Stencil, builder: StencilBuilder, sdfg: dace.SDFG): @@ -563,17 +558,18 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: StencilBuilder, sdfg: dace.SDF allocator="gt::cuda_util::cuda_malloc" if is_gpu else "std::make_unique", state_suffix=dace.Config.get("compiler.codegen_state_struct_suffix"), ) - generated_code = textwrap.dedent( - f"""#include - #include - #include - {"#include " if is_gpu else omp_header} - namespace gt = gridtools; - {computations} - - {interface} - """ - ) + generated_code = f"""\ +#include +#include +#include +{"#include " if is_gpu else omp_header} +namespace gt = gridtools; + +{computations} + +{interface} +""" + if builder.options.format_source: generated_code = codegen.format_source("cpp", generated_code, style="LLVM") @@ -794,7 +790,7 @@ class DaceCPUBackend(BaseDaceBackend): options = BaseGTBackend.GT_BACKEND_OPTS def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: - return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=False) + return self.make_extension(uses_cuda=False) @register @@ -815,4 +811,4 @@ class DaceGPUBackend(BaseDaceBackend): options = {**BaseGTBackend.GT_BACKEND_OPTS, "device_sync": {"versioning": True, "type": bool}} def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: - return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=True) + return self.make_extension(uses_cuda=True) diff --git a/src/gt4py/cartesian/backend/gtc_common.py b/src/gt4py/cartesian/backend/gtc_common.py index abc4baede1..348e85de92 100644 --- a/src/gt4py/cartesian/backend/gtc_common.py +++ b/src/gt4py/cartesian/backend/gtc_common.py @@ -236,19 +236,15 @@ def generate(self) -> Type[StencilObject]: def generate_computation(self) -> Dict[str, Union[str, Dict]]: dir_name = f"{self.builder.options.name}_src" - src_files = self.make_extension_sources(stencil_ir=self.builder.gtir) + src_files = self._make_extension_sources() return {dir_name: src_files["computation"]} - def generate_bindings( - self, language_name: str, *, stencil_ir: Optional[gtir.Stencil] = None - ) -> Dict[str, Union[str, Dict]]: - if not stencil_ir: - stencil_ir = self.builder.gtir - assert stencil_ir is not None + def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]: if language_name != "python": return super().generate_bindings(language_name) + dir_name = f"{self.builder.options.name}_src" - src_files = self.make_extension_sources(stencil_ir=stencil_ir) + src_files = self._make_extension_sources() return {dir_name: src_files["bindings"]} @abc.abstractmethod @@ -260,32 +256,26 @@ def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: """ pass - def make_extension( - self, *, stencil_ir: Optional[gtir.Stencil] = None, uses_cuda: bool = False - ) -> Tuple[str, str]: + def make_extension(self, *, uses_cuda: bool = False) -> Tuple[str, str]: build_info = self.builder.options.build_info if build_info is not None: start_time = time.perf_counter() - if not stencil_ir: - stencil_ir = self.builder.gtir - assert stencil_ir is not None - # Generate source gt_pyext_files: Dict[str, Any] gt_pyext_sources: Dict[str, Any] - if not self.builder.options._impl_opts.get("disable-code-generation", False): - gt_pyext_files = self.make_extension_sources(stencil_ir=stencil_ir) - gt_pyext_sources = { - **gt_pyext_files["computation"], - **gt_pyext_files["bindings"], - } - else: + if self.builder.options._impl_opts.get("disable-code-generation", False): # Pass NOTHING to the self.builder means try to reuse the source code files gt_pyext_files = {} gt_pyext_sources = { key: gt_utils.NOTHING for key in self.PYEXT_GENERATOR_CLASS.TEMPLATE_FILES.keys() } + else: + gt_pyext_files = self._make_extension_sources() + gt_pyext_sources = { + **gt_pyext_files["computation"], + **gt_pyext_files["bindings"], + } if build_info is not None: next_time = time.perf_counter() @@ -317,10 +307,11 @@ def make_extension( return result - def make_extension_sources(self, *, stencil_ir: gtir.Stencil) -> Dict[str, Dict[str, str]]: + def _make_extension_sources(self) -> Dict[str, Dict[str, str]]: """Generate the source for the stencil independently from use case.""" if "computation_src" in self.builder.backend_data: return self.builder.backend_data["computation_src"] + class_name = self.pyext_class_name if self.builder.stencil_id else self.builder.options.name module_name = ( self.pyext_module_name @@ -328,7 +319,7 @@ def make_extension_sources(self, *, stencil_ir: gtir.Stencil) -> Dict[str, Dict[ else f"{self.builder.options.name}_pyext" ) gt_pyext_generator = self.PYEXT_GENERATOR_CLASS(class_name, module_name, self) - gt_pyext_sources = gt_pyext_generator(stencil_ir) + gt_pyext_sources = gt_pyext_generator(self.builder.gtir) final_ext = ".cu" if self.languages and self.languages["computation"] == "cuda" else ".cpp" comp_src = gt_pyext_sources["computation"] for key in [k for k in comp_src.keys() if k.endswith(".src")]: diff --git a/src/gt4py/cartesian/backend/gtcpp_backend.py b/src/gt4py/cartesian/backend/gtcpp_backend.py index 8053409195..96f5672ae4 100644 --- a/src/gt4py/cartesian/backend/gtcpp_backend.py +++ b/src/gt4py/cartesian/backend/gtcpp_backend.py @@ -129,7 +129,7 @@ class GTBaseBackend(BaseGTBackend, CLIBackendMixin): PYEXT_GENERATOR_CLASS = GTExtGenerator def _generate_extension(self, uses_cuda: bool) -> Tuple[str, str]: - return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=uses_cuda) + return self.make_extension(uses_cuda=uses_cuda) def generate(self) -> Type[StencilObject]: self.check_options(self.builder.options) diff --git a/tests/cartesian_tests/unit_tests/backend_tests/test_backend_api.py b/tests/cartesian_tests/unit_tests/backend_tests/test_backend_api.py index c47ad10e94..3fbf586b35 100644 --- a/tests/cartesian_tests/unit_tests/backend_tests/test_backend_api.py +++ b/tests/cartesian_tests/unit_tests/backend_tests/test_backend_api.py @@ -79,7 +79,7 @@ def test_generate_bindings(backend, tmp_path): ) else: # assumption: only gt backends support python bindings for other languages than python - result = builder.backend.generate_bindings("python", stencil_ir=builder.gtir) + result = builder.backend.generate_bindings("python") assert "init_1_src" in result - srcs = result["init_1_src"] - assert "bindings.cpp" in srcs or "bindings.cu" in srcs + sources = result["init_1_src"] + assert "bindings.cpp" in sources or "bindings.cu" in sources From f0c67e62134a466aed1986f0d853ef5530244906 Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 30 Jan 2025 22:33:51 +0100 Subject: [PATCH 126/178] feat[next][dace]: support for field origin in lowering to SDFG (#1818) This PR adds support for GT4Py field arguments with non-zero start index, for example: `inp = constructors.empty(common.domain({IDim: (1, 9)}), ...)` which was supported in baseline only for temporary fields, by means of a data structure called `field_offsets`. This data structure is removed for two reasons: 1. the name "offset" is a left-over from previous design based on dace array offset 3. offset has a different meaning in GT4Py We introduce the GT4Py concept of field origin and use it for both temporary fields and program arguments. The field origin corresponds to the start of the field domain range. This PR also changes the symbolic definition of array shape. Before, the array shape was defined as `[data_size_0, data_size_1, ...]`, now the size corresponds to the range extent `stop - start` as `[(data_0_range_1 - data_0_range_0), (data_1_range_1 - data_1_range_0), ...]`. The translation stage of the dace workflow is extended with an option `disable_field_origin_on_program_arguments` to set the field range start symbols to constant value zero. This is needed for the dace orchestration, because the signature of a dace-orchestrated program does not provide the domain origin. --- .../runners/dace/gtir_builtin_translators.py | 177 +++++++++--- .../runners/dace/gtir_dataflow.py | 9 +- .../runners/dace/gtir_python_codegen.py | 11 +- .../runners/dace/gtir_scan_translator.py | 102 +++---- .../runners/dace/gtir_sdfg.py | 268 ++++++++++-------- .../runners/dace/program.py | 6 +- .../runners/dace/sdfg_callable.py | 54 ++-- .../program_processors/runners/dace/utils.py | 38 ++- .../runners/dace/workflow/translation.py | 10 +- tests/next_tests/definitions.py | 1 - .../feature_tests/dace/test_orchestration.py | 36 ++- .../feature_tests/dace/test_program.py | 4 - .../iterator_tests/test_hdiff.py | 1 + .../dace_tests/test_dace_utils.py | 21 ++ .../dace_tests/test_gtir_to_sdfg.py | 159 ++++++----- 15 files changed, 540 insertions(+), 357 deletions(-) create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_utils.py diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py index 0fe776c3ee..6b2a32c063 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py @@ -39,31 +39,28 @@ def get_domain_indices( - dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None + dims: Sequence[gtx_common.Dimension], origin: Optional[Sequence[dace.symbolic.SymExpr]] ) -> dace_subsets.Indices: """ Helper function to construct the list of indices for a field domain, applying - an optional offset in each dimension as start index. + an optional origin in each dimension as start index. Args: dims: The field dimensions. - offsets: The range start index in each dimension. + origin: The domain start index in each dimension. If set to `None`, assume all zeros. Returns: A list of indices for field access in dace arrays. As this list is returned as `dace.subsets.Indices`, it should be converted to `dace.subsets.Range` before being used in memlet subset because ranges are better supported throughout DaCe. """ - index_variables = [dace.symbolic.SymExpr(gtir_sdfg_utils.get_map_variable(dim)) for dim in dims] - if offsets is None: - return dace_subsets.Indices(index_variables) - else: - return dace_subsets.Indices( - [ - index - offset if offset != 0 else index - for index, offset in zip(index_variables, offsets, strict=True) - ] - ) + index_variables = [ + dace.symbolic.pystr_to_symbolic(gtir_sdfg_utils.get_map_variable(dim)) for dim in dims + ] + origin = [0] * len(index_variables) if origin is None else origin + return dace_subsets.Indices( + [index - start_index for index, start_index in zip(index_variables, origin, strict=True)] + ) @dataclasses.dataclass(frozen=True) @@ -78,18 +75,58 @@ class FieldopData: Args: dc_node: DaCe access node to the data storage. gt_type: GT4Py type definition, which includes the field domain information. - offset: List of index offsets, in each dimension, when the dimension range - does not start from zero; assume zero offset, if not set. + origin: Tuple of start indices, in each dimension, for `FieldType` data. + Pass an empty tuple for `ScalarType` data or zero-dimensional fields. """ dc_node: dace.nodes.AccessNode gt_type: ts.FieldType | ts.ScalarType - offset: Optional[list[dace.symbolic.SymExpr]] + origin: tuple[dace.symbolic.SymbolicType, ...] + + def __post_init__(self) -> None: + """Implements a sanity check on the constructed data type.""" + assert ( + len(self.origin) == 0 + if isinstance(self.gt_type, ts.ScalarType) + else len(self.origin) == len(self.gt_type.dims) + ) + + def map_to_parent_sdfg( + self, + sdfg_builder: gtir_sdfg.SDFGBuilder, + inner_sdfg: dace.SDFG, + outer_sdfg: dace.SDFG, + outer_sdfg_state: dace.SDFGState, + symbol_mapping: dict[str, dace.symbolic.SymbolicType], + ) -> FieldopData: + """ + Make the data descriptor which 'self' refers to, and which is located inside + a NestedSDFG, available in its parent SDFG. - def make_copy(self, data_node: dace.nodes.AccessNode) -> FieldopData: - """Create a copy of this data descriptor with a different access node.""" - assert data_node != self.dc_node - return FieldopData(data_node, self.gt_type, self.offset) + Thus, it turns 'self' into a non-transient array and creates a new data + descriptor inside the parent SDFG, with same shape and strides. + """ + inner_desc = self.dc_node.desc(inner_sdfg) + assert inner_desc.transient + inner_desc.transient = False + + if isinstance(self.gt_type, ts.ScalarType): + outer, outer_desc = sdfg_builder.add_temp_scalar(outer_sdfg, inner_desc.dtype) + outer_origin = [] + else: + outer, outer_desc = sdfg_builder.add_temp_array_like(outer_sdfg, inner_desc) + # We cannot use a copy of the inner data descriptor directly, we have to apply the symbol mapping. + dace.symbolic.safe_replace( + symbol_mapping, + lambda m: dace.sdfg.replace_properties_dict(outer_desc, m), + ) + # Same applies to the symbols used as field origin (the domain range start) + outer_origin = [ + gtx_dace_utils.safe_replace_symbolic(val, symbol_mapping) for val in self.origin + ] + + outer_node = outer_sdfg_state.add_access(outer) + return FieldopData(outer_node, self.gt_type, tuple(outer_origin)) def get_local_view( self, domain: FieldopDomain @@ -97,18 +134,20 @@ def get_local_view( """Helper method to access a field in local view, given the compute domain of a field operator.""" if isinstance(self.gt_type, ts.ScalarType): return gtir_dataflow.MemletExpr( - dc_node=self.dc_node, gt_dtype=self.gt_type, subset=dace_subsets.Indices([0]) + dc_node=self.dc_node, + gt_dtype=self.gt_type, + subset=dace_subsets.Range.from_string("0"), ) if isinstance(self.gt_type, ts.FieldType): domain_dims = [dim for dim, _, _ in domain] - domain_indices = get_domain_indices(domain_dims) + domain_indices = get_domain_indices(domain_dims, origin=None) it_indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { dim: gtir_dataflow.SymbolExpr(index, INDEX_DTYPE) for dim, index in zip(domain_dims, domain_indices) } - field_domain = [ - (dim, dace.symbolic.SymExpr(0) if self.offset is None else self.offset[i]) + field_origin = [ + (dim, dace.symbolic.SymExpr(0) if self.origin is None else self.origin[i]) for i, dim in enumerate(self.gt_type.dims) ] # The property below is ensured by calling `make_field()` to construct `FieldopData`. @@ -116,11 +155,48 @@ def get_local_view( # to `ListType` element type, while the field domain consists of all global dimensions. assert all(dim != gtx_common.DimensionKind.LOCAL for dim in self.gt_type.dims) return gtir_dataflow.IteratorExpr( - self.dc_node, self.gt_type.dtype, field_domain, it_indices + self.dc_node, self.gt_type.dtype, field_origin, it_indices ) raise NotImplementedError(f"Node type {type(self.gt_type)} not supported.") + def get_symbol_mapping( + self, dataname: str, sdfg: dace.SDFG + ) -> dict[str, dace.symbolic.SymExpr]: + """ + Helper method to create the symbol mapping for array storage in a nested SDFG. + + Args: + dataname: Name of the data container insiode the nested SDFG. + sdfg: The parent SDFG where the `FieldopData` object lives. + + Returns: + Mapping from symbols in nested SDFG to the corresponding symbolic values + in the parent SDFG. This includes the range start and stop symbols (used + to calculate the array shape as range 'stop - start') and the strides. + """ + if isinstance(self.gt_type, ts.ScalarType): + return {} + ndims = len(self.gt_type.dims) + outer_desc = self.dc_node.desc(sdfg) + assert isinstance(outer_desc, dace.data.Array) + # origin and size of the local dimension, in case of a field with `ListType` data, + # are assumed to be compiled-time values (not symbolic), therefore the start and + # stop range symbols of the inner field only extend over the global dimensions + return ( + {gtx_dace_utils.range_start_symbol(dataname, i): (self.origin[i]) for i in range(ndims)} + | { + gtx_dace_utils.range_stop_symbol(dataname, i): ( + self.origin[i] + outer_desc.shape[i] + ) + for i in range(ndims) + } + | { + gtx_dace_utils.field_stride_symbol_name(dataname, i): stride + for i, stride in enumerate(outer_desc.strides) + } + ) + FieldopDomain: TypeAlias = list[ tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] @@ -141,6 +217,33 @@ def get_local_view( """Data type used for field indexing.""" +def get_arg_symbol_mapping( + dataname: str, arg: FieldopResult, sdfg: dace.SDFG +) -> dict[str, dace.symbolic.SymExpr]: + """ + Helper method to build the mapping from inner to outer SDFG of all symbols + used for storage of a field or a tuple of fields. + + Args: + dataname: The storage name inside the nested SDFG. + arg: The argument field in the parent SDFG. + sdfg: The parent SDFG where the argument field lives. + + Returns: + A mapping from inner symbol names to values or symbolic definitions + in the parent SDFG. + """ + if isinstance(arg, FieldopData): + return arg.get_symbol_mapping(dataname, sdfg) + + symbol_mapping: dict[str, dace.symbolic.SymExpr] = {} + for i, elem in enumerate(arg): + dataname_elem = f"{dataname}_{i}" + symbol_mapping |= get_arg_symbol_mapping(dataname_elem, elem, sdfg) + + return symbol_mapping + + def get_tuple_type(data: tuple[FieldopResult, ...]) -> ts.TupleType: """ Compute the `ts.TupleType` corresponding to the tuple structure of `FieldopResult`. @@ -239,7 +342,7 @@ def get_field_layout( Returns: A tuple of three lists containing: - the domain dimensions - - the domain offset in each dimension + - the domain origin, that is the start indices in all dimensions - the domain size in each dimension """ domain_dims, domain_lbs, domain_ubs = zip(*domain) @@ -278,9 +381,9 @@ def _create_field_operator_impl( dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) # the memory layout of the output field follows the field operator compute domain - domain_dims, domain_offset, domain_shape = get_field_layout(domain) - domain_indices = get_domain_indices(domain_dims, domain_offset) - domain_subset = dace_subsets.Range.from_indices(domain_indices) + field_dims, field_origin, field_shape = get_field_layout(domain) + field_indices = get_domain_indices(field_dims, field_origin) + field_subset = dace_subsets.Range.from_indices(field_indices) if isinstance(output_edge.result.gt_dtype, ts.ScalarType): if output_edge.result.gt_dtype != output_type.dtype: @@ -288,8 +391,6 @@ def _create_field_operator_impl( f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}." ) assert isinstance(dataflow_output_desc, dace.data.Scalar) - field_shape = domain_shape - field_subset = domain_subset else: assert isinstance(output_type.dtype, ts.ListType) assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) @@ -301,8 +402,8 @@ def _create_field_operator_impl( assert len(dataflow_output_desc.shape) == 1 # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) assert output_edge.result.gt_dtype.offset_type is not None - field_shape = [*domain_shape, dataflow_output_desc.shape[0]] - field_subset = domain_subset + dace_subsets.Range.from_array(dataflow_output_desc) + field_shape = [*field_shape, dataflow_output_desc.shape[0]] + field_subset = field_subset + dace_subsets.Range.from_array(dataflow_output_desc) # allocate local temporary storage field_name, _ = sdfg_builder.add_temp_array(sdfg, field_shape, dataflow_output_desc.dtype) @@ -312,9 +413,7 @@ def _create_field_operator_impl( output_edge.connect(map_exit, field_node, field_subset) return FieldopData( - field_node, - ts.FieldType(domain_dims, output_edge.result.gt_dtype), - offset=(domain_offset if set(domain_offset) != {0} else None), + field_node, ts.FieldType(field_dims, output_edge.result.gt_dtype), tuple(field_origin) ) @@ -535,7 +634,7 @@ def construct_output(inner_data: FieldopData) -> FieldopData: outer, _ = sdfg_builder.add_temp_array_like(sdfg, inner_desc) outer_node = state.add_access(outer) - return inner_data.make_copy(outer_node) + return FieldopData(outer_node, inner_data.gt_type, inner_data.origin) result_temps = gtx_utils.tree_map(construct_output)(true_br_args) @@ -696,7 +795,7 @@ def translate_literal( data_type = node.type data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type) - return FieldopData(data_node, data_type, offset=None) + return FieldopData(data_node, data_type, origin=()) def translate_make_tuple( @@ -818,7 +917,7 @@ def translate_scalar_expr( dace.Memlet(data=temp_name, subset="0"), ) - return FieldopData(temp_node, node.type, offset=None) + return FieldopData(temp_node, node.type, origin=()) def translate_symbol_ref( diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py index 59d1a0087a..584ce849e1 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -807,6 +807,7 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp nsdfg.add_edge(entry_state, fstate, dace.InterstateEdge(condition="not (__cond)")) input_memlets: dict[str, MemletExpr | ValueExpr] = {} + nsdfg_symbols_mapping: Optional[dict[str, dace.symbol]] = None # define scalar or symbol for the condition value inside the nested SDFG if isinstance(condition_value, SymbolExpr): @@ -845,12 +846,16 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple((result,))} + # all free symbols are mapped to the symbols available in parent SDFG + nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} + if isinstance(condition_value, SymbolExpr): + nsdfg_symbols_mapping["__cond"] = condition_value.value nsdfg_node = self.state.add_nested_sdfg( nsdfg, self.sdfg, inputs=set(input_memlets.keys()), outputs=outputs, - symbol_mapping=None, # implicitly map all free symbols to the symbols available in parent SDFG + symbol_mapping=nsdfg_symbols_mapping, ) for inner, input_expr in input_memlets.items(): @@ -1504,7 +1509,7 @@ def _make_unstructured_shift( shifted_indices[neighbor_dim] = MemletExpr( dc_node=offset_table_node, gt_dtype=it.gt_dtype, - subset=dace_subsets.Indices([origin_index.value, offset_expr.value]), + subset=dace_subsets.Range.from_string(f"{origin_index.value}, {offset_expr.value}"), ) else: # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py index 56a67510e7..763c292836 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py @@ -74,19 +74,16 @@ } -def builtin_cast(*args: Any) -> str: - val, target_type = args +def builtin_cast(val: str, target_type: str) -> str: assert target_type in builtins.TYPE_BUILTINS return MATH_BUILTINS_MAPPING[target_type].format(val) -def builtin_if(*args: Any) -> str: - cond, true_val, false_val = args +def builtin_if(cond: str, true_val: str, false_val: str) -> str: return f"{true_val} if {cond} else {false_val}" -def builtin_tuple_get(*args: Any) -> str: - index, tuple_name = args +def builtin_tuple_get(index: str, tuple_name: str) -> str: return f"{tuple_name}_{index}" @@ -99,7 +96,7 @@ def make_const_list(arg: str) -> str: return arg -GENERAL_BUILTIN_MAPPING: dict[str, Callable[[Any], str]] = { +GENERAL_BUILTIN_MAPPING: dict[str, Callable[..., str]] = { "cast_": builtin_cast, "if_": builtin_if, "make_const_list": make_const_list, diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py b/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py index ec88cd8f84..791440c37a 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py @@ -22,7 +22,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, Optional +import itertools +from typing import TYPE_CHECKING, Any, Iterable import dace from dace import subsets as dace_subsets @@ -108,19 +109,19 @@ def _create_scan_field_operator_impl( assert isinstance(dataflow_output_desc, dace.data.Array) # the memory layout of the output field follows the field operator compute domain - domain_dims, domain_offset, domain_shape = gtir_translators.get_field_layout(domain) - domain_indices = gtir_translators.get_domain_indices(domain_dims, domain_offset) - domain_subset = dace_subsets.Range.from_indices(domain_indices) + field_dims, field_origin, field_shape = gtir_translators.get_field_layout(domain) + field_indices = gtir_translators.get_domain_indices(field_dims, field_origin) + field_subset = dace_subsets.Range.from_indices(field_indices) # the vertical dimension used as scan column is computed by the `LoopRegion` # inside the map scope, therefore it is excluded from the map range - scan_dim_index = [sdfg_builder.is_column_axis(dim) for dim in domain_dims].index(True) + scan_dim_index = [sdfg_builder.is_column_axis(dim) for dim in field_dims].index(True) # the map scope writes the full-shape dimension corresponding to the scan column field_subset = ( - dace_subsets.Range(domain_subset[:scan_dim_index]) + dace_subsets.Range(field_subset[:scan_dim_index]) + dace_subsets.Range.from_string(f"0:{dataflow_output_desc.shape[0]}") - + dace_subsets.Range(domain_subset[scan_dim_index + 1 :]) + + dace_subsets.Range(field_subset[scan_dim_index + 1 :]) ) if isinstance(output_edge.result.gt_dtype, ts.ScalarType): @@ -130,7 +131,6 @@ def _create_scan_field_operator_impl( f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}." ) field_dtype = output_edge.result.gt_dtype - field_shape = domain_shape # the scan field operator computes a column of scalar values assert len(dataflow_output_desc.shape) == 1 else: @@ -146,7 +146,7 @@ def _create_scan_field_operator_impl( assert len(dataflow_output_desc.shape) == 2 # the lines below extend the array with the local dimension added by the field operator assert output_edge.result.gt_dtype.offset_type is not None - field_shape = [*domain_shape, dataflow_output_desc.shape[1]] + field_shape = [*field_shape, dataflow_output_desc.shape[1]] field_subset = field_subset + dace_subsets.Range.from_string( f"0:{dataflow_output_desc.shape[1]}" ) @@ -158,7 +158,7 @@ def _create_scan_field_operator_impl( # the inner and outer strides have to match scan_output_stride = field_desc.strides[scan_dim_index] # also consider the stride of the local dimension, in case the scan field operator computes a list - local_strides = field_desc.strides[len(domain_dims) :] + local_strides = field_desc.strides[len(field_dims) :] assert len(local_strides) == (1 if isinstance(output_edge.result.gt_dtype, ts.ListType) else 0) new_inner_strides = [scan_output_stride, *local_strides] dataflow_output_desc.set_shape(dataflow_output_desc.shape, new_inner_strides) @@ -168,9 +168,7 @@ def _create_scan_field_operator_impl( output_edge.connect(map_exit, field_node, field_subset) return gtir_translators.FieldopData( - field_node, - ts.FieldType(domain_dims, output_edge.result.gt_dtype), - offset=(domain_offset if set(domain_offset) != {0} else None), + field_node, ts.FieldType(field_dims, output_edge.result.gt_dtype), tuple(field_origin) ) @@ -271,7 +269,6 @@ def _lower_lambda_to_nested_sdfg( domain: gtir_translators.FieldopDomain, init_data: gtir_translators.FieldopResult, lambda_symbols: dict[str, ts.DataType], - lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]], scan_forward: bool, scan_carry_symbol: gtir.Sym, ) -> tuple[dace.SDFG, gtir_translators.FieldopResult]: @@ -297,8 +294,6 @@ def _lower_lambda_to_nested_sdfg( init_data: The data produced in the field operator context that is used to initialize the scan carry value. lambda_symbols: List of symbols used as parameters of the stencil expressions. - lambda_field_offsets: Mapping from symbol name to field origin, - `None` if field origin is 0 in all dimensions. scan_forward: When True, the loop should range starting from the origin; when False, traverse towards origin. scan_carry_symbol: The symbol used in the stencil expression to carry the @@ -317,9 +312,7 @@ def _lower_lambda_to_nested_sdfg( # the lambda expression, i.e. body of the scan, will be created inside a nested SDFG. nsdfg = dace.SDFG(sdfg_builder.unique_nsdfg_name(sdfg, "scan")) nsdfg.debuginfo = gtir_sdfg_utils.debug_info(lambda_node, default=sdfg.debuginfo) - lambda_translator = sdfg_builder.setup_nested_context( - lambda_node, nsdfg, lambda_symbols, lambda_field_offsets - ) + lambda_translator = sdfg_builder.setup_nested_context(lambda_node, nsdfg, lambda_symbols) # use the vertical dimension in the domain as scan dimension scan_domain = [ @@ -474,7 +467,7 @@ def connect_scan_output( ) output_type = ts.FieldType(dims=[scan_dim], dtype=scan_result.gt_dtype) - return gtir_translators.FieldopData(output_node, output_type, offset=scan_lower_bound) + return gtir_translators.FieldopData(output_node, output_type, origin=(scan_lower_bound,)) # write the stencil result (value on one vertical level) into a 1D field # with full vertical shape representing one column @@ -603,24 +596,36 @@ def translate_scan( for p, arg_type in zip(stencil_expr.params, lambda_arg_types, strict=True) } + # lower the scan stencil expression in a separate SDFG context + nsdfg, lambda_output = _lower_lambda_to_nested_sdfg( + stencil_expr, + sdfg, + sdfg_builder, + domain, + init_data, + lambda_symbols, + scan_forward, + im.sym(scan_carry, scan_carry_type), + ) + # visit the arguments to be passed to the lambda expression # this must be executed before visiting the lambda expression, in order to populate # the data descriptor with the correct field domain offsets for field arguments lambda_args = [sdfg_builder.visit(arg, sdfg=sdfg, head_state=state) for arg in node.args] - lambda_args_mapping = { - _scan_input_name(scan_carry): init_data, - } | { - str(param.id): arg for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) - } + lambda_args_mapping = [ + (im.sym(_scan_input_name(scan_carry), scan_carry_type), init_data), + ] + [ + (im.sym(param.id, arg.gt_type), arg) + for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) + ] + + lambda_arg_nodes = dict( + itertools.chain( + *[gtir_translators.flatten_tuples(psym.id, arg) for psym, arg in lambda_args_mapping] + ) + ) - # parse the dataflow input and output symbols - lambda_flat_args: dict[str, gtir_translators.FieldopData] = {} - # the field offset is set to `None` when it is zero in all dimensions - lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = {} - for param, outer_arg in lambda_args_mapping.items(): - tuple_fields = gtir_translators.flatten_tuples(param, outer_arg) - lambda_field_offsets |= {tsym: tfield.offset for tsym, tfield in tuple_fields} - lambda_flat_args |= dict(tuple_fields) + # parse the dataflow output symbols if isinstance(scan_carry_type, ts.TupleType): lambda_flat_outs = { str(sym.id): sym.type @@ -631,46 +636,23 @@ def translate_scan( else: lambda_flat_outs = {_scan_output_name(scan_carry): scan_carry_type} - # lower the scan stencil expression in a separate SDFG context - nsdfg, lambda_output = _lower_lambda_to_nested_sdfg( - stencil_expr, - sdfg, - sdfg_builder, - domain, - init_data, - lambda_symbols, - lambda_field_offsets, - scan_forward, - im.sym(scan_carry, scan_carry_type), - ) - # build the mapping of symbols from nested SDFG to field operator context nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} - for inner_dataname, outer_arg in lambda_flat_args.items(): - inner_desc = nsdfg.data(inner_dataname) - outer_desc = outer_arg.dc_node.desc(sdfg) - nsdfg_symbols_mapping |= { - str(nested_symbol): parent_symbol - for nested_symbol, parent_symbol in zip( - [*inner_desc.shape, *inner_desc.strides], - [*outer_desc.shape, *outer_desc.strides], - strict=True, - ) - if dace.symbolic.issymbolic(nested_symbol) - } + for psym, arg in lambda_args_mapping: + nsdfg_symbols_mapping |= gtir_translators.get_arg_symbol_mapping(psym.id, arg, sdfg) # the scan nested SDFG is ready: it is instantiated in the field operator context # where the map scope over the horizontal domain lives nsdfg_node = state.add_nested_sdfg( nsdfg, sdfg, - inputs=set(lambda_flat_args.keys()), + inputs=set(lambda_arg_nodes.keys()), outputs=set(lambda_flat_outs.keys()), symbol_mapping=nsdfg_symbols_mapping, ) lambda_input_edges = [] - for input_connector, outer_arg in lambda_flat_args.items(): + for input_connector, outer_arg in lambda_arg_nodes.items(): arg_desc = outer_arg.dc_node.desc(sdfg) input_subset = dace_subsets.Range.from_array(arg_desc) input_edge = gtir_dataflow.MemletInputEdge( diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py index b306a59305..a58e8bcf8a 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py @@ -16,7 +16,6 @@ import abc import dataclasses -import functools import itertools import operator from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union @@ -34,6 +33,7 @@ from gt4py.next.program_processors.runners.dace import ( gtir_builtin_translators, gtir_sdfg_utils, + transformations as gtx_transformations, utils as gtx_dace_utils, ) from gt4py.next.type_system import type_specifications as ts, type_translation as tt @@ -121,7 +121,9 @@ class SDFGBuilder(DataflowBuilder, Protocol): @abc.abstractmethod def make_field( - self, data_node: dace.nodes.AccessNode, data_type: ts.FieldType | ts.ScalarType + self, + data_node: dace.nodes.AccessNode, + data_type: ts.FieldType | ts.ScalarType, ) -> gtir_builtin_translators.FieldopData: """Retrieve the field data descriptor including the domain offset information.""" ... @@ -142,7 +144,6 @@ def setup_nested_context( expr: gtir.Expr, sdfg: dace.SDFG, global_symbols: dict[str, ts.DataType], - field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]], ) -> SDFGBuilder: """ Create an SDFG context to translate a nested expression, indipendent @@ -156,7 +157,6 @@ def setup_nested_context( expr: The nested expresson to be lowered. sdfg: The SDFG where to lower the nested expression. global_symbols: Mapping from symbol name to GTIR data type. - field_offsets: Mapping from symbol name to field origin, `None` if field origin is 0 in all dimensions. Returns: A visitor object implementing the `SDFGBuilder` protocol. @@ -201,6 +201,24 @@ def _collect_symbols_in_domain_expressions( ) +def _make_access_index_for_field( + domain: gtir_builtin_translators.FieldopDomain, data: gtir_builtin_translators.FieldopData +) -> dace.subsets.Range: + """Helper method to build a memlet subset of a field over the given domain.""" + # convert domain expression to dictionary to ease access to the dimensions, + # since the access indices have to follow the order of dimensions in field domain + if isinstance(data.gt_type, ts.FieldType) and len(data.gt_type.dims) != 0: + assert data.origin is not None + domain_ranges = {dim: (lb, ub) for dim, lb, ub in domain} + return dace.subsets.Range( + (domain_ranges[dim][0] - origin, domain_ranges[dim][1] - origin - 1, 1) + for dim, origin in zip(data.gt_type.dims, data.origin, strict=True) + ) + else: + assert len(domain) == 0 + return dace.subsets.Range.from_string("0") + + @dataclasses.dataclass(frozen=True) class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. @@ -217,10 +235,7 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): offset_provider_type: gtx_common.OffsetProviderType column_axis: Optional[gtx_common.Dimension] - global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=dict) - field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = dataclasses.field( - default_factory=dict - ) + global_symbols: dict[str, ts.DataType] map_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") ) @@ -232,18 +247,23 @@ def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderType return self.offset_provider_type[offset] def make_field( - self, data_node: dace.nodes.AccessNode, data_type: ts.FieldType | ts.ScalarType + self, + data_node: dace.nodes.AccessNode, + data_type: ts.FieldType | ts.ScalarType, ) -> gtir_builtin_translators.FieldopData: """ - Helper method to build the field data type associated with an access node in the SDFG. + Helper method to build the field data type associated with a data access node. - In case of `ScalarType` data, the descriptor is constructed with `offset=None`. + In case of `ScalarType` data, the `FieldopData` is constructed with `origin=None`. In case of `FieldType` data, the field origin is added to the data descriptor. Besides, if the `FieldType` contains a local dimension, the descriptor is converted to a canonical form where the field domain consists of all global dimensions (the grid axes) and the field data type is `ListType`, with `offset_type` equal to the field local dimension. + TODO(edoapo): consider refactoring this method and moving it to a type module + close to the `FieldopData` type declaration. + Args: data_node: The access node to the SDFG data storage. data_type: The GT4Py data descriptor, which can either come from a field parameter @@ -253,8 +273,7 @@ def make_field( The descriptor associated with the SDFG data storage, filled with field origin. """ if isinstance(data_type, ts.ScalarType): - return gtir_builtin_translators.FieldopData(data_node, data_type, offset=None) - domain_offset = self.field_offsets.get(data_node.data, None) + return gtir_builtin_translators.FieldopData(data_node, data_type, origin=()) local_dims = [dim for dim in data_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL] if len(local_dims) == 0: # do nothing: the field domain consists of all global dimensions @@ -279,7 +298,11 @@ def make_field( raise NotImplementedError( "Fields with more than one local dimension are not supported." ) - return gtir_builtin_translators.FieldopData(data_node, field_type, domain_offset) + field_origin = tuple( + dace.symbolic.pystr_to_symbolic(gtx_dace_utils.range_start_symbol(data_node.data, axis)) + for axis in range(len(field_type.dims)) + ) + return gtir_builtin_translators.FieldopData(data_node, field_type, field_origin) def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.global_symbols[symbol_name] @@ -293,11 +316,8 @@ def setup_nested_context( expr: gtir.Expr, sdfg: dace.SDFG, global_symbols: dict[str, ts.DataType], - field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]], ) -> SDFGBuilder: - nsdfg_builder = GTIRToSDFG( - self.offset_provider_type, self.column_axis, global_symbols, field_offsets - ) + nsdfg_builder = GTIRToSDFG(self.offset_provider_type, self.column_axis, global_symbols) nsdfg_params = [ gtir.Sym(id=p_name, type=p_type) for p_name, p_type in global_symbols.items() ] @@ -321,28 +341,45 @@ def unique_tasklet_name(self, name: str) -> str: def _make_array_shape_and_strides( self, name: str, dims: Sequence[gtx_common.Dimension] - ) -> tuple[list[dace.symbol], list[dace.symbol]]: + ) -> tuple[list[dace.symbolic.SymbolicType], list[dace.symbolic.SymbolicType]]: """ Parse field dimensions and allocate symbols for array shape and strides. For local dimensions, the size is known at compile-time and therefore the corresponding array shape dimension is set to an integer literal value. + This method is only called for non-transient arrays, which require symbolic + memory layout. The memory layout of transient arrays, used for temporary + fields, is left to the DaCe default (row major, not necessarily the optimal + one) and might be changed during optimization. + Returns: Two lists of symbols, one for the shape and the other for the strides of the array. """ - dc_dtype = gtir_builtin_translators.INDEX_DTYPE neighbor_table_types = gtx_dace_utils.filter_connectivity_types(self.offset_provider_type) - shape = [ - ( - neighbor_table_types[dim.value].max_neighbors - if dim.kind == gtx_common.DimensionKind.LOCAL - else dace.symbol(gtx_dace_utils.field_size_symbol_name(name, i), dc_dtype) - ) - for i, dim in enumerate(dims) - ] + shape = [] + for i, dim in enumerate(dims): + if dim.kind == gtx_common.DimensionKind.LOCAL: + # for local dimension, the size is taken from the associated connectivity type + shape.append(neighbor_table_types[dim.value].max_neighbors) + elif gtx_dace_utils.is_connectivity_identifier(name, self.offset_provider_type): + # we use symbolic size for the global dimension of a connectivity + shape.append( + dace.symbolic.pystr_to_symbolic(gtx_dace_utils.field_size_symbol_name(name, i)) + ) + else: + # the size of global dimensions for a regular field is the symbolic + # expression of domain range 'stop - start' + shape.append( + dace.symbolic.pystr_to_symbolic( + "{} - {}".format( + gtx_dace_utils.range_stop_symbol(name, i), + gtx_dace_utils.range_start_symbol(name, i), + ) + ) + ) strides = [ - dace.symbol(gtx_dace_utils.field_stride_symbol_name(name, i), dc_dtype) + dace.symbolic.pystr_to_symbolic(gtx_dace_utils.field_stride_symbol_name(name, i)) for i in range(len(dims)) ] return shape, strides @@ -470,7 +507,7 @@ def make_temps( head_state.add_nedge( field.dc_node, temp_node, sdfg.make_array_memlet(field.dc_node.data) ) - return field.make_copy(temp_node) + return gtir_builtin_translators.FieldopData(temp_node, field.gt_type, field.origin) temp_result = gtx_utils.tree_map(make_temps)(result) return list(gtx_utils.flatten_nested_tuple((temp_result,))) @@ -498,7 +535,6 @@ def _add_sdfg_params( sdfg_args += self._add_storage( sdfg, symbolic_arguments, pname, param.type, transient=False ) - self.global_symbols[pname] = param.type # add SDFG storage for connectivity tables for offset, connectivity_type in gtx_dace_utils.filter_connectivity_types( @@ -532,13 +568,6 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: The temporary data is global, therefore available everywhere in the SDFG but not outside. Then, all statements are translated, one after the other. """ - if node.function_definitions: - raise NotImplementedError("Functions expected to be inlined as lambda calls.") - - # Since program field arguments are passed to the SDFG as full-shape arrays, - # there is no offset that needs to be compensated. - assert len(self.field_offsets) == 0 - sdfg = dace.SDFG(node.id) sdfg.debuginfo = gtir_sdfg_utils.debug_info(node) @@ -605,10 +634,8 @@ def visit_SetAt( # in case the statement returns more than one field target_fields = self._visit_expression(stmt.target, sdfg, state, use_temp=False) - # convert domain expression to dictionary to ease access to dimension boundaries - domain = { - dim: (lb, ub) for dim, lb, ub in gtir_builtin_translators.extract_domain(stmt.domain) - } + # visit the domain expression + domain = gtir_builtin_translators.extract_domain(stmt.domain) expr_input_args = { sym_id @@ -626,22 +653,9 @@ def visit_SetAt( target_desc = sdfg.arrays[target.dc_node.data] assert not target_desc.transient - if isinstance(target.gt_type, ts.FieldType): - target_subset = ",".join( - f"{domain[dim][0]}:{domain[dim][1]}" for dim in target.gt_type.dims - ) - source_subset = ( - target_subset - if source.offset is None - else ",".join( - f"{domain[dim][0] - offset}:{domain[dim][1] - offset}" - for dim, offset in zip(target.gt_type.dims, source.offset, strict=True) - ) - ) - else: - assert len(domain) == 0 - target_subset = "0" - source_subset = "0" + assert source.gt_type == target.gt_type + source_subset = _make_access_index_for_field(domain, source) + target_subset = _make_access_index_for_field(domain, target) if target.dc_node.data in state_input_data: # if inout argument, write the result in separate next state @@ -725,15 +739,11 @@ def visit_Lambda( i.e. a lambda parameter with the same name as a symbol in scope, the parameter will shadow the previous symbol during traversal of the lambda expression. """ - lambda_args_mapping = [ - (str(param.id), arg) for param, arg in zip(node.params, args, strict=True) - ] - lambda_arg_nodes = dict( itertools.chain( *[ - gtir_builtin_translators.flatten_tuples(pname, arg) - for pname, arg in lambda_args_mapping + gtir_builtin_translators.flatten_tuples(psym.id, arg) + for psym, arg in zip(node.params, args, strict=True) ] ) ) @@ -743,44 +753,16 @@ def visit_Lambda( sym: self.global_symbols[sym] for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) } | { - pname: gtir_builtin_translators.get_tuple_type(arg) + psym.id: gtir_builtin_translators.get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type - for pname, arg in lambda_args_mapping + for psym, arg in zip(node.params, args, strict=True) } - def get_field_domain_offset( - p_name: str, p_type: ts.DataType - ) -> dict[str, Optional[list[dace.symbolic.SymExpr]]]: - if isinstance(p_type, ts.FieldType): - if p_name in lambda_arg_nodes: - arg = lambda_arg_nodes[p_name] - assert isinstance(arg, gtir_builtin_translators.FieldopData) - return {p_name: arg.offset} - elif field_domain_offset := self.field_offsets.get(p_name, None): - return {p_name: field_domain_offset} - elif isinstance(p_type, ts.TupleType): - tsyms = gtir_sdfg_utils.flatten_tuple_fields(p_name, p_type) - return functools.reduce( - lambda field_offsets, sym: ( - field_offsets | get_field_domain_offset(sym.id, sym.type) # type: ignore[arg-type] - ), - tsyms, - {}, - ) - return {} - - # populate mapping from field name to domain offset - lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = {} - for p_name, p_type in lambda_symbols.items(): - lambda_field_offsets |= get_field_domain_offset(p_name, p_type) - # lower let-statement lambda node as a nested SDFG nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) nsdfg.debuginfo = gtir_sdfg_utils.debug_info(node, default=sdfg.debuginfo) - lambda_translator = self.setup_nested_context( - node.expr, nsdfg, lambda_symbols, lambda_field_offsets - ) + lambda_translator = self.setup_nested_context(node.expr, nsdfg, lambda_symbols) nstate = nsdfg.add_state("lambda") lambda_result = lambda_translator.visit( @@ -800,7 +782,6 @@ def get_field_domain_offset( } input_memlets = {} - nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} for nsdfg_dataname, nsdfg_datadesc in nsdfg.arrays.items(): if nsdfg_datadesc.transient: continue @@ -809,15 +790,6 @@ def get_field_domain_offset( src_node = lambda_arg_nodes[nsdfg_dataname].dc_node dataname = src_node.data datadesc = src_node.desc(sdfg) - nsdfg_symbols_mapping |= { - str(nested_symbol): parent_symbol - for nested_symbol, parent_symbol in zip( - [*nsdfg_datadesc.shape, *nsdfg_datadesc.strides], - [*datadesc.shape, *datadesc.strides], - strict=True, - ) - if dace.symbolic.issymbolic(nested_symbol) - } else: dataname = nsdfg_dataname datadesc = sdfg.arrays[nsdfg_dataname] @@ -855,6 +827,13 @@ def get_field_domain_offset( if output_data.dc_node.desc(nsdfg).transient } + # map free symbols to parent SDFG + nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} + for sym, arg in zip(node.params, args, strict=True): + nsdfg_symbols_mapping |= gtir_builtin_translators.get_arg_symbol_mapping( + sym.id, arg, sdfg + ) + nsdfg_node = head_state.add_nested_sdfg( nsdfg, parent=sdfg, @@ -888,33 +867,34 @@ def construct_output_for_nested_sdfg( arguments, that are simply returned by the lambda: it can be directly accessed in the parent SDFG. """ inner_desc = inner_data.dc_node.desc(nsdfg) + inner_dataname = inner_data.dc_node.data if inner_desc.transient: # Transient data nodes only exist within the nested SDFG. In order to return some result data, # the corresponding data container inside the nested SDFG has to be changed to non-transient, # that is externally allocated, as required by the SDFG IR. An output edge will write the result # from the nested-SDFG to a new intermediate data container allocated in the parent SDFG. - inner_desc.transient = False - outer, outer_desc = self.add_temp_array_like(sdfg, inner_desc) - # We cannot use a copy of the inner data descriptor directly, we have to apply the symbol mapping. - dace.symbolic.safe_replace( - nsdfg_symbols_mapping, - lambda m: dace.sdfg.replace_properties_dict(outer_desc, m), + outer_data = inner_data.map_to_parent_sdfg( + self, nsdfg, sdfg, head_state, nsdfg_symbols_mapping ) - connector = inner_data.dc_node.data - outer_node = head_state.add_access(outer) head_state.add_edge( - nsdfg_node, connector, outer_node, None, sdfg.make_array_memlet(outer) + nsdfg_node, + inner_dataname, + outer_data.dc_node, + None, + sdfg.make_array_memlet(outer_data.dc_node.data), ) - outer_data = inner_data.make_copy(outer_node) - elif inner_data.dc_node.data in lambda_arg_nodes: + elif inner_dataname in lambda_arg_nodes: # This if branch and the next one handle the non-transient result nodes. # Non-transient nodes are just input nodes that are immediately returned # by the lambda expression. Therefore, these nodes are already available # in the parent context and can be directly accessed there. - outer_data = lambda_arg_nodes[inner_data.dc_node.data] + outer_data = lambda_arg_nodes[inner_dataname] else: - outer_node = head_state.add_access(inner_data.dc_node.data) - outer_data = inner_data.make_copy(outer_node) + # This must be a symbol captured from the lambda parent scope. + outer_node = head_state.add_access(inner_dataname) + outer_data = gtir_builtin_translators.FieldopData( + outer_node, inner_data.gt_type, inner_data.origin + ) # Isolated access node will make validation fail. # Isolated access nodes can be found in the join-state of an if-expression # or in lambda expressions that just construct tuples from input arguments. @@ -941,10 +921,50 @@ def visit_SymRef( return gtir_builtin_translators.translate_symbol_ref(node, sdfg, head_state, self) +def _remove_field_origin_symbols(ir: gtir.Program, sdfg: dace.SDFG) -> None: + """ + Helper function to remove the origin symbols used in program field arguments, + that is only for non-transient data descriptors in the top-level SDFG. + The start symbol of field domain range is set to constant value 0, thus removing + the corresponding free symbol. These values are propagated to all nested SDFGs. + + This function is only used by `build_sdfg_from_gtir()` when the option flag + `disable_field_origin_on_program_arguments` is set to True. + """ + + # collect symbols used as range start for all program arguments + range_start_symbols: dict[str, dace.symbolic.SymExpr] = {} + for p in ir.params: + if isinstance(p.type, ts.TupleType): + psymbols = [ + sym + for sym in gtir_sdfg_utils.flatten_tuple_fields(p.id, p.type) + if isinstance(sym.type, ts.FieldType) + ] + elif isinstance(p.type, ts.FieldType): + psymbols = [p] + else: + psymbols = [] + for psymbol in psymbols: + assert isinstance(psymbol.type, ts.FieldType) + if len(psymbol.type.dims) == 0: + # zero-dimensional field + continue + dataname = str(psymbol.id) + # set all range start symbols to constant value 0 + range_start_symbols |= { + gtx_dace_utils.range_start_symbol(dataname, i): 0 + for i in range(len(psymbol.type.dims)) + } + # we set all range start symbols to 0 in the top-level SDFG and proagate them to nested SDFGs + gtx_transformations.gt_substitute_compiletime_symbols(sdfg, range_start_symbols, validate=True) + + def build_sdfg_from_gtir( ir: gtir.Program, offset_provider_type: gtx_common.OffsetProviderType, column_axis: Optional[gtx_common.Dimension] = None, + disable_field_origin_on_program_arguments: bool = False, ) -> dace.SDFG: """ Receives a GTIR program and lowers it to a DaCe SDFG. @@ -956,11 +976,15 @@ def build_sdfg_from_gtir( ir: The GTIR program node to be lowered to SDFG offset_provider_type: The definitions of offset providers used by the program node column_axis: Vertical dimension used for column scan expressions. + disable_field_origin_on_program_arguments: When True, the field range in all dimensions is assumed to start from 0 Returns: An SDFG in the DaCe canonical form (simplified) """ + if ir.function_definitions: + raise NotImplementedError("Functions expected to be inlined as lambda calls.") + ir = gtir_type_inference.infer(ir, offset_provider_type=offset_provider_type) ir = ir_prune_casts.PruneCasts().visit(ir) @@ -970,11 +994,15 @@ def build_sdfg_from_gtir( # Here we find new names for invalid symbols present in the IR. ir = gtir_sdfg_utils.replace_invalid_symbols(ir) - sdfg_genenerator = GTIRToSDFG(offset_provider_type, column_axis) + global_symbols = {str(p.id): p.type for p in ir.params if isinstance(p.type, ts.DataType)} + sdfg_genenerator = GTIRToSDFG(offset_provider_type, column_axis, global_symbols) sdfg = sdfg_genenerator.visit(ir) assert isinstance(sdfg, dace.SDFG) # TODO(edopao): remove inlining when DaCe transformations support LoopRegion construct dace_sdfg_utils.inline_loop_blocks(sdfg) + if disable_field_origin_on_program_arguments: + _remove_field_origin_symbols(ir, sdfg) + return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index 78016db0a9..a381346a1e 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -85,11 +85,13 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: if not hasattr(self.backend.executor, "step") else self.backend.executor.step, ) # We know which backend we are using, but we don't know if the compile workflow is cached. - # TODO(ricoh): switch 'itir_transforms_off=True' because we ran them separately previously + # TODO(ricoh): switch 'disable_itir_transforms=True' because we ran them separately previously # and so we can ensure the SDFG does not know any runtime info it shouldn't know. Remove with # the other parts of the workaround when possible. sdfg = dace.SDFG.from_json( - compile_workflow.translation.replace(itir_transforms_off=True)(gtir_stage).source_code + compile_workflow.translation.replace( + disable_itir_transforms=True, disable_field_origin_on_program_arguments=True + )(gtir_stage).source_code ) self.sdfg_closure_cache["arrays"] = sdfg.arrays diff --git a/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py b/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py index 7f221a5a41..09720ddf3c 100644 --- a/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py +++ b/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import warnings from collections.abc import Mapping, Sequence -from typing import Any +from typing import Any, Optional import dace import numpy as np @@ -24,32 +24,40 @@ cp = None -def _convert_arg(arg: Any, sdfg_param: str) -> Any: +def _convert_arg(arg: Any) -> tuple[Any, Optional[gtx_common.Domain]]: if not isinstance(arg, gtx_common.Field): - return arg + return arg, None if len(arg.domain.dims) == 0: # Pass zero-dimensional fields as scalars. - return arg.as_scalar() - # field domain offsets are not supported - non_zero_offsets = [ - (dim, dim_range) - for dim, dim_range in zip(arg.domain.dims, arg.domain.ranges, strict=True) - if dim_range.start != 0 - ] - if non_zero_offsets: - dim, dim_range = non_zero_offsets[0] - raise RuntimeError( - f"Field '{sdfg_param}' passed as array slice with offset {dim_range.start} on dimension {dim.value}." - ) - return arg.ndarray + return arg.as_scalar(), None + return arg.ndarray, arg.domain def _get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: sdfg_params: Sequence[str] = sdfg.arg_names - return { - sdfg_param: _convert_arg(arg, sdfg_param) - for sdfg_param, arg in zip(sdfg_params, args, strict=True) - } + sdfg_arguments = {} + range_symbols: dict[str, int] = {} + for sdfg_param, arg in zip(sdfg_params, args, strict=True): + sdfg_arg, domain = _convert_arg(arg) + sdfg_arguments[sdfg_param] = sdfg_arg + if domain: + assert gtx_common.Domain.is_finite(domain) + range_symbols |= { + gtx_dace_utils.range_start_symbol(sdfg_param, i): r.start + for i, r in enumerate(domain.ranges) + } + range_symbols |= { + gtx_dace_utils.range_stop_symbol(sdfg_param, i): r.stop + for i, r in enumerate(domain.ranges) + } + # sanity check in case range symbols are passed as explicit program arguments + for range_symbol, value in range_symbols.items(): + if (sdfg_arg := sdfg_arguments.get(range_symbol, None)) is not None: + if sdfg_arg != value: + raise ValueError( + f"Received program argument {range_symbol} with value {sdfg_arg}, expected {value}." + ) + return sdfg_arguments | range_symbols def _ensure_is_on_device( @@ -150,18 +158,16 @@ def get_sdfg_args( dace_args = _get_args(sdfg, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} + dace_field_strides = _get_stride_args(sdfg.arrays, dace_field_args) dace_conn_args = get_sdfg_conn_args(sdfg, offset_provider, on_gpu) - dace_shapes = _get_shape_args(sdfg.arrays, dace_field_args) dace_conn_shapes = _get_shape_args(sdfg.arrays, dace_conn_args) - dace_strides = _get_stride_args(sdfg.arrays, dace_field_args) dace_conn_strides = _get_stride_args(sdfg.arrays, dace_conn_args) all_args = { **dace_args, **dace_conn_args, - **dace_shapes, **dace_conn_shapes, - **dace_strides, **dace_conn_strides, + **dace_field_strides, } if check_args: diff --git a/src/gt4py/next/program_processors/runners/dace/utils.py b/src/gt4py/next/program_processors/runners/dace/utils.py index cca0c001e7..5fdace73a9 100644 --- a/src/gt4py/next/program_processors/runners/dace/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/utils.py @@ -9,7 +9,7 @@ from __future__ import annotations import re -from typing import Final, Literal +from typing import Final, Literal, Mapping, Union import dace @@ -73,6 +73,16 @@ def field_stride_symbol_name(field_name: str, axis: int) -> str: return field_symbol_name(field_name, axis, "stride") +def range_start_symbol(field_name: str, axis: int) -> str: + """Format name of start symbol for domain range, as expected by GTIR.""" + return f"__{field_name}_{axis}_range_0" + + +def range_stop_symbol(field_name: str, axis: int) -> str: + """Format name of stop symbol for domain range, as expected by GTIR.""" + return f"__{field_name}_{axis}_range_1" + + def is_field_symbol(name: str) -> bool: return FIELD_SYMBOL_RE.match(name) is not None @@ -90,3 +100,29 @@ def filter_connectivity_types( for offset, conn in offset_provider_type.items() if isinstance(conn, gtx_common.NeighborConnectivityType) } + + +def safe_replace_symbolic( + val: dace.symbolic.SymbolicType, + symbol_mapping: Mapping[ + Union[dace.symbolic.SymbolicType, str], Union[dace.symbolic.SymbolicType, str] + ], +) -> dace.symbolic.SymbolicType: + """ + Replace free symbols in a dace symbolic expression, using `safe_replace()` + in order to avoid clashes in case the new symbol value is also a free symbol + in the original exoression. + + Args: + val: The symbolic expression where to apply the replacement. + symbol_mapping: The mapping table for symbol replacement. + + Returns: + A new symbolic expression as result of symbol replacement. + """ + # The list `x` is needed because `subs()` returns a new object and can not handle + # replacement dicts of the form `{'x': 'y', 'y': 'x'}`. + # The utility `safe_replace()` will call `subs()` twice in case of such dicts. + x = [val] + dace.symbolic.safe_replace(symbol_mapping, lambda m, xx=x: xx.append(xx[-1].subs(m))) + return x[-1] diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index 96be93de5e..6e1b3a6f32 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -35,7 +35,8 @@ class DaCeTranslator( ): device_type: core_defs.DeviceType auto_optimize: bool - itir_transforms_off: bool = False + disable_itir_transforms: bool = False + disable_field_origin_on_program_arguments: bool = False def _language_settings(self) -> languages.LanguageSettings: return languages.LanguageSettings( @@ -50,10 +51,13 @@ def generate_sdfg( auto_opt: bool, on_gpu: bool, ) -> dace.SDFG: - if not self.itir_transforms_off: + if not self.disable_itir_transforms: ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) sdfg = gtir_sdfg.build_sdfg_from_gtir( - ir, common.offset_provider_to_type(offset_provider), column_axis + ir, + common.offset_provider_to_type(offset_provider), + column_axis, + disable_field_origin_on_program_arguments=self.disable_field_origin_on_program_arguments, ) if auto_opt: diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 522250cafc..a96d967430 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -152,7 +152,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_CAN_DEREF, XFAIL, UNSUPPORTED_MESSAGE), (USES_COMPOSITE_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_LIFT, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCE_WITH_LAMBDA, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index 8fe0634302..3ba376b08f 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -37,9 +37,6 @@ def test_sdfgConvertible_laplap(cartesian_case): # noqa: F811 if not cartesian_case.backend or "dace" not in cartesian_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") - # TODO(edopao): add support for range symbols in field domain and re-enable this test - pytest.skip("Requires support for field domain range.") - backend = cartesian_case.backend in_field = cases.allocate(cartesian_case, laplap_program, "in_field")() @@ -62,7 +59,9 @@ def sdfg(): tmp_field, out_field ) - sdfg() + # use unique cache name based on process id to avoid clashes between parallel pytest workers + with dace.config.set_temporary("cache", value="unique"): + sdfg() assert np.allclose( gtx.field_utils.asnumpy(out_field)[2:-2, 2:-2], @@ -85,9 +84,6 @@ def test_sdfgConvertible_connectivities(unstructured_case): # noqa: F811 if not unstructured_case.backend or "dace" not in unstructured_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") - # TODO(edopao): add support for range symbols in field domain and re-enable this test - pytest.skip("Requires support for field domain range.") - allocator, backend = unstructured_case.allocator, unstructured_case.backend if gtx_allocators.is_field_allocator_for(allocator, gtx_allocators.CUPY_DEVICE): @@ -139,16 +135,18 @@ def get_stride_from_numpy_to_dace(arg: core_defs.NDArrayObject, axis: int) -> in # DaCe strides: number of elements to jump return arg.strides[axis] // arg.itemsize - cSDFG( - a, - out, - offset_provider, - rows=3, - cols=2, - connectivity_E2V=e2v, - __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v.ndarray, 0), - __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v.ndarray, 1), - ) + # use unique cache name based on process id to avoid clashes between parallel pytest workers + with dace.config.set_temporary("cache", value="unique"): + cSDFG( + a, + out, + offset_provider, + rows=3, + cols=2, + connectivity_E2V=e2v, + __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v.ndarray, 0), + __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v.ndarray, 1), + ) e2v_np = e2v.asnumpy() assert np.allclose(out.asnumpy(), a.asnumpy()[e2v_np[:, 0]]) @@ -160,8 +158,8 @@ def get_stride_from_numpy_to_dace(arg: core_defs.NDArrayObject, axis: int) -> in allocator=allocator, ) offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) - with dace.config.temporary_config(): - dace.config.Config.set("compiler", "allow_view_arguments", value=True) + # use unique cache name based on process id to avoid clashes between parallel pytest workers + with dace.config.set_temporary("cache", value="unique"): cSDFG( a, out, diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py index 4edaf9f85f..e5e2f18608 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py @@ -77,13 +77,9 @@ def unstructured(request, gtir_dace_backend, mesh_descriptor): # noqa: F811 ) -@pytest.mark.skipif(dace is None, reason="DaCe not found") def test_halo_exchange_helper_attrs(unstructured): local_int = gtx.int - # TODO(edopao): add support for range symbols in field domain and re-enable this test - pytest.skip("Requires support for field domain range.") - @gtx.field_operator(backend=unstructured.backend) def testee_op( a: gtx.Field[[Vertex, KDim], gtx.int], diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index e44e92013f..1726956332 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -61,6 +61,7 @@ def hdiff(inp, coeff, out, x, y): set_at(as_fieldop(hdiff_sten, domain)(inp, coeff), domain, out) +@pytest.mark.uses_lift @pytest.mark.uses_origin def test_hdiff(hdiff_reference, program_processor): program_processor, validate = program_processor diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_utils.py new file mode 100644 index 0000000000..eec68a6486 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_utils.py @@ -0,0 +1,21 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Test utility functions of the dace backend module.""" + +import pytest + +dace = pytest.importorskip("dace") + +from gt4py.next.program_processors.runners.dace import utils as gtx_dace_utils + + +def test_safe_replace_symbolic(): + assert gtx_dace_utils.safe_replace_symbolic( + dace.symbolic.pystr_to_symbolic("x*x + y"), symbol_mapping={"x": "y", "y": "x"} + ) == dace.symbolic.pystr_to_symbolic("y*y + x") diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 7431ad2b4a..8ebb240339 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -13,6 +13,7 @@ """ import functools +from typing import Any, Callable import numpy as np import pytest @@ -52,13 +53,13 @@ SKIP_VALUE_MESH: MeshDescriptor = skip_value_mesh() SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) FSYMBOLS = dict( - __w_size_0=N, + __w_0_range_1=N, __w_stride_0=1, - __x_size_0=N, + __x_0_range_1=N, __x_stride_0=1, - __y_size_0=N, + __y_0_range_1=N, __y_stride_0=1, - __z_size_0=N, + __z_0_range_1=N, __z_stride_0=1, size=N, ) @@ -69,31 +70,39 @@ def make_mesh_symbols(mesh: MeshDescriptor): ncells=mesh.num_cells, nedges=mesh.num_edges, nvertices=mesh.num_vertices, - __cells_size_0=mesh.num_cells, + __cells_0_range_1=mesh.num_cells, __cells_stride_0=1, - __edges_size_0=mesh.num_edges, + __edges_0_range_1=mesh.num_edges, __edges_stride_0=1, - __vertices_size_0=mesh.num_vertices, + __vertices_0_range_1=mesh.num_vertices, __vertices_stride_0=1, - __connectivity_C2E_size_0=mesh.num_cells, + __connectivity_C2E_0_range_1=mesh.num_cells, __connectivity_C2E_size_1=mesh.offset_provider_type["C2E"].max_neighbors, __connectivity_C2E_stride_0=mesh.offset_provider_type["C2E"].max_neighbors, __connectivity_C2E_stride_1=1, - __connectivity_C2V_size_0=mesh.num_cells, + __connectivity_C2V_0_range_1=mesh.num_cells, __connectivity_C2V_size_1=mesh.offset_provider_type["C2V"].max_neighbors, __connectivity_C2V_stride_0=mesh.offset_provider_type["C2V"].max_neighbors, __connectivity_C2V_stride_1=1, - __connectivity_E2V_size_0=mesh.num_edges, + __connectivity_E2V_0_range_1=mesh.num_edges, __connectivity_E2V_size_1=mesh.offset_provider_type["E2V"].max_neighbors, __connectivity_E2V_stride_0=mesh.offset_provider_type["E2V"].max_neighbors, __connectivity_E2V_stride_1=1, - __connectivity_V2E_size_0=mesh.num_vertices, + __connectivity_V2E_0_range_1=mesh.num_vertices, __connectivity_V2E_size_1=mesh.offset_provider_type["V2E"].max_neighbors, __connectivity_V2E_stride_0=mesh.offset_provider_type["V2E"].max_neighbors, __connectivity_V2E_stride_1=1, ) +def build_dace_sdfg( + ir: gtir.Program, offset_provider_type: gtx_common.OffsetProviderType +) -> Callable[..., Any]: + return dace_backend.build_sdfg_from_gtir( + ir, offset_provider_type, disable_field_origin_on_program_arguments=True + ) + + def test_gtir_broadcast(): val = np.random.rand() domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) @@ -116,7 +125,7 @@ def test_gtir_broadcast(): a = np.empty(N, dtype=np.float64) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, **FSYMBOLS) np.testing.assert_array_equal(a, val) @@ -152,7 +161,7 @@ def test_gtir_cast(): b = a.astype(np.float32) c = np.empty_like(a, dtype=np.bool_) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, c, **FSYMBOLS) np.testing.assert_array_equal(c, True) @@ -180,7 +189,7 @@ def test_gtir_copy_self(): a = np.random.rand(N) ref = a.copy() - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, **FSYMBOLS) assert np.allclose(a, ref) @@ -211,7 +220,7 @@ def test_gtir_tuple_swap(): b = np.random.rand(N) ref = (a.copy(), b.copy()) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, **FSYMBOLS) assert np.allclose(a, ref[1]) @@ -250,16 +259,16 @@ def test_gtir_tuple_args(): b = np.random.rand(N) c = np.empty_like(a) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) x_fields = (a, a, b) tuple_symbols = { - "__x_0_size_0": N, + "__x_0_0_range_1": N, "__x_0_stride_0": 1, - "__x_1_0_size_0": N, + "__x_1_0_0_range_1": N, "__x_1_0_stride_0": 1, - "__x_1_1_size_0": N, + "__x_1_1_0_range_1": N, "__x_1_1_stride_0": 1, } @@ -302,7 +311,7 @@ def test_gtir_tuple_expr(): b = np.random.rand(N) c = np.empty_like(a) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, c, **FSYMBOLS) assert np.allclose(c, a * 2 + b) @@ -356,7 +365,7 @@ def test_gtir_tuple_broadcast_scalar(): c = np.random.rand() d = np.empty(N, dtype=type(a)) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) x_fields = (a, b, c) @@ -387,7 +396,7 @@ def test_gtir_zero_dim_fields(): a = np.asarray(np.random.rand()) b = np.empty(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a.item(), b, **FSYMBOLS) assert np.allclose(a, b) @@ -421,16 +430,16 @@ def test_gtir_tuple_return(): a = np.random.rand(N) b = np.random.rand(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a), np.empty_like(a)) tuple_symbols = { - "__z_0_0_size_0": N, + "__z_0_0_0_range_1": N, "__z_0_0_stride_0": 1, - "__z_0_1_size_0": N, + "__z_0_1_0_range_1": N, "__z_0_1_stride_0": 1, - "__z_1_size_0": N, + "__z_1_0_range_1": N, "__z_1_stride_0": 1, } @@ -464,7 +473,7 @@ def test_gtir_tuple_target(): b = np.empty_like(a) ref = a.copy() - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, **FSYMBOLS) assert np.allclose(a, ref + 1) @@ -496,7 +505,7 @@ def test_gtir_update(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) a = np.random.rand(N) ref = a - 1.0 @@ -530,7 +539,7 @@ def test_gtir_sum2(): b = np.random.rand(N) c = np.empty_like(a) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, c, **FSYMBOLS) assert np.allclose(c, (a + b)) @@ -559,7 +568,7 @@ def test_gtir_sum2_sym(): a = np.random.rand(N) b = np.empty_like(a) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) sdfg(a, b, **FSYMBOLS) assert np.allclose(b, (a + a)) @@ -601,7 +610,7 @@ def test_gtir_sum3(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) d = np.empty_like(a) @@ -645,7 +654,7 @@ def test_gtir_cond(): b = np.random.rand(N) c = np.random.rand(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) for s1, s2 in [(1, 2), (2, 1)]: d = np.empty_like(a) @@ -687,12 +696,12 @@ def test_gtir_cond_with_tuple_return(): b = np.random.rand(N) c = np.random.rand(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) tuple_symbols = { - "__z_0_size_0": N, + "__z_0_0_range_1": N, "__z_0_stride_0": 1, - "__z_1_size_0": N, + "__z_1_0_range_1": N, "__z_1_stride_0": 1, } @@ -735,7 +744,7 @@ def test_gtir_cond_nested(): a = np.random.rand(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) for s1 in [False, True]: for s2 in [False, True]: @@ -841,9 +850,9 @@ def test_gtir_cartesian_shift_left(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) - sdfg(a, a_offset, b, **FSYMBOLS, __x_offset_size_0=N, __x_offset_stride_0=1) + sdfg(a, a_offset, b, **FSYMBOLS, __x_offset_0_range_1=N, __x_offset_stride_0=1) assert np.allclose(a[OFFSET:] + DELTA, b[:-OFFSET]) @@ -936,9 +945,9 @@ def test_gtir_cartesian_shift_right(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) - sdfg(a, a_offset, b, **FSYMBOLS, __x_offset_size_0=N, __x_offset_stride_0=1) + sdfg(a, a_offset, b, **FSYMBOLS, __x_offset_0_range_1=N, __x_offset_stride_0=1) assert np.allclose(a[:-OFFSET] + DELTA, b[OFFSET:]) @@ -1075,7 +1084,7 @@ def test_gtir_connectivity_shift(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) ce = np.empty([SIMPLE_MESH.num_cells, SIMPLE_MESH.num_edges]) @@ -1088,17 +1097,17 @@ def test_gtir_connectivity_shift(): connectivity_E2V=connectivity_E2V.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), - __ce_field_size_0=SIMPLE_MESH.num_cells, + __ce_field_0_range_1=SIMPLE_MESH.num_cells, __ce_field_size_1=SIMPLE_MESH.num_edges, __ce_field_stride_0=SIMPLE_MESH.num_edges, __ce_field_stride_1=1, - __ev_field_size_0=SIMPLE_MESH.num_edges, + __ev_field_0_range_1=SIMPLE_MESH.num_edges, __ev_field_size_1=SIMPLE_MESH.num_vertices, __ev_field_stride_0=SIMPLE_MESH.num_vertices, __ev_field_stride_1=1, - __c2e_offset_size_0=SIMPLE_MESH.num_cells, + __c2e_offset_0_range_1=SIMPLE_MESH.num_cells, __c2e_offset_stride_0=1, - __e2v_offset_size_0=SIMPLE_MESH.num_edges, + __e2v_offset_0_range_1=SIMPLE_MESH.num_edges, __e2v_offset_stride_0=1, ) assert np.allclose(ce, ref) @@ -1136,7 +1145,7 @@ def test_gtir_connectivity_shift_chain(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) connectivity_E2V = SIMPLE_MESH.offset_provider["E2V"] assert isinstance(connectivity_E2V, gtx_common.NeighborTable) @@ -1158,7 +1167,7 @@ def test_gtir_connectivity_shift_chain(): connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), - __edges_out_size_0=SIMPLE_MESH.num_edges, + __edges_out_0_range_1=SIMPLE_MESH.num_edges, __edges_out_stride_0=1, ) assert np.allclose(e_out, ref) @@ -1196,7 +1205,7 @@ def test_gtir_neighbors_as_input(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) @@ -1217,7 +1226,7 @@ def test_gtir_neighbors_as_input(): connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), - __v2e_field_size_0=SIMPLE_MESH.num_vertices, + __v2e_field_0_range_1=SIMPLE_MESH.num_vertices, __v2e_field_size_1=connectivity_V2E.shape[1], __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, @@ -1254,7 +1263,7 @@ def test_gtir_neighbors_as_output(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) @@ -1268,7 +1277,7 @@ def test_gtir_neighbors_as_output(): connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), - __v2e_field_size_0=SIMPLE_MESH.num_vertices, + __v2e_field_0_range_1=SIMPLE_MESH.num_vertices, __v2e_field_size_1=connectivity_V2E.max_neighbors, __v2e_field_stride_0=connectivity_V2E.max_neighbors, __v2e_field_stride_1=1, @@ -1317,7 +1326,7 @@ def test_gtir_reduce(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) # new empty output field v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) @@ -1377,7 +1386,7 @@ def test_gtir_reduce_with_skip_values(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SKIP_VALUE_MESH.offset_provider_type) # new empty output field v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) @@ -1446,7 +1455,7 @@ def test_gtir_reduce_dot_product(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SKIP_VALUE_MESH.offset_provider_type) sdfg( v2e_field, @@ -1454,7 +1463,7 @@ def test_gtir_reduce_dot_product(): v, connectivity_V2E=connectivity_V2E.ndarray, **make_mesh_symbols(SKIP_VALUE_MESH), - __v2e_field_size_0=SKIP_VALUE_MESH.num_vertices, + __v2e_field_0_range_1=SKIP_VALUE_MESH.num_vertices, __v2e_field_size_1=connectivity_V2E.shape[1], __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, @@ -1503,7 +1512,7 @@ def test_gtir_reduce_with_cond_neighbors(): e = np.random.rand(SKIP_VALUE_MESH.num_edges) for use_sparse in [False, True]: - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SKIP_VALUE_MESH.offset_provider_type) v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) v_ref = [ @@ -1531,7 +1540,7 @@ def test_gtir_reduce_with_cond_neighbors(): connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SKIP_VALUE_MESH), - __v2e_field_size_0=SKIP_VALUE_MESH.num_vertices, + __v2e_field_0_range_1=SKIP_VALUE_MESH.num_vertices, __v2e_field_size_1=connectivity_V2E.shape[1], __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, @@ -1618,7 +1627,7 @@ def test_gtir_symbolic_domain(): b = np.random.rand(N) ref = np.concatenate((b[0:MARGIN], a[MARGIN : N - MARGIN] * 8, b[N - MARGIN : N])) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, **FSYMBOLS) assert np.allclose(b, ref) @@ -1666,7 +1675,7 @@ def test_gtir_let_lambda(): b = np.random.rand(N) ref = np.concatenate((b[0:1], a[1 : N - 1] * 8, b[N - 1 : N])) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) sdfg(a, b, **FSYMBOLS) assert np.allclose(b, ref) @@ -1701,7 +1710,7 @@ def test_gtir_let_lambda_scalar_expression(): c = np.random.rand(N) d = np.empty_like(c) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) sdfg(a, b, c, d, **FSYMBOLS) assert np.allclose(d, (a * a * b * b * c)) @@ -1750,7 +1759,7 @@ def test_gtir_let_lambda_with_connectivity(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) e = np.random.rand(SIMPLE_MESH.num_edges) v = np.random.rand(SIMPLE_MESH.num_vertices) @@ -1797,7 +1806,7 @@ def test_gtir_let_lambda_with_cond(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) a = np.random.rand(N) for s in [False, True]: @@ -1835,16 +1844,16 @@ def test_gtir_let_lambda_with_tuple1(): a = np.random.rand(N) b = np.random.rand(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a)) a_ref = np.concatenate((z_fields[0][:1], a[1 : N - 1], z_fields[0][N - 1 :])) b_ref = np.concatenate((z_fields[1][:1], b[1 : N - 1], z_fields[1][N - 1 :])) tuple_symbols = { - "__z_0_size_0": N, + "__z_0_0_range_1": N, "__z_0_stride_0": 1, - "__z_1_size_0": N, + "__z_1_0_range_1": N, "__z_1_stride_0": 1, } @@ -1884,16 +1893,16 @@ def test_gtir_let_lambda_with_tuple2(): a = np.random.rand(N) b = np.random.rand(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a), np.empty_like(a)) tuple_symbols = { - "__z_0_size_0": N, + "__z_0_0_range_1": N, "__z_0_stride_0": 1, - "__z_1_size_0": N, + "__z_1_0_range_1": N, "__z_1_stride_0": 1, - "__z_2_size_0": N, + "__z_2_0_range_1": N, "__z_2_stride_0": 1, } @@ -1947,14 +1956,14 @@ def test_gtir_if_scalars(): d1 = np.random.randint(0, 1000) d2 = np.random.randint(0, 1000) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) tuple_symbols = { - "__x_0_size_0": N, + "__x_0_0_range_1": N, "__x_0_stride_0": 1, - "__x_1_0_size_0": N, + "__x_1_0_0_range_1": N, "__x_1_0_stride_0": 1, - "__x_1_1_size_0": N, + "__x_1_1_0_range_1": N, "__x_1_1_stride_0": 1, } @@ -1990,7 +1999,7 @@ def test_gtir_if_values(): b = np.random.rand(N) c = np.empty_like(a) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, c, **FSYMBOLS) assert np.allclose(c, np.where(a < b, a, b)) @@ -2032,7 +2041,7 @@ def test_gtir_index(): # we need to run domain inference in order to add the domain annex information to the index node. testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) ref = np.concatenate( (v[:MARGIN], np.arange(MARGIN, N - MARGIN, dtype=np.int32), v[N - MARGIN :]) From c06cac31f043f6da5ef612db57adab4fe3649d73 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 31 Jan 2025 07:58:32 +0100 Subject: [PATCH 127/178] bug[next]: Use same python in CMake as used for execution (#1567) Co-authored-by: Lorenzo Varese <55581163+lorenzovarese@users.noreply.github.com> --- .../next/otf/compilation/build_systems/cmake_lists.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py index 0533adac81..23c80793c7 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py @@ -88,9 +88,15 @@ def visit_FindDependency(self, dep: FindDependency) -> str: # Instead, design this to be extensible (refer to ADR-0016). match dep.name: case "nanobind": + import sys + import nanobind - py = "find_package(Python COMPONENTS Interpreter Development REQUIRED)" + py = f""" + set(Python_EXECUTABLE {sys.executable}) + + find_package(Python COMPONENTS Interpreter Development REQUIRED) + """ nb = f"find_package(nanobind CONFIG REQUIRED PATHS {nanobind.cmake_dir()} NO_DEFAULT_PATHS)" return py + "\n" + nb case "gridtools_cpu" | "gridtools_gpu": From 6f835983edb0dd8c2b19831098acf52ea338c282 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 31 Jan 2025 10:20:33 +0100 Subject: [PATCH 128/178] build: use frozen dependencies when running mypy in pre-commit (#1841) ... otherwise it might update `uv.lock` which will create a pre-commit error because of changed files. See also https://github.com/astral-sh/uv/issues/10845 --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4222224cc4..afca7bfa05 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -63,7 +63,7 @@ repos: hooks: - id: mypy name: mypy static type checker - entry: uv run mypy --no-install-types src/ + entry: uv run --frozen mypy --no-install-types src/ language: system types_or: [python, pyi] pass_filenames: false From b18bbf98e20d70646a693030762bd766571fba9c Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sat, 1 Feb 2025 20:41:00 +0100 Subject: [PATCH 129/178] bug[next]: Fix accept args (#1830) `type_info.accept_args` is used in the iterator type inference to check if a node with type `ts.FunctionType` is callable for a given set of arguments [here](https://github.com/GridTools/gt4py/blob/764ef5098078b327d086c4e06db357b45b046669/src/gt4py/next/iterator/type_system/inference.py#L237). This failed when the function type had an iterator argument with `position_dims="unknown"`. This PR promotes the `_is_compatible_type` function from the iterator type inference to a shared function available in all type systems and switches `accept_args` (or more precisely `function_signature_incompatibilities_func`) to use that function instead of `is_concretizable`. Additionally this PR contains a small cleanup of the type deduction / type system tests: - `test_type_deduction.py` was moved from integration to unit tests (as it only contains unit tests). - All type system tests are moved from `test_type_deduction.py` as they are not specific to the frontend type deduction. Co-authored-by: SF-N --- .../next/iterator/type_system/inference.py | 63 +-- src/gt4py/next/type_system/type_info.py | 72 ++- .../ffront_tests/test_type_deduction.py | 414 +++++++++++++++++ .../test_type_system.py} | 419 +----------------- 4 files changed, 505 insertions(+), 463 deletions(-) create mode 100644 tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py rename tests/next_tests/{integration_tests/feature_tests/ffront_tests/test_type_deduction.py => unit_tests/test_type_system.py} (51%) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 901cb103da..fe450625db 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -32,66 +32,11 @@ def _is_representable_as_int(s: int | str) -> bool: return False -def _is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec): - """ - Predicate to determine if two types are compatible. - - This function gracefully handles: - - iterators with unknown positions which are considered compatible to any other positions - of another iterator. - - iterators which are defined everywhere, i.e. empty defined dimensions - Beside that this function simply checks for equality of types. - - >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) - >>> IDim = common.Dimension(value="IDim") - >>> type_on_i_of_i_it = it_ts.IteratorType( - ... position_dims=[IDim], defined_dims=[IDim], element_type=bool_type - ... ) - >>> type_on_undefined_of_i_it = it_ts.IteratorType( - ... position_dims="unknown", defined_dims=[IDim], element_type=bool_type - ... ) - >>> _is_compatible_type(type_on_i_of_i_it, type_on_undefined_of_i_it) - True - - >>> JDim = common.Dimension(value="JDim") - >>> type_on_j_of_j_it = it_ts.IteratorType( - ... position_dims=[JDim], defined_dims=[JDim], element_type=bool_type - ... ) - >>> _is_compatible_type(type_on_i_of_i_it, type_on_j_of_j_it) - False - """ - is_compatible = True - - if isinstance(type_a, it_ts.IteratorType) and isinstance(type_b, it_ts.IteratorType): - if not any(el_type.position_dims == "unknown" for el_type in [type_a, type_b]): - is_compatible &= type_a.position_dims == type_b.position_dims - if type_a.defined_dims and type_b.defined_dims: - is_compatible &= type_a.defined_dims == type_b.defined_dims - is_compatible &= type_a.element_type == type_b.element_type - elif isinstance(type_a, ts.TupleType) and isinstance(type_b, ts.TupleType): - for el_type_a, el_type_b in zip(type_a.types, type_b.types, strict=True): - is_compatible &= _is_compatible_type(el_type_a, el_type_b) - elif isinstance(type_a, ts.FunctionType) and isinstance(type_b, ts.FunctionType): - for arg_a, arg_b in zip(type_a.pos_only_args, type_b.pos_only_args, strict=True): - is_compatible &= _is_compatible_type(arg_a, arg_b) - for arg_a, arg_b in zip( - type_a.pos_or_kw_args.values(), type_b.pos_or_kw_args.values(), strict=True - ): - is_compatible &= _is_compatible_type(arg_a, arg_b) - for arg_a, arg_b in zip( - type_a.kw_only_args.values(), type_b.kw_only_args.values(), strict=True - ): - is_compatible &= _is_compatible_type(arg_a, arg_b) - is_compatible &= _is_compatible_type(type_a.returns, type_b.returns) - else: - is_compatible &= type_info.is_concretizable(type_a, type_b) - - return is_compatible - - def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: if node.type: - assert _is_compatible_type(node.type, type_), "Node already has a type which differs." + assert type_info.is_compatible_type( + node.type, type_ + ), "Node already has a type which differs." node.type = type_ @@ -475,7 +420,7 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: if isinstance(node, itir.Node): if isinstance(result, ts.TypeSpec): if node.type and not isinstance(node.type, ts.DeferredType): - assert _is_compatible_type(node.type, result) + assert type_info.is_compatible_type(node.type, result) node.type = result elif isinstance(result, ObservableTypeSynthesizer) or result is None: pass diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 26373c647f..27dd2cf02c 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -26,6 +26,7 @@ from gt4py.eve.utils import XIterable, xiter from gt4py.next import common +from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_specifications as ts @@ -432,6 +433,69 @@ def contains_local_field(type_: ts.TypeSpec) -> bool: ) +# TODO(tehrengruber): This function has specializations on Iterator types, which are not part of +# the general / shared type system. This functionality should be moved to the iterator-only +# type system, but we need some sort of multiple dispatch for that. +# TODO(tehrengruber): Should this have a direction like is_concretizable? +def is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec) -> bool: + """ + Predicate to determine if two types are compatible. + + This function gracefully handles: + - iterators with unknown positions which are considered compatible to any other positions + of another iterator. + - iterators which are defined everywhere, i.e. empty defined dimensions + Beside that this function simply checks for equality of types. + + >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) + >>> IDim = common.Dimension(value="IDim") + >>> type_on_i_of_i_it = it_ts.IteratorType( + ... position_dims=[IDim], defined_dims=[IDim], element_type=bool_type + ... ) + >>> type_on_undefined_of_i_it = it_ts.IteratorType( + ... position_dims="unknown", defined_dims=[IDim], element_type=bool_type + ... ) + >>> is_compatible_type(type_on_i_of_i_it, type_on_undefined_of_i_it) + True + + >>> JDim = common.Dimension(value="JDim") + >>> type_on_j_of_j_it = it_ts.IteratorType( + ... position_dims=[JDim], defined_dims=[JDim], element_type=bool_type + ... ) + >>> is_compatible_type(type_on_i_of_i_it, type_on_j_of_j_it) + False + """ + is_compatible = True + + if isinstance(type_a, it_ts.IteratorType) and isinstance(type_b, it_ts.IteratorType): + if not any(el_type.position_dims == "unknown" for el_type in [type_a, type_b]): + is_compatible &= type_a.position_dims == type_b.position_dims + if type_a.defined_dims and type_b.defined_dims: + is_compatible &= type_a.defined_dims == type_b.defined_dims + is_compatible &= type_a.element_type == type_b.element_type + elif isinstance(type_a, ts.TupleType) and isinstance(type_b, ts.TupleType): + if len(type_a.types) != len(type_b.types): + return False + for el_type_a, el_type_b in zip(type_a.types, type_b.types, strict=True): + is_compatible &= is_compatible_type(el_type_a, el_type_b) + elif isinstance(type_a, ts.FunctionType) and isinstance(type_b, ts.FunctionType): + for arg_a, arg_b in zip(type_a.pos_only_args, type_b.pos_only_args, strict=True): + is_compatible &= is_compatible_type(arg_a, arg_b) + for arg_a, arg_b in zip( + type_a.pos_or_kw_args.values(), type_b.pos_or_kw_args.values(), strict=True + ): + is_compatible &= is_compatible_type(arg_a, arg_b) + for arg_a, arg_b in zip( + type_a.kw_only_args.values(), type_b.kw_only_args.values(), strict=True + ): + is_compatible &= is_compatible_type(arg_a, arg_b) + is_compatible &= is_compatible_type(type_a.returns, type_b.returns) + else: + is_compatible &= is_concretizable(type_a, type_b) + + return is_compatible + + def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: """ Check if ``symbol_type`` can be concretized to ``to_type``. @@ -725,11 +789,7 @@ def function_signature_incompatibilities_func( for i, (a_arg, b_arg) in enumerate( zip(list(func_type.pos_only_args) + list(func_type.pos_or_kw_args.values()), args) ): - if ( - b_arg is not UNDEFINED_ARG - and a_arg != b_arg - and not is_concretizable(a_arg, to_type=b_arg) - ): + if b_arg is not UNDEFINED_ARG and a_arg != b_arg and not is_compatible_type(a_arg, b_arg): if i < len(func_type.pos_only_args): arg_repr = f"{_number_to_ordinal_number(i + 1)} argument" else: @@ -739,7 +799,7 @@ def function_signature_incompatibilities_func( for kwarg in set(func_type.kw_only_args.keys()) & set(kwargs.keys()): if (a_kwarg := func_type.kw_only_args[kwarg]) != ( b_kwarg := kwargs[kwarg] - ) and not is_concretizable(a_kwarg, to_type=b_kwarg): + ) and not is_compatible_type(a_kwarg, b_kwarg): yield f"Expected keyword argument '{kwarg}' to be of type '{func_type.kw_only_args[kwarg]}', got '{kwargs[kwarg]}'." diff --git a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py new file mode 100644 index 0000000000..254772fd8a --- /dev/null +++ b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py @@ -0,0 +1,414 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import re +from typing import Optional, Pattern + +import pytest + +import gt4py.next.ffront.type_specifications +from gt4py.next import ( + Dimension, + DimensionKind, + Field, + FieldOffset, + astype, + broadcast, + common, + errors, + float32, + float64, + int32, + int64, + neighbor_sum, + where, +) +from gt4py.next.ffront.ast_passes import single_static_assign as ssa +from gt4py.next.ffront.experimental import as_offset +from gt4py.next.ffront.func_to_foast import FieldOperatorParser +from gt4py.next.type_system import type_info, type_specifications as ts + +TDim = Dimension("TDim") # Meaningless dimension, used for tests. + + +def test_unpack_assign(): + def unpack_explicit_tuple( + a: Field[[TDim], float64], b: Field[[TDim], float64] + ) -> tuple[Field[[TDim], float64], Field[[TDim], float64]]: + tmp_a, tmp_b = (a, b) + return tmp_a, tmp_b + + parsed = FieldOperatorParser.apply_to_function(unpack_explicit_tuple) + + assert parsed.body.annex.symtable[ssa.unique_name("tmp_a", 0)].type == ts.FieldType( + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) + ) + assert parsed.body.annex.symtable[ssa.unique_name("tmp_b", 0)].type == ts.FieldType( + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) + ) + + +def test_assign_tuple(): + def temp_tuple(a: Field[[TDim], float64], b: Field[[TDim], int64]): + tmp = a, b + return tmp + + parsed = FieldOperatorParser.apply_to_function(temp_tuple) + + assert parsed.body.annex.symtable[ssa.unique_name("tmp", 0)].type == ts.TupleType( + types=[ + ts.FieldType(dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None)), + ts.FieldType(dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64, shape=None)), + ] + ) + + +def test_adding_bool(): + """Expect an error when using arithmetic on bools.""" + + def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): + return a + b + + with pytest.raises( + errors.DSLError, match=(r"Type 'Field\[\[TDim\], bool\]' can not be used in operator '\+'.") + ): + _ = FieldOperatorParser.apply_to_function(add_bools) + + +def test_binop_nonmatching_dims(): + """Binary operations can only work when both fields have the same dimensions.""" + X = Dimension("X") + Y = Dimension("Y") + + def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]): + return a + b + + with pytest.raises( + errors.DSLError, + match=( + r"Could not promote 'Field\[\[X], float64\]' and 'Field\[\[Y\], float64\]' to common type in call to +." + ), + ): + _ = FieldOperatorParser.apply_to_function(nonmatching) + + +def test_bitopping_float(): + def float_bitop(a: Field[[TDim], float], b: Field[[TDim], float]): + return a & b + + with pytest.raises( + errors.DSLError, + match=(r"Type 'Field\[\[TDim\], float64\]' can not be used in operator '\&'."), + ): + _ = FieldOperatorParser.apply_to_function(float_bitop) + + +def test_signing_bool(): + def sign_bool(a: Field[[TDim], bool]): + return -a + + with pytest.raises( + errors.DSLError, + match=r"Incompatible type for unary operator '\-': 'Field\[\[TDim\], bool\]'.", + ): + _ = FieldOperatorParser.apply_to_function(sign_bool) + + +def test_notting_int(): + def not_int(a: Field[[TDim], int64]): + return not a + + with pytest.raises( + errors.DSLError, + match=r"Incompatible type for unary operator 'not': 'Field\[\[TDim\], int64\]'.", + ): + _ = FieldOperatorParser.apply_to_function(not_int) + + +@pytest.fixture +def premap_setup(): + X = Dimension("X") + Y = Dimension("Y") + Y2XDim = Dimension("Y2X", kind=DimensionKind.LOCAL) + Y2X = FieldOffset("Y2X", source=X, target=(Y, Y2XDim)) + return X, Y, Y2XDim, Y2X + + +def test_premap(premap_setup): + X, Y, Y2XDim, Y2X = premap_setup + + def premap_fo(bar: Field[[X], int64]) -> Field[[Y], int64]: + return bar(Y2X[0]) + + parsed = FieldOperatorParser.apply_to_function(premap_fo) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[Y], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64) + ) + + +def test_premap_nbfield(premap_setup): + X, Y, Y2XDim, Y2X = premap_setup + + def premap_fo(bar: Field[[X], int64]) -> Field[[Y, Y2XDim], int64]: + return bar(Y2X) + + parsed = FieldOperatorParser.apply_to_function(premap_fo) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[Y, Y2XDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64) + ) + + +def test_premap_reduce(premap_setup): + X, Y, Y2XDim, Y2X = premap_setup + + def premap_fo(bar: Field[[X], int32]) -> Field[[Y], int32]: + return 2 * neighbor_sum(bar(Y2X), axis=Y2XDim) + + parsed = FieldOperatorParser.apply_to_function(premap_fo) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[Y], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32) + ) + + +def test_premap_reduce_sparse(premap_setup): + X, Y, Y2XDim, Y2X = premap_setup + + def premap_fo(bar: Field[[Y, Y2XDim], int32]) -> Field[[Y], int32]: + return 5 * neighbor_sum(bar, axis=Y2XDim) + + parsed = FieldOperatorParser.apply_to_function(premap_fo) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[Y], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32) + ) + + +def test_mismatched_literals(): + def mismatched_lit() -> Field[[TDim], "float32"]: + return float32("1.0") + float64("1.0") + + with pytest.raises( + errors.DSLError, + match=(r"Could not promote 'float32' and 'float64' to common type in call to +."), + ): + _ = FieldOperatorParser.apply_to_function(mismatched_lit) + + +def test_broadcast_multi_dim(): + ADim = Dimension("ADim") + BDim = Dimension("BDim") + CDim = Dimension("CDim") + + def simple_broadcast(a: Field[[ADim], float64]): + return broadcast(a, (ADim, BDim, CDim)) + + parsed = FieldOperatorParser.apply_to_function(simple_broadcast) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[ADim, BDim, CDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) + + +def test_broadcast_disjoint(): + ADim = Dimension("ADim") + BDim = Dimension("BDim") + CDim = Dimension("CDim") + + def disjoint_broadcast(a: Field[[ADim], float64]): + return broadcast(a, (BDim, CDim)) + + with pytest.raises(errors.DSLError, match=r"expected broadcast dimension\(s\) \'.*\' missing"): + _ = FieldOperatorParser.apply_to_function(disjoint_broadcast) + + +def test_broadcast_badtype(): + ADim = Dimension("ADim") + BDim = "BDim" + CDim = Dimension("CDim") + + def badtype_broadcast(a: Field[[ADim], float64]): + return broadcast(a, (BDim, CDim)) + + with pytest.raises( + errors.DSLError, match=r"expected all broadcast dimensions to be of type 'Dimension'." + ): + _ = FieldOperatorParser.apply_to_function(badtype_broadcast) + + +def test_where_dim(): + ADim = Dimension("ADim") + BDim = Dimension("BDim") + + def simple_where(a: Field[[ADim], bool], b: Field[[ADim, BDim], float64]): + return where(a, b, 9.0) + + parsed = FieldOperatorParser.apply_to_function(simple_where) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[ADim, BDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) + + +def test_where_broadcast_dim(): + ADim = Dimension("ADim") + + def simple_where(a: Field[[ADim], bool]): + return where(a, 5.0, 9.0) + + parsed = FieldOperatorParser.apply_to_function(simple_where) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) + + +def test_where_tuple_dim(): + ADim = Dimension("ADim") + + def tuple_where(a: Field[[ADim], bool], b: Field[[ADim], float64]): + return where(a, ((5.0, 9.0), (b, 6.0)), ((8.0, b), (5.0, 9.0))) + + parsed = FieldOperatorParser.apply_to_function(tuple_where) + + assert parsed.body.stmts[0].value.type == ts.TupleType( + types=[ + ts.TupleType( + types=[ + ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + ] + ), + ts.TupleType( + types=[ + ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + ] + ), + ] + ) + + +def test_where_bad_dim(): + ADim = Dimension("ADim") + + def bad_dim_where(a: Field[[ADim], bool], b: Field[[ADim], float64]): + return where(a, ((5.0, 9.0), (b, 6.0)), b) + + with pytest.raises(errors.DSLError, match=r"Return arguments need to be of same type"): + _ = FieldOperatorParser.apply_to_function(bad_dim_where) + + +def test_where_mixed_dims(): + ADim = Dimension("ADim") + BDim = Dimension("BDim") + + def tuple_where_mix_dims( + a: Field[[ADim], bool], b: Field[[ADim], float64], c: Field[[ADim, BDim], float64] + ): + return where(a, ((c, 9.0), (b, 6.0)), ((8.0, b), (5.0, 9.0))) + + parsed = FieldOperatorParser.apply_to_function(tuple_where_mix_dims) + + assert parsed.body.stmts[0].value.type == ts.TupleType( + types=[ + ts.TupleType( + types=[ + ts.FieldType( + dims=[ADim, BDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ), + ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + ] + ), + ts.TupleType( + types=[ + ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + ] + ), + ] + ) + + +def test_astype_dtype(): + def simple_astype(a: Field[[TDim], float64]): + return astype(a, bool) + + parsed = FieldOperatorParser.apply_to_function(simple_astype) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL) + ) + + +def test_astype_wrong_dtype(): + def simple_astype(a: Field[[TDim], float64]): + # we just use broadcast here, but anything with type function is fine + return astype(a, broadcast) + + with pytest.raises( + errors.DSLError, + match=r"Invalid call to 'astype': second argument must be a scalar type, got.", + ): + _ = FieldOperatorParser.apply_to_function(simple_astype) + + +def test_astype_wrong_value_type(): + def simple_astype(a: Field[[TDim], float64]): + # we just use broadcast here but anything that is not a field, scalar or tuple thereof works + return astype(broadcast, bool) + + with pytest.raises(errors.DSLError) as exc_info: + _ = FieldOperatorParser.apply_to_function(simple_astype) + + assert ( + re.search("Expected 1st argument to be of type", exc_info.value.__cause__.args[0]) + is not None + ) + + +def test_mod_floats(): + def modulo_floats(inp: Field[[TDim], float]): + return inp % 3.0 + + with pytest.raises(errors.DSLError, match=r"Type 'float64' can not be used in operator '%'"): + _ = FieldOperatorParser.apply_to_function(modulo_floats) + + +def test_undefined_symbols(): + def return_undefined(): + return undefined_symbol + + with pytest.raises(errors.DSLError, match="Undeclared symbol"): + _ = FieldOperatorParser.apply_to_function(return_undefined) + + +def test_as_offset_dim(): + ADim = Dimension("ADim") + BDim = Dimension("BDim") + Boff = FieldOffset("Boff", source=BDim, target=(BDim,)) + + def as_offset_dim(a: Field[[ADim, BDim], float], b: Field[[ADim], int]): + return a(as_offset(Boff, b)) + + with pytest.raises(errors.DSLError, match=f"not in list of offset field dimensions"): + _ = FieldOperatorParser.apply_to_function(as_offset_dim) + + +def test_as_offset_dtype(): + ADim = Dimension("ADim") + BDim = Dimension("BDim") + Boff = FieldOffset("Boff", source=BDim, target=(BDim,)) + + def as_offset_dtype(a: Field[[ADim, BDim], float], b: Field[[BDim], float]): + return a(as_offset(Boff, b)) + + with pytest.raises(errors.DSLError, match=f"expected integer for offset field dtype"): + _ = FieldOperatorParser.apply_to_function(as_offset_dtype) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/test_type_system.py similarity index 51% rename from tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py rename to tests/next_tests/unit_tests/test_type_system.py index 5352724827..99758d6f14 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/unit_tests/test_type_system.py @@ -11,28 +11,13 @@ import pytest -import gt4py.next.ffront.type_specifications from gt4py.next import ( Dimension, DimensionKind, - Field, - FieldOffset, - astype, - broadcast, - common, - errors, - float32, - float64, - int32, - int64, - neighbor_sum, - where, ) -from gt4py.next.ffront.ast_passes import single_static_assign as ssa -from gt4py.next.ffront.experimental import as_offset -from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.type_system import type_info, type_specifications as ts - +from gt4py.next.ffront import type_specifications as ts_ffront +from gt4py.next.iterator.type_system import type_specifications as ts_it TDim = Dimension("TDim") # Meaningless dimension, used for tests. @@ -107,7 +92,7 @@ def callable_type_info_cases(): unary_tuple_arg_func_type = ts.FunctionType( pos_only_args=[tuple_type], pos_or_kw_args={}, kw_only_args={}, returns=ts.VoidType() ) - fieldop_type = gt4py.next.ffront.type_specifications.FieldOperatorType( + fieldop_type = ts_ffront.FieldOperatorType( definition=ts.FunctionType( pos_only_args=[field_type, float_type], pos_or_kw_args={}, @@ -115,7 +100,7 @@ def callable_type_info_cases(): returns=field_type, ) ) - scanop_type = gt4py.next.ffront.type_specifications.ScanOperatorType( + scanop_type = ts_ffront.ScanOperatorType( axis=KDim, definition=ts.FunctionType( pos_only_args=[], @@ -124,7 +109,7 @@ def callable_type_info_cases(): returns=float_type, ), ) - tuple_scanop_type = gt4py.next.ffront.type_specifications.ScanOperatorType( + tuple_scanop_type = ts_ffront.ScanOperatorType( axis=KDim, definition=ts.FunctionType( pos_only_args=[], @@ -367,6 +352,22 @@ def callable_type_info_cases(): ], ts.FieldType(dims=[IDim, JDim, KDim], dtype=float_type), ), + ( + ts.FunctionType( + pos_only_args=[ + ts_it.IteratorType( + position_dims="unknown", defined_dims=[], element_type=float_type + ), + ], + pos_or_kw_args={}, + kw_only_args={}, + returns=ts.VoidType(), + ), + [ts_it.IteratorType(position_dims=[IDim], defined_dims=[], element_type=float_type)], + {}, + [], + ts.VoidType(), + ), ] @@ -408,381 +409,3 @@ def test_return_type( accepts_args = type_info.accepts_args(func_type, with_args=args, with_kwargs=kwargs) if accepts_args: assert type_info.return_type(func_type, with_args=args, with_kwargs=kwargs) == return_type - - -def test_unpack_assign(): - def unpack_explicit_tuple( - a: Field[[TDim], float64], b: Field[[TDim], float64] - ) -> tuple[Field[[TDim], float64], Field[[TDim], float64]]: - tmp_a, tmp_b = (a, b) - return tmp_a, tmp_b - - parsed = FieldOperatorParser.apply_to_function(unpack_explicit_tuple) - - assert parsed.body.annex.symtable[ssa.unique_name("tmp_a", 0)].type == ts.FieldType( - dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) - ) - assert parsed.body.annex.symtable[ssa.unique_name("tmp_b", 0)].type == ts.FieldType( - dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) - ) - - -def test_assign_tuple(): - def temp_tuple(a: Field[[TDim], float64], b: Field[[TDim], int64]): - tmp = a, b - return tmp - - parsed = FieldOperatorParser.apply_to_function(temp_tuple) - - assert parsed.body.annex.symtable[ssa.unique_name("tmp", 0)].type == ts.TupleType( - types=[ - ts.FieldType(dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None)), - ts.FieldType(dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64, shape=None)), - ] - ) - - -def test_adding_bool(): - """Expect an error when using arithmetic on bools.""" - - def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): - return a + b - - with pytest.raises( - errors.DSLError, match=(r"Type 'Field\[\[TDim\], bool\]' can not be used in operator '\+'.") - ): - _ = FieldOperatorParser.apply_to_function(add_bools) - - -def test_binop_nonmatching_dims(): - """Binary operations can only work when both fields have the same dimensions.""" - X = Dimension("X") - Y = Dimension("Y") - - def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]): - return a + b - - with pytest.raises( - errors.DSLError, - match=( - r"Could not promote 'Field\[\[X], float64\]' and 'Field\[\[Y\], float64\]' to common type in call to +." - ), - ): - _ = FieldOperatorParser.apply_to_function(nonmatching) - - -def test_bitopping_float(): - def float_bitop(a: Field[[TDim], float], b: Field[[TDim], float]): - return a & b - - with pytest.raises( - errors.DSLError, - match=(r"Type 'Field\[\[TDim\], float64\]' can not be used in operator '\&'."), - ): - _ = FieldOperatorParser.apply_to_function(float_bitop) - - -def test_signing_bool(): - def sign_bool(a: Field[[TDim], bool]): - return -a - - with pytest.raises( - errors.DSLError, - match=r"Incompatible type for unary operator '\-': 'Field\[\[TDim\], bool\]'.", - ): - _ = FieldOperatorParser.apply_to_function(sign_bool) - - -def test_notting_int(): - def not_int(a: Field[[TDim], int64]): - return not a - - with pytest.raises( - errors.DSLError, - match=r"Incompatible type for unary operator 'not': 'Field\[\[TDim\], int64\]'.", - ): - _ = FieldOperatorParser.apply_to_function(not_int) - - -@pytest.fixture -def premap_setup(): - X = Dimension("X") - Y = Dimension("Y") - Y2XDim = Dimension("Y2X", kind=DimensionKind.LOCAL) - Y2X = FieldOffset("Y2X", source=X, target=(Y, Y2XDim)) - return X, Y, Y2XDim, Y2X - - -def test_premap(premap_setup): - X, Y, Y2XDim, Y2X = premap_setup - - def premap_fo(bar: Field[[X], int64]) -> Field[[Y], int64]: - return bar(Y2X[0]) - - parsed = FieldOperatorParser.apply_to_function(premap_fo) - - assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[Y], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64) - ) - - -def test_premap_nbfield(premap_setup): - X, Y, Y2XDim, Y2X = premap_setup - - def premap_fo(bar: Field[[X], int64]) -> Field[[Y, Y2XDim], int64]: - return bar(Y2X) - - parsed = FieldOperatorParser.apply_to_function(premap_fo) - - assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[Y, Y2XDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64) - ) - - -def test_premap_reduce(premap_setup): - X, Y, Y2XDim, Y2X = premap_setup - - def premap_fo(bar: Field[[X], int32]) -> Field[[Y], int32]: - return 2 * neighbor_sum(bar(Y2X), axis=Y2XDim) - - parsed = FieldOperatorParser.apply_to_function(premap_fo) - - assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[Y], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32) - ) - - -def test_premap_reduce_sparse(premap_setup): - X, Y, Y2XDim, Y2X = premap_setup - - def premap_fo(bar: Field[[Y, Y2XDim], int32]) -> Field[[Y], int32]: - return 5 * neighbor_sum(bar, axis=Y2XDim) - - parsed = FieldOperatorParser.apply_to_function(premap_fo) - - assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[Y], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32) - ) - - -def test_mismatched_literals(): - def mismatched_lit() -> Field[[TDim], "float32"]: - return float32("1.0") + float64("1.0") - - with pytest.raises( - errors.DSLError, - match=(r"Could not promote 'float32' and 'float64' to common type in call to +."), - ): - _ = FieldOperatorParser.apply_to_function(mismatched_lit) - - -def test_broadcast_multi_dim(): - ADim = Dimension("ADim") - BDim = Dimension("BDim") - CDim = Dimension("CDim") - - def simple_broadcast(a: Field[[ADim], float64]): - return broadcast(a, (ADim, BDim, CDim)) - - parsed = FieldOperatorParser.apply_to_function(simple_broadcast) - - assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[ADim, BDim, CDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - ) - - -def test_broadcast_disjoint(): - ADim = Dimension("ADim") - BDim = Dimension("BDim") - CDim = Dimension("CDim") - - def disjoint_broadcast(a: Field[[ADim], float64]): - return broadcast(a, (BDim, CDim)) - - with pytest.raises(errors.DSLError, match=r"expected broadcast dimension\(s\) \'.*\' missing"): - _ = FieldOperatorParser.apply_to_function(disjoint_broadcast) - - -def test_broadcast_badtype(): - ADim = Dimension("ADim") - BDim = "BDim" - CDim = Dimension("CDim") - - def badtype_broadcast(a: Field[[ADim], float64]): - return broadcast(a, (BDim, CDim)) - - with pytest.raises( - errors.DSLError, match=r"expected all broadcast dimensions to be of type 'Dimension'." - ): - _ = FieldOperatorParser.apply_to_function(badtype_broadcast) - - -def test_where_dim(): - ADim = Dimension("ADim") - BDim = Dimension("BDim") - - def simple_where(a: Field[[ADim], bool], b: Field[[ADim, BDim], float64]): - return where(a, b, 9.0) - - parsed = FieldOperatorParser.apply_to_function(simple_where) - - assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[ADim, BDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - ) - - -def test_where_broadcast_dim(): - ADim = Dimension("ADim") - - def simple_where(a: Field[[ADim], bool]): - return where(a, 5.0, 9.0) - - parsed = FieldOperatorParser.apply_to_function(simple_where) - - assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - ) - - -def test_where_tuple_dim(): - ADim = Dimension("ADim") - - def tuple_where(a: Field[[ADim], bool], b: Field[[ADim], float64]): - return where(a, ((5.0, 9.0), (b, 6.0)), ((8.0, b), (5.0, 9.0))) - - parsed = FieldOperatorParser.apply_to_function(tuple_where) - - assert parsed.body.stmts[0].value.type == ts.TupleType( - types=[ - ts.TupleType( - types=[ - ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - ] - ), - ts.TupleType( - types=[ - ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - ] - ), - ] - ) - - -def test_where_bad_dim(): - ADim = Dimension("ADim") - - def bad_dim_where(a: Field[[ADim], bool], b: Field[[ADim], float64]): - return where(a, ((5.0, 9.0), (b, 6.0)), b) - - with pytest.raises(errors.DSLError, match=r"Return arguments need to be of same type"): - _ = FieldOperatorParser.apply_to_function(bad_dim_where) - - -def test_where_mixed_dims(): - ADim = Dimension("ADim") - BDim = Dimension("BDim") - - def tuple_where_mix_dims( - a: Field[[ADim], bool], b: Field[[ADim], float64], c: Field[[ADim, BDim], float64] - ): - return where(a, ((c, 9.0), (b, 6.0)), ((8.0, b), (5.0, 9.0))) - - parsed = FieldOperatorParser.apply_to_function(tuple_where_mix_dims) - - assert parsed.body.stmts[0].value.type == ts.TupleType( - types=[ - ts.TupleType( - types=[ - ts.FieldType( - dims=[ADim, BDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - ), - ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - ] - ), - ts.TupleType( - types=[ - ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - ] - ), - ] - ) - - -def test_astype_dtype(): - def simple_astype(a: Field[[TDim], float64]): - return astype(a, bool) - - parsed = FieldOperatorParser.apply_to_function(simple_astype) - - assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL) - ) - - -def test_astype_wrong_dtype(): - def simple_astype(a: Field[[TDim], float64]): - # we just use broadcast here, but anything with type function is fine - return astype(a, broadcast) - - with pytest.raises( - errors.DSLError, - match=r"Invalid call to 'astype': second argument must be a scalar type, got.", - ): - _ = FieldOperatorParser.apply_to_function(simple_astype) - - -def test_astype_wrong_value_type(): - def simple_astype(a: Field[[TDim], float64]): - # we just use broadcast here but anything that is not a field, scalar or tuple thereof works - return astype(broadcast, bool) - - with pytest.raises(errors.DSLError) as exc_info: - _ = FieldOperatorParser.apply_to_function(simple_astype) - - assert ( - re.search("Expected 1st argument to be of type", exc_info.value.__cause__.args[0]) - is not None - ) - - -def test_mod_floats(): - def modulo_floats(inp: Field[[TDim], float]): - return inp % 3.0 - - with pytest.raises(errors.DSLError, match=r"Type 'float64' can not be used in operator '%'"): - _ = FieldOperatorParser.apply_to_function(modulo_floats) - - -def test_undefined_symbols(): - def return_undefined(): - return undefined_symbol - - with pytest.raises(errors.DSLError, match="Undeclared symbol"): - _ = FieldOperatorParser.apply_to_function(return_undefined) - - -def test_as_offset_dim(): - ADim = Dimension("ADim") - BDim = Dimension("BDim") - Boff = FieldOffset("Boff", source=BDim, target=(BDim,)) - - def as_offset_dim(a: Field[[ADim, BDim], float], b: Field[[ADim], int]): - return a(as_offset(Boff, b)) - - with pytest.raises(errors.DSLError, match=f"not in list of offset field dimensions"): - _ = FieldOperatorParser.apply_to_function(as_offset_dim) - - -def test_as_offset_dtype(): - ADim = Dimension("ADim") - BDim = Dimension("BDim") - Boff = FieldOffset("Boff", source=BDim, target=(BDim,)) - - def as_offset_dtype(a: Field[[ADim, BDim], float], b: Field[[BDim], float]): - return a(as_offset(Boff, b)) - - with pytest.raises(errors.DSLError, match=f"expected integer for offset field dtype"): - _ = FieldOperatorParser.apply_to_function(as_offset_dtype) From ac253b6b10a8adf87313a48b62328debaf39f07f Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 3 Feb 2025 10:28:47 +0100 Subject: [PATCH 130/178] fix[next]: reshuffling for fields with non-zero domain start (#1845) --- src/gt4py/next/embedded/nd_array_field.py | 5 ++++- .../unit_tests/embedded_tests/test_nd_array_field.py | 7 ++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index e15fb4266a..537482508b 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -657,7 +657,10 @@ def _reshuffling_premap( conn_map[dim] = _identity_connectivity(new_domain, dim, cls=type(connectivity)) # Take data - take_indices = tuple(conn_map[dim].ndarray for dim in data.domain.dims) + take_indices = tuple( + conn_map[dim].ndarray - data.domain[dim].unit_range.start # shift to 0-based indexing + for dim in data.domain.dims + ) new_buffer = data._ndarray.__getitem__(take_indices) return data.__class__.from_array( diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 9dde5bb40a..9bdc6ab5c1 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -364,10 +364,11 @@ def test_reshuffling_premap(): ij_field = common._field( np.asarray([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]), - domain=common.Domain(dims=(I, J), ranges=(UnitRange(0, 3), UnitRange(0, 3))), + domain=common.Domain(dims=(I, J), ranges=(UnitRange(1, 4), UnitRange(2, 5))), ) + max_ij_conn = common._connectivity( - np.fromfunction(lambda i, j: np.maximum(i, j), (3, 3), dtype=int), + np.asarray([[1, 2, 3], [2, 2, 3], [3, 3, 3]], dtype=int), domain=common.Domain( dims=ij_field.domain.dims, ranges=ij_field.domain.ranges, @@ -378,7 +379,7 @@ def test_reshuffling_premap(): result = ij_field.premap(max_ij_conn) expected = common._field( np.asarray([[0.0, 4.0, 8.0], [3.0, 4.0, 8.0], [6.0, 7.0, 8.0]]), - domain=common.Domain(dims=(I, J), ranges=(UnitRange(0, 3), UnitRange(0, 3))), + domain=common.Domain(dims=(I, J), ranges=(UnitRange(1, 4), UnitRange(2, 5))), ) assert result.domain == expected.domain From c17b88259190a5eefee67f039150f4e92337fe61 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 4 Feb 2025 14:14:46 +0100 Subject: [PATCH 131/178] bug[next]: Fix for GTIR partial type inference (#1840) Co-authored-by: Edoardo Paone Co-authored-by: Hannes Vogt --- src/gt4py/next/iterator/type_system/type_synthesizer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index f5aeac7943..19ab3ecdda 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -141,7 +141,9 @@ def can_deref(it: it_ts.IteratorType | ts.DeferredType) -> ts.ScalarType: @_register_builtin_type_synthesizer -def if_(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType) -> ts.DataType: +def if_( + pred: ts.ScalarType | ts.DeferredType, true_branch: ts.DataType, false_branch: ts.DataType +) -> ts.DataType: if isinstance(true_branch, ts.TupleType) and isinstance(false_branch, ts.TupleType): return tree_map( collection_type=ts.TupleType, @@ -149,7 +151,9 @@ def if_(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType )(functools.partial(if_, pred))(true_branch, false_branch) assert not isinstance(true_branch, ts.TupleType) and not isinstance(false_branch, ts.TupleType) - assert isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL + assert isinstance(pred, ts.DeferredType) or ( + isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL + ) # TODO(tehrengruber): Enable this or a similar check. In case the true- and false-branch are # iterators defined on different positions this fails. For the GTFN backend we also don't # want this, but for roundtrip it is totally fine. From 14d18b30813a0cfc9d327fbf41434931542f9b1e Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 5 Feb 2025 16:14:11 +0100 Subject: [PATCH 132/178] build[next]: switch dace version to main branch from git repo (#1835) Change version of dace dependency in gt4py-next to the main branch of the dace git repository. The version used in gt4py-cartesian remains as before, dace v1.0.1 from the v1/maintenance branch. This PR includes some changes that were needed to comply with the new API of latest dace: - update usage of `ConditionalBlock` and `LoopRegion` - update SDFG transformations - workaround to deal with changes to CFG-tree in SDFG pattern matching --------- Co-authored-by: Philip Mueller --- .pre-commit-config.yaml | 2 +- noxfile.py | 6 +- pyproject.toml | 12 +- .../runners/dace/gtir_dataflow.py | 37 ++-- .../runners/dace/gtir_scan_translator.py | 4 + .../runners/dace/gtir_sdfg.py | 4 - .../dace/transformations/auto_optimize.py | 1 + .../runners/dace/transformations/gpu_utils.py | 139 ++++++++------ .../transformations/local_double_buffering.py | 2 +- .../dace/transformations/map_fusion_helper.py | 170 ++++++++---------- .../dace/transformations/map_fusion_serial.py | 4 +- .../runners/dace/transformations/simplify.py | 87 +++++++-- .../runners/dace/transformations/strides.py | 3 + .../runners/dace/transformations/utils.py | 87 +++++++-- .../test_distributed_buffer_relocator.py | 44 ++++- uv.lock | 85 ++++++--- 16 files changed, 448 insertions(+), 239 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index afca7bfa05..173997849a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: - repo: https://github.com/astral-sh/uv-pre-commit # uv version. - rev: 0.5.10 + rev: 0.5.25 hooks: - id: uv-lock diff --git a/noxfile.py b/noxfile.py index 0b150c0db7..e119669e92 100644 --- a/noxfile.py +++ b/noxfile.py @@ -61,6 +61,10 @@ "internal": {"extras": [], "markers": ["not requires_dace"]}, "dace": {"extras": ["dace"], "markers": ["requires_dace"]}, } +# Use dace-next for GT4Py-next, to install a different dace version than in cartesian +CodeGenNextTestSettings = CodeGenTestSettings | { + "dace": {"extras": ["dace-next"], "markers": ["requires_dace"]}, +} # -- nox sessions -- @@ -158,7 +162,7 @@ def test_next( ) -> None: """Run selected 'gt4py.next' tests.""" - codegen_settings = CodeGenTestSettings[codegen] + codegen_settings = CodeGenNextTestSettings[codegen] device_settings = DeviceTestSettings[device] groups: list[str] = ["test"] mesh_markers: list[str] = [] diff --git a/pyproject.toml b/pyproject.toml index 91c2ba0323..4a9071e9d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,7 +132,8 @@ all = ['gt4py[dace,formatting,jax,performance,testing]'] cuda11 = ['cupy-cuda11x>=12.0'] cuda12 = ['cupy-cuda12x>=12.0'] # features -dace = ['dace>=1.0.0,<1.1.0'] # v1.x will contain breaking changes, see https://github.com/spcl/dace/milestone/4 +dace = ['dace>=1.0.1,<1.1.0'] # v1.x will contain breaking changes, see https://github.com/spcl/dace/milestone/4 +dace-next = ['dace'] # pull dace latest version from the git repository formatting = ['clang-format>=9.0'] jax = ['jax>=0.4.26'] jax-cuda12 = ['jax[cuda12_local]>=0.4.26', 'gt4py[cuda12]'] @@ -438,6 +439,14 @@ conflicts = [ {extra = 'jax-cuda12'}, {extra = 'rocm4_3'}, {extra = 'rocm5_0'} + ], + [ + {extra = 'dace'}, + {extra = 'dace-next'} + ], + [ + {extra = 'all'}, + {extra = 'dace-next'} ] ] @@ -448,3 +457,4 @@ url = 'https://test.pypi.org/simple/' [tool.uv.sources] atlas4py = {index = "test.pypi"} +dace = {git = "https://github.com/spcl/dace", branch = "main", extra = "dace-next"} diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py index 584ce849e1..04d362b834 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -113,7 +113,7 @@ class IteratorExpr: field: dace.nodes.AccessNode gt_dtype: ts.ListType | ts.ScalarType - field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]] + field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymbolicType]] indices: dict[gtx_common.Dimension, DataExpr] def get_field_type(self) -> ts.FieldType: @@ -767,9 +767,6 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp assert len(node.args) == 3 - # TODO(edopao): enable once supported in next DaCe release - use_conditional_block: Final[bool] = False - # evaluate the if-condition that will write to a boolean scalar node condition_value = self.visit(node.args[0]) assert ( @@ -785,26 +782,18 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp nsdfg.debuginfo = gtir_sdfg_utils.debug_info(node, default=self.sdfg.debuginfo) # create states inside the nested SDFG for the if-branches - if use_conditional_block: - if_region = dace.sdfg.state.ConditionalBlock("if") - nsdfg.add_node(if_region) - entry_state = nsdfg.add_state("entry", is_start_block=True) - nsdfg.add_edge(entry_state, if_region, dace.InterstateEdge()) - - then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=nsdfg) - tstate = then_body.add_state("true_branch", is_start_block=True) - if_region.add_branch(dace.sdfg.state.CodeBlock("__cond"), then_body) - - else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=nsdfg) - fstate = else_body.add_state("false_branch", is_start_block=True) - if_region.add_branch(dace.sdfg.state.CodeBlock("not (__cond)"), else_body) - - else: - entry_state = nsdfg.add_state("entry", is_start_block=True) - tstate = nsdfg.add_state("true_branch") - nsdfg.add_edge(entry_state, tstate, dace.InterstateEdge(condition="__cond")) - fstate = nsdfg.add_state("false_branch") - nsdfg.add_edge(entry_state, fstate, dace.InterstateEdge(condition="not (__cond)")) + if_region = dace.sdfg.state.ConditionalBlock("if") + nsdfg.add_node(if_region) + entry_state = nsdfg.add_state("entry", is_start_block=True) + nsdfg.add_edge(entry_state, if_region, dace.InterstateEdge()) + + then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=nsdfg) + tstate = then_body.add_state("true_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock("__cond"), then_body) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=nsdfg) + fstate = else_body.add_state("false_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock("not (__cond)"), else_body) input_memlets: dict[str, MemletExpr | ValueExpr] = {} nsdfg_symbols_mapping: Optional[dict[str, dace.symbol]] = None diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py b/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py index 791440c37a..743b4d33e4 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py @@ -312,6 +312,10 @@ def _lower_lambda_to_nested_sdfg( # the lambda expression, i.e. body of the scan, will be created inside a nested SDFG. nsdfg = dace.SDFG(sdfg_builder.unique_nsdfg_name(sdfg, "scan")) nsdfg.debuginfo = gtir_sdfg_utils.debug_info(lambda_node, default=sdfg.debuginfo) + # We set `using_explicit_control_flow=True` because the vertical scan is lowered to a `LoopRegion`. + # This property is used by pattern matching in SDFG transformation framework + # to skip those transformations that do not yet support control flow blocks. + nsdfg.using_explicit_control_flow = True lambda_translator = sdfg_builder.setup_nested_context(lambda_node, nsdfg, lambda_symbols) # use the vertical dimension in the domain as scan dimension diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py index a58e8bcf8a..a4c0194849 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py @@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union import dace -from dace.sdfg import utils as dace_sdfg_utils from gt4py import eve from gt4py.eve import concepts @@ -999,9 +998,6 @@ def build_sdfg_from_gtir( sdfg = sdfg_genenerator.visit(ir) assert isinstance(sdfg, dace.SDFG) - # TODO(edopao): remove inlining when DaCe transformations support LoopRegion construct - dace_sdfg_utils.inline_loop_blocks(sdfg) - if disable_field_origin_on_program_arguments: _remove_field_origin_symbols(ir, sdfg) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 849730db76..8137c60959 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -280,6 +280,7 @@ def gt_auto_optimize( # For compatibility with DaCe (and until we found out why) the GT4Py # auto optimizer will emulate this behaviour. for state in sdfg.states(): + assert isinstance(state, dace.SDFGState) for edge in state.edges(): edge.data.wcr_nonatomic = False diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index 8bae56cd88..c2ac528647 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -160,62 +160,9 @@ def gt_gpu_transform_non_standard_memlet( correct loop order. - This function should be called after `gt_set_iteration_order()` has run. """ - new_maps: set[dace_nodes.MapEntry] = set() - - # This code is is copied from DaCe's code generator. - for e, state in list(sdfg.all_edges_recursive()): - nsdfg = state.parent - if ( - isinstance(e.src, dace_nodes.AccessNode) - and isinstance(e.dst, dace_nodes.AccessNode) - and e.src.desc(nsdfg).storage == dace_dtypes.StorageType.GPU_Global - and e.dst.desc(nsdfg).storage == dace_dtypes.StorageType.GPU_Global - ): - a: dace_nodes.AccessNode = e.src - b: dace_nodes.AccessNode = e.dst - copy_shape, src_strides, dst_strides, _, _ = dace_cpp.memlet_copy_to_absolute_strides( - None, nsdfg, state, e, a, b - ) - dims = len(copy_shape) - if dims == 1: - continue - elif dims == 2: - if src_strides[-1] != 1 or dst_strides[-1] != 1: - try: - is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1] - is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1] - except (TypeError, ValueError): - is_src_cont = False - is_dst_cont = False - if is_src_cont and is_dst_cont: - continue - else: - continue - elif dims > 2: - if not (src_strides[-1] != 1 or dst_strides[-1] != 1): - continue - - # For identifying the new map, we first store all neighbors of `a`. - old_neighbors_of_a: list[dace_nodes.AccessNode] = [ - edge.dst for edge in state.out_edges(a) - ] - - # Turn unsupported copy to a map - try: - dace_transformation.dataflow.CopyToMap.apply_to( - nsdfg, save=False, annotate=False, a=a, b=b - ) - except ValueError: # If transformation doesn't match, continue normally - continue - - # We find the new map by comparing the new neighborhood of `a` with the old one. - new_nodes: set[dace_nodes.MapEntry] = { - edge.dst for edge in state.out_edges(a) if edge.dst not in old_neighbors_of_a - } - assert any(isinstance(new_node, dace_nodes.MapEntry) for new_node in new_nodes) - assert len(new_nodes) == 1 - new_maps.update(new_nodes) + # Expand all non standard memlets and get the new MapEntries. + new_maps: set[dace_nodes.MapEntry] = _gt_expand_non_standard_memlets(sdfg) # If there are no Memlets that are translated to copy-Maps, then we have nothing to do. if len(new_maps) == 0: @@ -283,6 +230,88 @@ def restrict_fusion_to_newly_created_maps( return sdfg +def _gt_expand_non_standard_memlets( + sdfg: dace.SDFG, +) -> set[dace_nodes.MapEntry]: + """Finds all non standard Memlet in the SDFG and expand them. + + The function is used by `gt_gpu_transform_non_standard_memlet()` and performs + the actual expansion of the Memlet, i.e. turning all Memlets that can not be + expressed as a `memcpy()` into a Map, copy kernel. + The function will return the MapEntries of all expanded. + + The function will process the SDFG recursively. + """ + new_maps: set[dace_nodes.MapEntry] = set() + for nsdfg in sdfg.all_sdfgs_recursive(): + new_maps.update(_gt_expand_non_standard_memlets_sdfg(nsdfg)) + return new_maps + + +def _gt_expand_non_standard_memlets_sdfg( + sdfg: dace.SDFG, +) -> set[dace_nodes.MapEntry]: + """Implementation of `_gt_expand_non_standard_memlets()` that process a single SDFG.""" + new_maps: set[dace_nodes.MapEntry] = set() + # The implementation is based on DaCe's code generator. + for state in sdfg.states(): + for e in state.edges(): + # We are only interested in edges that connects two access nodes of GPU memory. + if not ( + isinstance(e.src, dace_nodes.AccessNode) + and isinstance(e.dst, dace_nodes.AccessNode) + and e.src.desc(sdfg).storage == dace_dtypes.StorageType.GPU_Global + and e.dst.desc(sdfg).storage == dace_dtypes.StorageType.GPU_Global + ): + continue + + a: dace_nodes.AccessNode = e.src + b: dace_nodes.AccessNode = e.dst + copy_shape, src_strides, dst_strides, _, _ = dace_cpp.memlet_copy_to_absolute_strides( + None, sdfg, state, e, a, b + ) + dims = len(copy_shape) + if dims == 1: + continue + elif dims == 2: + if src_strides[-1] != 1 or dst_strides[-1] != 1: + try: + is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1] + is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1] + except (TypeError, ValueError): + is_src_cont = False + is_dst_cont = False + if is_src_cont and is_dst_cont: + continue + else: + continue + elif dims > 2: + if not (src_strides[-1] != 1 or dst_strides[-1] != 1): + continue + + # For identifying the new map, we first store all neighbors of `a`. + old_neighbors_of_a: list[dace_nodes.AccessNode] = [ + edge.dst for edge in state.out_edges(a) + ] + + # Turn unsupported copy to a map + try: + dace_transformation.dataflow.CopyToMap.apply_to( + sdfg, save=False, annotate=False, a=a, b=b + ) + except ValueError: # If transformation doesn't match, continue normally + continue + + # We find the new map by comparing the new neighborhood of `a` with the old one. + new_nodes: set[dace_nodes.MapEntry] = { + edge.dst for edge in state.out_edges(a) if edge.dst not in old_neighbors_of_a + } + assert any(isinstance(new_node, dace_nodes.MapEntry) for new_node in new_nodes) + assert len(new_nodes) == 1 + new_maps.update(new_nodes) + return new_maps + + def gt_set_gpu_blocksize( sdfg: dace.SDFG, block_size: Optional[Sequence[int | str] | str], diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/local_double_buffering.py b/src/gt4py/next/program_processors/runners/dace/transformations/local_double_buffering.py index 02ecbe28e6..5201748e12 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/local_double_buffering.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/local_double_buffering.py @@ -38,7 +38,6 @@ def gt_create_local_double_buffering( it is not needed that the whole data is stored, but only the working set of a single thread. """ - processed_maps = 0 for nsdfg in sdfg.all_sdfgs_recursive(): processed_maps += _create_local_double_buffering_non_recursive(nsdfg) @@ -60,6 +59,7 @@ def _create_local_double_buffering_non_recursive( processed_maps = 0 for state in sdfg.states(): + assert isinstance(state, dace.SDFGState) scope_dict = state.scope_dict() for node in state.nodes(): if not isinstance(node, dace_nodes.MapEntry): diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_helper.py index eceb07ed82..03e5973c3c 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_helper.py @@ -73,11 +73,6 @@ class MapFusionHelper(transformation.SingleStateTransformation): # `False` then the fusion will be rejected. _apply_fusion_callback: Optional[FusionCallback] - # Maps SDFGs to the set of data that can not be removed, - # because they transmit data _between states_, such data will be made 'shared'. - # This variable acts as a cache, and is managed by 'is_shared_data()'. - _shared_data: Dict[SDFG, Set[str]] - def __init__( self, only_inner_maps: Optional[bool] = None, @@ -87,7 +82,7 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(**kwargs) - self._shared_data = {} + self._shared_data = {} # type: ignore[var-annotated] self._apply_fusion_callback = None if only_toplevel_maps is not None: self.only_toplevel_maps = bool(only_toplevel_maps) @@ -380,107 +375,94 @@ def rename_map_parameters( def is_shared_data( self, data: nodes.AccessNode, + state: dace.SDFGState, sdfg: dace.SDFG, ) -> bool: - """Tests if `data` is interstate data, an can not be removed. + """Tests if `data` is shared data, i.e. it can not be removed from the SDFG. - Interstate data is used to transmit data between multiple state or by - extension within the state. Thus it must be classified as a shared output. - This function will go through the SDFG to and collect the names of all data - container that should be classified as shared. Note that this is an over - approximation as it does not take the location into account, i.e. "is no longer - used". + Depending on the situation, the function will not perform a scan of the whole SDFG: + 1) If `data` is non transient then the function will return `True`, as non transient data + must be reconstructed always. + 2) If the AccessNode `data` has more than one outgoing edge or more than one incoming edge + it is classified as shared. + 3) If `FindSingleUseData` is in the pipeline it will be used and no scan will be performed. + 4) The function will perform a scan. - Args: - transient: The transient that should be checked. - sdfg: The SDFG containing the array. + :param data: The transient that should be checked. + :param state: The state in which the fusion is performed. + :param sdfg: The SDFG in which we want to perform the fusing. - Note: - The function computes the this set once for every SDFG and then caches it. - There is no mechanism to detect if the cache must be evicted. However, - as long as no additional data is added, there is no problem. """ - if sdfg not in self._shared_data: - self._compute_shared_data(sdfg) - return data.data in self._shared_data[sdfg] - - def _compute_shared_data( + # If `data` is non transient then return `True` as the intermediate can not be removed. + if not data.desc(sdfg).transient: + return True + + # This means the data is consumed by multiple Maps, through the same AccessNode, in this state + # Note currently multiple incoming edges are not handled, but in the spirit of this function + # we consider such AccessNodes as shared, because we can not remove the intermediate. + if state.out_degree(data) > 1: + return True + if state.in_degree(data) > 1: + return True + + # We have to perform the full scan of the SDFG. + return self._scan_sdfg_if_data_is_shared(data=data, state=state, sdfg=sdfg) + + def _scan_sdfg_if_data_is_shared( self, + data: nodes.AccessNode, + state: dace.SDFGState, sdfg: dace.SDFG, - ) -> None: - """Updates the internal set of shared data/interstate data of `self` for `sdfg`. + ) -> bool: + """Scans `sdfg` to determine if `data` is shared. - See the documentation for `self.is_shared_data()` for a description. + Essentially, this function determines if the intermediate AccessNode `data` + can be removed or if it has to be restored as output of the Map. + A data descriptor is classified as shared if any of the following is true: + - `data` is non transient data. + - `data` has at most one incoming and/or outgoing edge. + - There are other AccessNodes beside `data` that refer to the same data. + - The data is accessed on an interstate edge. - Args: - sdfg: The SDFG for which the set of shared data should be computed. + This function should not be called directly. Instead it is called indirectly + by `is_shared_data()` if there is no short cut. + + :param data: The AccessNode that should checked if it is shared. + :param sdfg: The SDFG for which the set of shared data should be computed. """ - # Shared data of this SDFG. - shared_data: Set[str] = set() - - # All global data can not be removed, so it must always be shared. - for data_name, data_desc in sdfg.arrays.items(): - if not data_desc.transient: - shared_data.add(data_name) - elif isinstance(data_desc, dace.data.Scalar): - shared_data.add(data_name) - - # We go through all states and classify the nodes/data: - # - Data is referred to in different states. - # - The access node is a view (both have to survive). - # - Transient sink or source node. - # - The access node has output degree larger than 1 (input degrees larger - # than one, will always be partitioned as shared anyway). - prevously_seen_data: Set[str] = set() - interstate_read_symbols: Set[str] = set() - for state in sdfg.nodes(): - for access_node in state.data_nodes(): - if access_node.data in shared_data: - # The data was already classified to be shared data - pass - - elif access_node.data in prevously_seen_data: - # We have seen this data before, either in this state or in - # a previous one, but we did not classifies it as shared back then - shared_data.add(access_node.data) - - if state.in_degree(access_node) == 0: - # (Transient) sink nodes are used in other states, or simplify - # will get rid of them. - shared_data.add(access_node.data) - - elif ( - state.out_degree(access_node) != 1 - ): # state.out_degree() == 0 or state.out_degree() > 1 - # The access node is either a source node (it is shared in another - # state) or the node has a degree larger than one, so it is used - # in this state somewhere else. - shared_data.add(access_node.data) - - elif self.is_view(node=access_node, sdfg=sdfg): - # To ensure that the write to the view happens, both have to be shared. - viewed_data: str = self.track_view( - view=access_node, state=state, sdfg=sdfg - ).data - shared_data.update([access_node.data, viewed_data]) - prevously_seen_data.update([access_node.data, viewed_data]) - - else: - # The node was not classified as shared data, so we record that - # we saw it. Note that a node that was immediately classified - # as shared node will never be added to this set, but a data - # that was found twice will be inside this list. - prevously_seen_data.add(access_node.data) - - # Now we are collecting all symbols that interstate edges read from. + if not data.desc(sdfg).transient: + return True + + # See description in `is_shared_data()` for more. + if state.out_degree(data) > 1: + return True + if state.in_degree(data) > 1: + return True + + data_name: str = data.data + for state in sdfg.states(): + for dnode in state.data_nodes(): + if dnode is data: + # We have found the `data` AccessNode, which we must ignore. + continue + if dnode.data == data_name: + # We found a different AccessNode that refers to the same data + # as `data`. Thus `data` is shared. + return True + + # Test if the data is referenced in the interstate edges. for edge in sdfg.edges(): - interstate_read_symbols.update(edge.data.read_symbols()) + if data_name in edge.data.free_symbols: + # The data is used in the inter state edges. So it is shared. + return True - # We also have to keep everything the edges referrers to and is an array. - shared_data.update(interstate_read_symbols.intersection(prevously_seen_data)) + # Test if they are accessed in a condition of a loop or conditional block. + for cfr in sdfg.all_control_flow_regions(): + if data_name in cfr.used_symbols(all_symbols=True, with_contents=False): + return True - # Update the internal cache - self._shared_data[sdfg] = shared_data + # The `data` is not used anywhere else, thus `data` is not shared. + return False def _compute_multi_write_data( self, @@ -522,7 +504,7 @@ def _compute_multi_write_data( def is_node_reachable_from( self, - graph: Union[dace.SDFG, dace.SDFGState], + graph: dace.SDFGState, begin: nodes.Node, end: nodes.Node, ) -> bool: diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py index 2cdcc455d4..0ef33cae97 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py @@ -621,7 +621,7 @@ def partition_first_outputs( # node can be removed (`\mathbb{E}`) or has to be restored (`\mathbb{S}`). # Note that "removed" here means that it is reconstructed by a new # output of the second map. - if self.is_shared_data(intermediate_node, sdfg): + if self.is_shared_data(data=intermediate_node, state=state, sdfg=sdfg): # The intermediate data is used somewhere else, either in this or another state. shared_outputs.add(out_edge) else: @@ -798,7 +798,7 @@ def handle_intermediate_set( # It will only have the shape `new_inter_shape` which is basically its # output within one Map iteration. # NOTE: The insertion process might generate a new name. - new_inter_name: str = f"__s{sdfg.node_id(state)}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" + new_inter_name: str = f"__s{self.state_id}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" # Now generate the intermediate data container. if len(new_inter_shape) == 0: diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py index e798df4596..f1fa65a716 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py @@ -151,6 +151,12 @@ def gt_inline_nested_sdfg( nb_preproccess_total = 0 nb_inlines_total = 0 while True: + # TODO(edopao): we call `reset_cfg_list()` as temporary workaround for a + # dace issue with pattern matching. Any time the SDFG's CFG-tree is modified, + # i.e. a loop is added/removed or something similar, the CFG list needs + # to be updated accordingly. Otherwise, all ID-based accesses are not going + # to work (which is what pattern matching attempts to do). + sdfg.reset_cfg_list() nb_preproccess = sdfg.apply_transformations_repeated( [dace_dataflow.PruneSymbols, dace_dataflow.PruneConnectors], validate=False, @@ -203,12 +209,15 @@ def gt_substitute_compiletime_symbols( repl: Maps the name of the symbol to the value it should be replaced with. validate: Perform validation at the end of the function. validate_all: Perform validation also on intermediate steps. + + Todo: This function needs improvement. """ # We will use the `replace` function of the top SDFG, however, lower levels # are handled using ConstantPropagation. sdfg.replace_dict(repl) + # TODO(phimuell): Get rid of the `ConstantPropagation` const_prop = dace_passes.ConstantPropagation() const_prop.recursive = True const_prop.progress = False @@ -334,9 +343,12 @@ def _is_read_downstream( write_g: dace_nodes.AccessNode = self.node_write_g tmp_node: dace_nodes.AccessNode = self.node_tmp + # TODO(phimuell): Run the `StateReachability` pass in a pipeline and use + # the `_pipeline_results` member to access the data. return gtx_transformations.utils.is_accessed_downstream( start_state=start_state, sdfg=sdfg, + reachable_states=None, data_to_look=data_to_look, nodes_to_ignore={read_g, write_g, tmp_node}, ) @@ -429,21 +441,29 @@ def should_reapply(self, modified: dace_ppl.Modifies) -> bool: def depends_on(self) -> set[type[dace_transformation.Pass]]: return { dace_transformation.passes.StateReachability, - dace_transformation.passes.AccessSets, + dace_transformation.passes.FindAccessStates, } def apply_pass( self, sdfg: dace.SDFG, pipeline_results: dict[str, Any] ) -> Optional[dict[dace.SDFGState, set[str]]]: + # NOTE: We can not use `AccessSets` because this pass operates on + # `ControlFlowBlock`s, which might consists of multiple states. Thus we are + # using `FindAccessStates` which has this `SDFGState` granularity. The downside + # is, however, that we have to determine if the access in that state is a + # write or not, which means we have to find it first. + access_states: dict[str, set[dace.SDFGState]] = pipeline_results["FindAccessStates"][ + sdfg.cfg_id + ] + + # For speeding up the `is_accessed_downstream()` calls. reachable: dict[dace.SDFGState, set[dace.SDFGState]] = pipeline_results[ "StateReachability" ][sdfg.cfg_id] - access_sets: dict[dace.SDFGState, tuple[set[str], set[str]]] = pipeline_results[ - "AccessSets" - ][sdfg.cfg_id] + result: dict[dace.SDFGState, set[str]] = collections.defaultdict(set) - to_relocate = self._find_candidates(sdfg, reachable, access_sets) + to_relocate = self._find_candidates(sdfg, reachable, access_states) if len(to_relocate) == 0: return None self._relocate_write_backs(sdfg, to_relocate) @@ -485,7 +505,7 @@ def _find_candidates( self, sdfg: dace.SDFG, reachable: dict[dace.SDFGState, set[dace.SDFGState]], - access_sets: dict[dace.SDFGState, tuple[set[str], set[str]]], + access_states: dict[str, set[dace.SDFGState]], ) -> list[tuple[AccessLocation, list[AccessLocation]]]: """Determines all temporaries that have to be relocated. @@ -515,9 +535,7 @@ def _find_candidates( if len(candidate_dst_nodes) == 0: continue - for temp_storage in state.source_nodes(): - if not isinstance(temp_storage, dace_nodes.AccessNode): - continue + for temp_storage in state.data_nodes(): if not temp_storage.desc(sdfg).transient: continue if state.out_degree(temp_storage) != 1: @@ -548,7 +566,11 @@ def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: temp_storage_node, temp_storage_state = temp_storage def_locations: list[AccessLocation] = [] for upstream_state in find_upstream_states(temp_storage_state): - if temp_storage_node.data in access_sets[upstream_state][1]: + if self._is_written_to_in_state( + data=temp_storage_node.data, + state=upstream_state, + access_states=access_states, + ): # NOTE: We do not impose any restriction on `temp_storage`. Thus # It could be that we do read from it (we can never write to it) # in this state or any other state later. @@ -592,11 +614,13 @@ def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: if gtx_transformations.utils.is_accessed_downstream( start_state=def_state, sdfg=sdfg, + reachable_states=reachable, data_to_look=wb_node.data, nodes_to_ignore={def_node, wb_node}, ): break - # check if the global data is not used between the definition of + + # Check if the global data is not used between the definition of # `dest_storage` and where its written back. However, we ignore # the state were `temp_storage` is defined. The checks if these # checks are performed by the `_check_read_write_dependency()` @@ -605,9 +629,14 @@ def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: global_nodes_in_def_state = { dnode for dnode in def_state.data_nodes() if dnode.data == global_data_name } + + # The `is_accessed_downstream()` function has some odd behaviour + # regarding `states_to_ignore`. Because of the special SDFGs we have + # this should not be an issue. if gtx_transformations.utils.is_accessed_downstream( start_state=def_state, sdfg=sdfg, + reachable_states=reachable, data_to_look=global_data_name, nodes_to_ignore=global_nodes_in_def_state, states_to_ignore={wb_state}, @@ -620,6 +649,36 @@ def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: return result + def _is_written_to_in_state( + self, + data: str, + state: dace.SDFGState, + access_states: dict[str, set[dace.SDFGState]], + ) -> bool: + """This function determines if there is a write to data `data` in state `state`. + + Args: + data: Name of the data descriptor that should be tested. + state: The state that should be examined. + access_states: The set of state that writes to a specific data. + """ + assert data in access_states, f"Did not found '{data}' in 'access_states'." + + # According to `access_states` `data` is not accessed inside `state`. + # Therefore there is no write. + if state not in access_states[data]: + return False + + # There is an AccessNode for `data` inside `state`. Now we have to find the + # node and determine if it is a write or not. + for dnode in state.data_nodes(): + if dnode.data != data: + continue + if state.in_degree(dnode) > 0: + return True + + return False + def _check_read_write_dependency( self, sdfg: dace.SDFG, @@ -889,9 +948,11 @@ def apply( # The data is no longer referenced in this state, so we can potentially # remove if graph.out_degree(access_node) == 0: + # TODO(phimuell): Use the pipeline to run `StateReachability` once. if not gtx_transformations.utils.is_accessed_downstream( start_state=graph, sdfg=sdfg, + reachable_states=None, data_to_look=access_node.data, nodes_to_ignore={access_node}, ): @@ -952,6 +1013,7 @@ class GT4PyMapBufferElimination(dace_transformation.SingleStateTransformation): Todo: - Implement a real pointwise test. + - Run this inside a pipeline. """ map_exit = dace_transformation.PatternNode(dace_nodes.MapExit) @@ -1015,9 +1077,12 @@ def can_be_applied( # Test if `tmp` is only anywhere else, this is important for removing it. if graph.out_degree(tmp_ac) != 1: return False + # TODO(phimuell): Use the pipeline system to run the `StateReachability` pass + # only once. Taking care of DaCe issue 1911. if gtx_transformations.utils.is_accessed_downstream( start_state=graph, sdfg=sdfg, + reachable_states=None, data_to_look=tmp_ac.data, nodes_to_ignore={tmp_ac}, ): diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace/transformations/strides.py index 9af76e5b57..c037535124 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/strides.py @@ -192,6 +192,8 @@ def gt_propagate_strides_from_access_node( processed_nsdfgs: Set of NestedSDFG that were already processed and will be ignored. Only specify when you know what your are doing. """ + assert isinstance(state, dace.SDFGState) + if processed_nsdfgs is None: # For preventing the case that nested SDFGs are handled multiple time. processed_nsdfgs = set() @@ -631,6 +633,7 @@ def _gt_find_toplevel_data_accesses( not_top_level_data: set[str] = set() for state in sdfg.states(): + assert isinstance(state, dace.SDFGState) scope_dict = state.scope_dict() for dnode in state.data_nodes(): data: str = dnode.data diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py index 3cc2dadd89..d315f99264 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -13,6 +13,7 @@ import dace from dace import data as dace_data from dace.sdfg import nodes as dace_nodes +from dace.transformation.passes import analysis as dace_analysis from gt4py.next.program_processors.runners.dace import utils as gtx_dace_utils @@ -134,13 +135,14 @@ def is_accessed_downstream( start_state: dace.SDFGState, sdfg: dace.SDFG, data_to_look: str, + reachable_states: Optional[dict[dace.SDFGState, set[dace.SDFGState]]], nodes_to_ignore: Optional[set[dace_nodes.AccessNode]] = None, states_to_ignore: Optional[set[dace.SDFGState]] = None, ) -> bool: """Scans for accesses to the data container `data_to_look`. The function will go through states that are reachable from `start_state` - (included) and test if there is an AccessNode that refers to `data_to_look`. + (included) and test if there is an AccessNode that _reads_ from `data_to_look`. It will return `True` the first time it finds such a node. The function will ignore all nodes that are listed in `nodes_to_ignore`. @@ -151,35 +153,82 @@ def is_accessed_downstream( start_state: The state where the scanning starts. sdfg: The SDFG on which we operate. data_to_look: The data that we want to look for. + reachable_states: Maps an `SDFGState` to all `SDFGState`s that can be reached. + If `None` it will be computed, but this is not recommended. nodes_to_ignore: Ignore these nodes. states_to_ignore: Ignore these states. + + Note: + Currently, the function will not only ignore the states that are listed in + `states_to_ignore`, but all that are reachable from any of these states. + Thus care must be taken when this option is used. Furthermore, this behaviour + is not intended and will change in further versions. + `reachable_states` can be computed by using the `StateReachability` analysis + pass from DaCe. + + Todo: + - Modify the function such that it is no longer necessary to pass the + `reachable_states` argument. + - Fix the behaviour for `states_to_ignore`. """ - seen_states: set[dace.SDFGState] = set() - to_visit: list[dace.SDFGState] = [start_state] + # After DaCe 1 switched to a hierarchical version of the state machine. Thus + # it is no longer possible in a simple way to traverse the SDFG. As a temporary + # solution we use the `StateReachability` pass. However, this has some issues, + # see the note about `states_to_ignore`. + if reachable_states is None: + state_reachability_pass = dace_analysis.StateReachability() + reachable_states = state_reachability_pass.apply_pass(sdfg, None)[sdfg.cfg_id] + else: + # Ensures that the externally generated result was passed properly. + assert all( + isinstance(state, dace.SDFGState) and state.sdfg is sdfg for state in reachable_states + ) + ign_dnodes: set[dace_nodes.AccessNode] = nodes_to_ignore or set() ign_states: set[dace.SDFGState] = states_to_ignore or set() - while len(to_visit) > 0: - state = to_visit.pop() - seen_states.add(state) - for dnode in state.data_nodes(): + # NOTE: We have to include `start_state`, however, we must also consider the + # data in `reachable_states` as immutable, so we have to do it this way. + # TODO(phimuell): Go back to a trivial scan of the graph. + if start_state not in reachable_states: + # This can mean different things, either there was only one state to begin + # with or `start_state` is the last one. In this case the `states_to_scan` + # set consists only of the `start_state` because we have to process it. + states_to_scan = {start_state} + else: + # Ensure that `start_state` is scanned. + states_to_scan = reachable_states[start_state].union([start_state]) + + # In the first version we explored the state machine and if we encountered a + # state in the ignore set we simply ignored it. This is no longer possible. + # Instead we will remove all states from the `states_to_scan` that are reachable + # from an ignored state. However, this is not the same as if we would explore + # the state machine (as we did before). Consider the following case: + # + # (STATE_1) ------------> (STATE_2) + # | /\ + # V | + # (STATE_3) ------------------+ + # + # Assume that `STATE_1` is the starting state and `STATE_3` is ignored. + # If we would explore the state machine, we would still scan `STATE_2`. + # However, because `STATE_2` is also reachable from `STATE_3` it will now be + # ignored. In most cases this should be fine, but we have to handle it. + states_to_scan.difference_update(ign_states) + for ign_state in ign_states: + states_to_scan.difference_update(reachable_states.get(ign_state, set())) + assert start_state in states_to_scan + + for downstream_state in states_to_scan: + if downstream_state in ign_states: + continue + for dnode in downstream_state.data_nodes(): if dnode.data != data_to_look: continue if dnode in ign_dnodes: continue - if state.out_degree(dnode) != 0: + if downstream_state.out_degree(dnode) != 0: return True # There is a read operation - - # Look for new states, also scan the interstate edges. - for out_edge in sdfg.out_edges(state): - if out_edge.dst in ign_states: - continue - if data_to_look in out_edge.data.read_symbols(): - return True - if out_edge.dst in seen_states: - continue - to_visit.append(out_edge.dst) - return False diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py index ae3624ce13..d3aadf8927 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py @@ -18,6 +18,7 @@ ) from . import util +import dace def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: @@ -343,6 +344,47 @@ def test_distributed_buffer_non_sink_temporary(): assert wb_state.number_of_nodes() == 4 res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) - sdfg.view() assert res[sdfg]["DistributedBufferRelocator"][wb_state] == {"t1", "t2"} assert wb_state.number_of_nodes() == 0 + + +def _make_distributed_buffer_conditional_block_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG("distributed_buffer_conditional_block_sdfg") + + for name in ["a", "b", "c", "t"]: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + sdfg.add_symbol("cond", dace.bool_) + + # create states inside the nested SDFG for the if-branches + if_region = dace.sdfg.state.ConditionalBlock("if") + sdfg.add_node(if_region) + entry_state = sdfg.add_state("entry", is_start_block=True) + sdfg.add_edge(entry_state, if_region, dace.InterstateEdge()) + + then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=sdfg) + tstate = then_body.add_state("true_branch", is_start_block=True) + tstate.add_nedge(tstate.add_access("a"), tstate.add_access("t"), dace.Memlet("a[0:10]")) + if_region.add_branch(dace.sdfg.state.CodeBlock("cond"), then_body) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=sdfg) + fstate = else_body.add_state("false_branch", is_start_block=True) + fstate.add_nedge(fstate.add_access("b"), fstate.add_access("t"), dace.Memlet("b[0:10]")) + if_region.add_branch(dace.sdfg.state.CodeBlock("not (cond)"), else_body) + + wb_state = sdfg.add_state_after(if_region) + wb_state.add_nedge(wb_state.add_access("t"), wb_state.add_access("c"), dace.Memlet("t[0:10]")) + sdfg.validate() + return sdfg, wb_state + + +def test_distributed_buffer_conditional_block(): + sdfg, wb_state = _make_distributed_buffer_conditional_block_sdfg() + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert res[sdfg]["DistributedBufferRelocator"][wb_state] == {"t"} diff --git a/uv.lock b/uv.lock index 1d050717af..60f62028bd 100644 --- a/uv.lock +++ b/uv.lock @@ -9,6 +9,12 @@ conflicts = [[ { package = "gt4py", extra = "jax-cuda12" }, { package = "gt4py", extra = "rocm4-3" }, { package = "gt4py", extra = "rocm5-0" }, +], [ + { package = "gt4py", extra = "dace" }, + { package = "gt4py", extra = "dace-next" }, +], [ + { package = "gt4py", extra = "all" }, + { package = "gt4py", extra = "dace-next" }, ]] [[package]] @@ -175,8 +181,8 @@ dependencies = [ { name = "packaging" }, { name = "pathspec" }, { name = "platformdirs" }, - { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, - { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/94/49/26a7b0f3f35da4b5a65f081943b7bcd22d7002f5f0fb8098ec1ff21cb6ef/black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666", size = 649449 } wheels = [ @@ -261,8 +267,8 @@ version = "24.1.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, - { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, - { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/64/65/af6d57da2cb32c076319b7489ae0958f746949d407109e3ccf4d115f147c/cattrs-24.1.2.tar.gz", hash = "sha256:8028cfe1ff5382df59dd36474a86e02d817b06eaf8af84555441bac915d2ef85", size = 426462 } wheels = [ @@ -384,7 +390,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -431,7 +437,7 @@ name = "colorlog" version = "6.9.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d3/7a/359f4d5df2353f26172b3cc39ea32daa39af8de522205f512f458923e677/colorlog-6.9.0.tar.gz", hash = "sha256:bfba54a1b93b94f54e1f4fe48395725a3d92fd2a4af702f6bd70946bdc0c6ac2", size = 16624 } wheels = [ @@ -515,7 +521,7 @@ wheels = [ [package.optional-dependencies] toml = [ - { name = "tomli", marker = "python_full_version <= '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "tomli", marker = "python_full_version <= '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, ] [[package]] @@ -668,10 +674,36 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/b9/fe9da37090b6444c65f848a83e390f87d8cb43d6a4df46de1556ad7e5ceb/cytoolz-1.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:3237e56211e03b13df47435b2369f5df281e02b04ad80a948ebd199b7bc10a47", size = 343358 }, ] +[[package]] +name = "dace" +version = "1.0.0" +source = { git = "https://github.com/spcl/dace?branch=main#118c1312961dc1146f43d5b15cde4b97e067d9cb" } +resolution-markers = [ + "python_full_version < '3.11'", + "python_full_version >= '3.11'", +] +dependencies = [ + { name = "aenum" }, + { name = "astunparse" }, + { name = "dill" }, + { name = "fparser" }, + { name = "networkx" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "ply" }, + { name = "pyreadline", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "pyyaml" }, + { name = "sympy" }, +] + [[package]] name = "dace" version = "1.0.1" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", + "python_full_version >= '3.11'", +] dependencies = [ { name = "aenum" }, { name = "astunparse" }, @@ -681,7 +713,7 @@ dependencies = [ { name = "numpy" }, { name = "packaging" }, { name = "ply" }, - { name = "pyreadline", marker = "platform_system == 'Windows' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "pyreadline", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "pyyaml" }, { name = "sympy" }, ] @@ -998,7 +1030,6 @@ wheels = [ [[package]] name = "gt4py" -version = "1.0.4" source = { editable = "." } dependencies = [ { name = "attrs" }, @@ -1033,7 +1064,7 @@ dependencies = [ [package.optional-dependencies] all = [ { name = "clang-format" }, - { name = "dace" }, + { name = "dace", version = "1.0.1", source = { registry = "https://pypi.org/simple" } }, { name = "hypothesis" }, { name = "jax" }, { name = "pytest" }, @@ -1046,7 +1077,10 @@ cuda12 = [ { name = "cupy-cuda12x" }, ] dace = [ - { name = "dace" }, + { name = "dace", version = "1.0.1", source = { registry = "https://pypi.org/simple" } }, +] +dace-next = [ + { name = "dace", version = "1.0.0", source = { git = "https://github.com/spcl/dace?branch=main#118c1312961dc1146f43d5b15cde4b97e067d9cb" } }, ] formatting = [ { name = "clang-format" }, @@ -1170,7 +1204,8 @@ requires-dist = [ { name = "cupy-rocm-4-3", marker = "extra == 'rocm4-3'", specifier = ">=13.3.0" }, { name = "cupy-rocm-5-0", marker = "extra == 'rocm5-0'", specifier = ">=13.3.0" }, { name = "cytoolz", specifier = ">=0.12.1" }, - { name = "dace", marker = "extra == 'dace'", specifier = ">=1.0.0,<1.1.0" }, + { name = "dace", marker = "extra == 'dace'", specifier = ">=1.0.1,<1.1.0" }, + { name = "dace", marker = "extra == 'dace-next'", git = "https://github.com/spcl/dace?branch=main" }, { name = "deepdiff", specifier = ">=5.6.0" }, { name = "devtools", specifier = ">=0.6" }, { name = "diskcache", specifier = ">=5.6.3" }, @@ -1359,7 +1394,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "platform_system == 'Darwin' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "appnope", marker = "sys_platform == 'darwin' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -1383,12 +1418,12 @@ name = "ipython" version = "8.31.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "decorator" }, - { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "jedi" }, { name = "matplotlib-inline" }, - { name = "pexpect", marker = "(sys_platform != 'emscripten' and sys_platform != 'win32') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "pexpect", marker = "(sys_platform != 'emscripten' and sys_platform != 'win32') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "prompt-toolkit" }, { name = "pygments" }, { name = "stack-data" }, @@ -1538,7 +1573,7 @@ version = "5.7.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "platformdirs" }, - { name = "pywin32", marker = "(platform_python_implementation != 'PyPy' and sys_platform == 'win32') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "pywin32", marker = "(platform_python_implementation != 'PyPy' and sys_platform == 'win32') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "traitlets" }, ] sdist = { url = "https://files.pythonhosted.org/packages/00/11/b56381fa6c3f4cc5d2cf54a7dbf98ad9aa0b339ef7a601d6053538b079a7/jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9", size = 87629 } @@ -1556,7 +1591,7 @@ dependencies = [ { name = "nbformat" }, { name = "packaging" }, { name = "pyyaml" }, - { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/10/e7/58d6fd374e1065d2bccefd07953d2f1f911d8de03fd7dc33dd5a25ac659c/jupytext-1.16.6.tar.gz", hash = "sha256:dbd03f9263c34b737003f388fc069e9030834fb7136879c4c32c32473557baa0", size = 3726029 } wheels = [ @@ -1821,7 +1856,7 @@ version = "1.14.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "mypy-extensions" }, - { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "typing-extensions" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/eb/2c92d8ea1e684440f54fa49ac5d9a5f19967b7b472a281f419e69a8d228e/mypy-1.14.1.tar.gz", hash = "sha256:7ec88144fe9b510e8475ec2f5f251992690fcf89ccb4500b214b4226abcd32d6", size = 3216051 } @@ -1995,7 +2030,7 @@ dependencies = [ { name = "argcomplete" }, { name = "colorlog" }, { name = "packaging" }, - { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "virtualenv" }, ] sdist = { url = "https://files.pythonhosted.org/packages/08/93/4df547afcd56e0b2bbaa99bc2637deb218a01802ed62d80f763189be802c/nox-2024.10.9.tar.gz", hash = "sha256:7aa9dc8d1c27e9f45ab046ffd1c3b2c4f7c91755304769df231308849ebded95", size = 4003197 } @@ -2575,7 +2610,7 @@ name = "pyzmq" version = "26.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cffi", marker = "implementation_name == 'pypy' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "cffi", marker = "implementation_name == 'pypy' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/fd/05/bed626b9f7bb2322cdbbf7b4bd8f54b1b617b0d2ab2d3547d6e39428a48e/pyzmq-26.2.0.tar.gz", hash = "sha256:070672c258581c8e4f640b5159297580a9974b026043bd4ab0470be9ed324f1f", size = 271975 } wheels = [ @@ -2658,7 +2693,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markdown-it-py" }, { name = "pygments" }, - { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ab/3a/0316b28d0761c6734d6bc14e770d85506c986c85ffb239e688eeaab2c2bc/rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098", size = 223149 } wheels = [ @@ -2730,7 +2765,7 @@ name = "ruamel-yaml" version = "0.18.10" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "ruamel-yaml-clib", marker = "platform_python_implementation == 'CPython' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "ruamel-yaml-clib", marker = "platform_python_implementation == 'CPython' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ea/46/f44d8be06b85bc7c4d8c95d658be2b68f27711f279bf9dd0612a5e4794f5/ruamel.yaml-0.18.10.tar.gz", hash = "sha256:20c86ab29ac2153f80a428e1254a8adf686d3383df04490514ca3b79a362db58", size = 143447 } wheels = [ @@ -2890,7 +2925,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alabaster" }, { name = "babel" }, - { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "docutils" }, { name = "imagesize" }, { name = "jinja2" }, @@ -2904,7 +2939,7 @@ dependencies = [ { name = "sphinxcontrib-jsmath" }, { name = "sphinxcontrib-qthelp" }, { name = "sphinxcontrib-serializinghtml" }, - { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/6f/6d/be0b61178fe2cdcb67e2a92fc9ebb488e3c51c4f74a36a7824c0adf23425/sphinx-8.1.3.tar.gz", hash = "sha256:43c1911eecb0d3e161ad78611bc905d1ad0e523e4ddc202a58a821773dc4c927", size = 8184611 } wheels = [ From d916ae57bc93f73e38dd3a1b6cdcf02f23259921 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 6 Feb 2025 12:26:13 +0100 Subject: [PATCH 133/178] feat[next]: Improve fieldop fusion (#1764) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add support for inlining into scans. - Fuse `make_tuple(as_fieldop(...), as_fieldop(...))` calls into `as_fieldop(λ(...) → make_tuple(...))(...)`. - Refactor pass such that inlining decision is expressed in a dedicated function `_arg_inline_predicate`. - Inline all let vars with dtype list. - Performance improvement: Stop visiting when reaching a stencil. - Bugfix for inlining of `as_fieldop` args that use the same arg twice, e.g. `as_fieldop(...)(a, b)`. - Bugfix such that only expressions inside the expr of an `itir.SetAt` are considered. --------- Co-authored-by: Hannes Vogt Co-authored-by: Edoardo Paone Co-authored-by: Sara Faghih-Naini --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 14 +- .../iterator/transforms/collapse_tuple.py | 5 +- .../iterator/transforms/constant_folding.py | 5 + .../transforms/fixed_point_transformation.py | 2 +- .../iterator/transforms/fuse_as_fieldop.py | 327 +++++++++++++++--- .../inline_center_deref_lift_vars.py | 14 +- .../next/iterator/transforms/inline_lifts.py | 9 +- .../next/iterator/transforms/inline_scalar.py | 2 + .../next/iterator/transforms/merge_let.py | 3 + .../next/iterator/transforms/pass_manager.py | 4 + .../next/iterator/transforms/trace_shifts.py | 12 +- .../iterator/type_system/type_synthesizer.py | 43 ++- .../transforms_tests/test_fuse_as_fieldop.py | 261 +++++++++++++- 13 files changed, 614 insertions(+), 87 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 24842ad3be..42b82ffdd0 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -436,7 +436,7 @@ def domain( ) -def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> call: +def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Callable: """ Create an `as_fieldop` call. @@ -445,7 +445,9 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> cal >>> str(as_fieldop(lambda_("it1", "it2")(plus(deref("it1"), deref("it2"))))("field1", "field2")) '(⇑(λ(it1, it2) → ·it1 + ·it2))(field1, field2)' """ - return call( + from gt4py.next.iterator.ir_utils import domain_utils + + result = call( call("as_fieldop")( *( ( @@ -458,6 +460,14 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> cal ) ) + def _populate_domain_annex_wrapper(*args, **kwargs): + node = result(*args, **kwargs) + if domain: + node.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + return node + + return _populate_domain_annex_wrapper + def op_as_fieldop( op: str | itir.SymRef | itir.Lambda | Callable, domain: Optional[itir.FunCall] = None diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 6db58f3765..462f87b600 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -147,7 +147,7 @@ def all(self) -> CollapseTuple.Transformation: ignore_tuple_size: bool enabled_transformations: Transformation = Transformation.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] - PRESERVED_ANNEX_ATTRS = ("type",) + PRESERVED_ANNEX_ATTRS = ("type", "domain") @classmethod def apply( @@ -236,6 +236,7 @@ def transform_collapse_make_tuple_tuple_get( # tuple argument differs, just continue with the rest of the tree return None + itir_type_inference.reinfer(first_expr) # type is needed so reinfer on-demand assert self.ignore_tuple_size or isinstance( first_expr.type, (ts.TupleType, ts.DeferredType) ) @@ -255,7 +256,7 @@ def transform_collapse_tuple_get_make_tuple( and cpm.is_call_to(node.args[1], "make_tuple") ): # `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` - assert type_info.is_integer(node.args[0].type) + assert not node.args[0].type or type_info.is_integer(node.args[0].type) make_tuple_call = node.args[1] idx = int(node.args[0].value) assert idx < len( diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 7215d0787a..0dc324f94c 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -12,6 +12,11 @@ class ConstantFolding(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + @classmethod def apply(cls, node: ir.Node) -> ir.Node: return cls().visit(node) diff --git a/src/gt4py/next/iterator/transforms/fixed_point_transformation.py b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py index be34af846b..f1176b4bef 100644 --- a/src/gt4py/next/iterator/transforms/fixed_point_transformation.py +++ b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py @@ -61,7 +61,7 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: if result is not None: assert ( result is not node - ) # transformation should have returned None, since nothing changed + ), f"Transformation {transformation.name.lower()} should have returned None, since nothing changed." itir_type_inference.reinfer(result) return result return None diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index cc42896f2b..81633dfb87 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -5,19 +5,29 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations import dataclasses +import enum +import functools +import operator from typing import Optional from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.transforms import ( + fixed_point_transformation, inline_center_deref_lift_vars, inline_lambdas, inline_lifts, + merge_let, trace_shifts, ) from gt4py.next.iterator.type_system import inference as type_inference @@ -50,7 +60,6 @@ def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: if cpm.is_ref_to(stencil, "deref"): stencil = im.lambda_("arg")(im.deref("arg")) new_expr = im.as_fieldop(stencil, domain)(*expr.args) - type_inference.copy_type(from_=expr, to=new_expr, allow_untyped=True) return new_expr @@ -80,7 +89,12 @@ def _inline_as_fieldop_arg( for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): if isinstance(inner_arg, itir.SymRef): - stencil_params.append(inner_param) + if inner_arg.id in extracted_args: + assert extracted_args[inner_arg.id] == inner_arg + alias = stencil_params[list(extracted_args.keys()).index(inner_arg.id)] + stencil_body = im.let(inner_param, im.ref(alias.id))(stencil_body) + else: + stencil_params.append(inner_param) extracted_args[inner_arg.id] = inner_arg elif isinstance(inner_arg, itir.Literal): # note: only literals, not all scalar expressions are required as it doesn't make sense @@ -100,12 +114,59 @@ def _inline_as_fieldop_arg( ), extracted_args +def _unwrap_scan(stencil: itir.Lambda | itir.FunCall): + """ + If given a scan, extract stencil part of its scan pass and a back-transformation into a scan. + + If a regular stencil is given the stencil is left as-is and the back-transformation is the + identity function. This function allows treating a scan stencil like a regular stencil during + a transformation avoiding the complexity introduced by the different IR format. + + >>> scan = im.call("scan")( + ... im.lambda_("state", "arg")(im.plus("state", im.deref("arg"))), True, 0.0 + ... ) + >>> stencil, back_trafo = _unwrap_scan(scan) + >>> str(stencil) + 'λ(arg) → state + ·arg' + >>> str(back_trafo(stencil)) + 'scan(λ(state, arg) → (λ(arg) → state + ·arg)(arg), True, 0.0)' + + In case a regular stencil is given it is returned as-is: + + >>> deref_stencil = im.lambda_("it")(im.deref("it")) + >>> stencil, back_trafo = _unwrap_scan(deref_stencil) + >>> assert stencil == deref_stencil + """ + if cpm.is_call_to(stencil, "scan"): + scan_pass, direction, init = stencil.args + assert isinstance(scan_pass, itir.Lambda) + # remove scan pass state to be used by caller + state_param = scan_pass.params[0] + stencil_like = im.lambda_(*scan_pass.params[1:])(scan_pass.expr) + + def restore_scan(transformed_stencil_like: itir.Lambda): + new_scan_pass = im.lambda_(state_param, *transformed_stencil_like.params)( + im.call(transformed_stencil_like)( + *(param.id for param in transformed_stencil_like.params) + ) + ) + return im.call("scan")(new_scan_pass, direction, init) + + return stencil_like, restore_scan + + assert isinstance(stencil, itir.Lambda) + return stencil, lambda s: s + + def fuse_as_fieldop( expr: itir.Expr, eligible_args: list[bool], *, uids: eve_utils.UIDGenerator ) -> itir.Expr: - assert cpm.is_applied_as_fieldop(expr) and isinstance(expr.fun.args[0], itir.Lambda) # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + assert cpm.is_applied_as_fieldop(expr) stencil: itir.Lambda = expr.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + assert isinstance(expr.fun.args[0], itir.Lambda) or cpm.is_call_to(stencil, "scan") # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + stencil, restore_scan = _unwrap_scan(stencil) + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by is_applied_as_fieldop args: list[itir.Expr] = expr.args @@ -118,10 +179,9 @@ def fuse_as_fieldop( if cpm.is_applied_as_fieldop(arg): pass elif cpm.is_call_to(arg, "if_"): + # transform scalar `if` into per-grid-point `if` # TODO(tehrengruber): revisit if we want to inline if_ - type_ = arg.type arg = im.op_as_fieldop("if_")(*arg.args) - arg.type = type_ elif _is_tuple_expr_of_literals(arg): arg = im.op_as_fieldop(im.lambda_()(arg))() else: @@ -134,6 +194,7 @@ def fuse_as_fieldop( new_args = _merge_arguments(new_args, extracted_args) else: # just a safety check if typing information is available + type_inference.reinfer(arg) if arg.type and not isinstance(arg.type, ts.DeferredType): assert isinstance(arg.type, ts.TypeSpec) dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) @@ -148,26 +209,72 @@ def fuse_as_fieldop( new_param = stencil_param.id new_args = _merge_arguments(new_args, {new_param: arg}) - new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( - *new_args.values() - ) + stencil = im.lambda_(*new_args.keys())(new_stencil_body) + stencil = restore_scan(stencil) # simplify stencil directly to keep the tree small - new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( - new_node + new_stencil = inline_lambdas.InlineLambdas.apply( + stencil, opcount_preserving=True, force_inline_lift_args=False + ) + new_stencil = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( + new_stencil, is_stencil=True, uids=uids ) # to keep the tree small - new_node = inline_lambdas.InlineLambdas.apply( - new_node, opcount_preserving=True, force_inline_lift_args=True + new_stencil = merge_let.MergeLet().visit(new_stencil) + new_stencil = inline_lambdas.InlineLambdas.apply( + new_stencil, opcount_preserving=True, force_inline_lift_args=True ) - new_node = inline_lifts.InlineLifts().visit(new_node) + new_stencil = inline_lifts.InlineLifts().visit(new_stencil) - type_inference.copy_type(from_=expr, to=new_node, allow_untyped=True) + new_node = im.as_fieldop(new_stencil, domain)(*new_args.values()) return new_node -@dataclasses.dataclass -class FuseAsFieldOp(eve.NodeTranslator): +def _arg_inline_predicate(node: itir.Expr, shifts: set[tuple[itir.OffsetLiteral, ...]]) -> bool: + if _is_tuple_expr_of_literals(node): + return True + + if ( + is_applied_fieldop := cpm.is_applied_as_fieldop(node) + and not cpm.is_call_to(node.fun.args[0], "scan") # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + ) or cpm.is_call_to(node, "if_"): + # always inline arg if it is an applied fieldop with only a single arg + if is_applied_fieldop and len(node.args) == 1: + return True + # argument is never used, will be removed when inlined + if len(shifts) == 0: + return True + # applied fieldop with list return type must always be inlined as no backend supports this + type_inference.reinfer(node) + assert isinstance(node.type, ts.TypeSpec) + dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, node.type) + if isinstance(dtype, ts.ListType): + return True + # only accessed at the center location + if shifts in [set(), {()}]: + return True + # TODO(tehrengruber): Disabled as the InlineCenterDerefLiftVars does not support this yet + # and it would increase the size of the tree otherwise. + # if len(shifts) == 1 and not any( + # trace_shifts.Sentinel.ALL_NEIGHBORS in access for access in shifts + # ): + # return True # noqa: ERA001 [commented-out-code] + + return False + + +def _make_tuple_element_inline_predicate(node: itir.Expr): + if cpm.is_applied_as_fieldop(node): # field, or tuple of fields + return True + if isinstance(node.type, ts.FieldType) and isinstance(node, itir.SymRef): + return True + return False + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class FuseAsFieldOp( + fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor +): """ Merge multiple `as_fieldop` calls into one. @@ -194,6 +301,23 @@ class FuseAsFieldOp(eve.NodeTranslator): as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1[ ⟩)(inp1, inp2, inp3) """ # noqa: RUF002 # ignore ambiguous multiplication character + class Transformation(enum.Flag): + #: Let `f_expr` be an expression with list dtype then + #: `let(f, f_expr) -> as_fieldop(...)(f)` -> `as_fieldop(...)(f_expr)` + FUSE_MAKE_TUPLE = enum.auto() + #: `as_fieldop(...)(as_fieldop(...)(a, b), c)` + #: -> as_fieldop(fused_stencil)(a, b, c) + FUSE_AS_FIELDOP = enum.auto() + INLINE_LET_VARS_OPCOUNT_PRESERVING = enum.auto() + + @classmethod + def all(self) -> FuseAsFieldOp.Transformation: + return functools.reduce(operator.or_, self.__members__.values()) + + PRESERVED_ANNEX_ATTRS = ("domain",) + + enabled_transformations = Transformation.all() + uids: eve_utils.UIDGenerator @classmethod @@ -204,48 +328,161 @@ def apply( offset_provider_type: common.OffsetProviderType, uids: Optional[eve_utils.UIDGenerator] = None, allow_undeclared_symbols=False, + within_set_at_expr: Optional[bool] = None, + enabled_transformations: Optional[Transformation] = None, ): + enabled_transformations = enabled_transformations or cls.enabled_transformations + node = type_inference.infer( node, offset_provider_type=offset_provider_type, allow_undeclared_symbols=allow_undeclared_symbols, ) + if within_set_at_expr is None: + within_set_at_expr = not isinstance(node, itir.Program) + if not uids: uids = eve_utils.UIDGenerator() - return cls(uids=uids).visit(node) + return cls(uids=uids, enabled_transformations=enabled_transformations).visit( + node, within_set_at_expr=within_set_at_expr + ) - def visit_FunCall(self, node: itir.FunCall): - node = self.generic_visit(node) + def transform_fuse_make_tuple(self, node: itir.Node, **kwargs): + if not cpm.is_call_to(node, "make_tuple"): + return None - if cpm.is_call_to(node.fun, "as_fieldop"): - node = _canonicalize_as_fieldop(node) - - if cpm.is_call_to(node.fun, "as_fieldop") and isinstance(node.fun.args[0], itir.Lambda): - stencil: itir.Lambda = node.fun.args[0] - args: list[itir.Expr] = node.args - shifts = trace_shifts.trace_stencil(stencil) + for arg in node.args: + type_inference.reinfer(arg) + assert not isinstance(arg.type, ts.FieldType) or ( + hasattr(arg.annex, "domain") + and isinstance(arg.annex.domain, domain_utils.SymbolicDomain) + ) - eligible_args = [] - for arg, arg_shifts in zip(args, shifts, strict=True): - assert isinstance(arg.type, ts.TypeSpec) - dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) - # TODO(tehrengruber): make this configurable - eligible_args.append( - _is_tuple_expr_of_literals(arg) - or ( - isinstance(arg, itir.FunCall) - and ( - ( - cpm.is_call_to(arg.fun, "as_fieldop") - and isinstance(arg.fun.args[0], itir.Lambda) - ) - or cpm.is_call_to(arg, "if_") + eligible_els = [_make_tuple_element_inline_predicate(arg) for arg in node.args] + field_args = [arg for i, arg in enumerate(node.args) if eligible_els[i]] + distinct_domains = set(arg.annex.domain.as_expr() for arg in field_args) + if len(distinct_domains) != len(field_args): + new_els: list[itir.Expr | None] = [None for _ in node.args] + field_args_by_domain: dict[itir.FunCall, list[tuple[int, itir.Expr]]] = {} + for i, arg in enumerate(node.args): + if eligible_els[i]: + assert isinstance(arg.annex.domain, domain_utils.SymbolicDomain) + domain = arg.annex.domain.as_expr() + field_args_by_domain.setdefault(domain, []) + field_args_by_domain[domain].append((i, arg)) + else: + new_els[i] = arg # keep as is + + if len(field_args_by_domain) == 1 and all(eligible_els): + # if we only have a single domain covering all args we don't need to create an + # unnecessary let + ((domain, inner_field_args),) = field_args_by_domain.items() + new_node = im.op_as_fieldop(lambda *args: im.make_tuple(*args), domain)( + *(arg for _, arg in inner_field_args) + ) + new_node = self.visit(new_node, **{**kwargs, "recurse": False}) + else: + let_vars = {} + for domain, inner_field_args in field_args_by_domain.items(): + if len(inner_field_args) > 1: + var = self.uids.sequential_id(prefix="__fasfop") + fused_args = im.op_as_fieldop(lambda *args: im.make_tuple(*args), domain)( + *(arg for _, arg in inner_field_args) ) - and (isinstance(dtype, ts.ListType) or len(arg_shifts) <= 1) - ) + type_inference.reinfer(arg) + # don't recurse into nested args, but only consider newly created `as_fieldop` + # note: this will always inline (as we inline center accessed) + let_vars[var] = self.visit(fused_args, **{**kwargs, "recurse": False}) + for outer_tuple_idx, (inner_tuple_idx, _) in enumerate(inner_field_args): + new_el = im.tuple_get(outer_tuple_idx, var) + new_el.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + new_els[inner_tuple_idx] = new_el + else: + i, arg = inner_field_args[0] + new_els[i] = arg + assert not any(el is None for el in new_els) + assert let_vars + new_node = im.let(*let_vars.items())(im.make_tuple(*new_els)) + new_node = inline_lambdas.inline_lambda(new_node, opcount_preserving=True) + return new_node + return None + + def transform_fuse_as_fieldop(self, node: itir.Node, **kwargs): + if cpm.is_applied_as_fieldop(node): + node = _canonicalize_as_fieldop(node) + stencil = node.fun.args[0] # type: ignore[attr-defined] # ensure cpm.is_applied_as_fieldop + assert isinstance(stencil, itir.Lambda) or cpm.is_call_to(stencil, "scan") + args: list[itir.Expr] = node.args + shifts = trace_shifts.trace_stencil(stencil, num_args=len(args)) + + eligible_els = [ + _arg_inline_predicate(arg, arg_shifts) + for arg, arg_shifts in zip(args, shifts, strict=True) + ] + if any(eligible_els): + return self.visit( + fuse_as_fieldop(node, eligible_els, uids=self.uids), + **{**kwargs, "recurse": False}, ) + return None + + def transform_inline_let_vars_opcount_preserving(self, node: itir.Node, **kwargs): + # when multiple `as_fieldop` calls are fused that use the same argument, this argument + # might become referenced once only. In order to be able to continue fusing such arguments + # try inlining here. + if cpm.is_let(node): + new_node = inline_lambdas.inline_lambda(node, opcount_preserving=True) + if new_node is not node: # nothing has been inlined + return self.visit(new_node, **kwargs) + + return None + + def generic_visit(self, node, **kwargs): + if cpm.is_applied_as_fieldop(node): # don't descend in stencil + return im.as_fieldop(*node.fun.args)(*self.visit(node.args, **kwargs)) + + # TODO(tehrengruber): This is a common pattern that should be absorbed in + # `FixedPointTransformation`. + if kwargs.get("recurse", True): + return super().generic_visit(node, **kwargs) + else: + return node + + def visit(self, node, **kwargs): + if isinstance(node, itir.SetAt): + return itir.SetAt( + expr=self.visit(node.expr, **kwargs | {"within_set_at_expr": True}), + # rest doesn't need to be visited + domain=node.domain, + target=node.target, + ) + + # don't execute transformations unless inside `SetAt` node + if not kwargs.get("within_set_at_expr"): + return self.generic_visit(node, **kwargs) + + # inline all fields with list dtype. This needs to happen before the children are visited + # such that the `as_fieldop` can be fused. + # TODO(tehrengruber): what should we do in case the field with list dtype is a let itself? + # This could duplicate other expressions which we did not intend to duplicate. + # TODO(tehrengruber): This should be moved into a `transform_` method, but + # `FixedPointTransformation` does not support pre-order transformations yet. + if cpm.is_let(node): + for arg in node.args: + type_inference.reinfer(arg) + eligible_els = [ + isinstance(arg.type, ts.FieldType) and isinstance(arg.type.dtype, ts.ListType) + for arg in node.args + ] + if any(eligible_els): + node = inline_lambdas.inline_lambda(node, eligible_params=eligible_els) + return self.visit(node, **kwargs) + + node = super().visit(node, **kwargs) + + if isinstance(node, itir.Expr) and hasattr(node.annex, "domain"): + node.annex.domain = node.annex.domain - return fuse_as_fieldop(node, eligible_args, uids=self.uids) return node diff --git a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py index 95c761d7ba..7bd26d0f19 100644 --- a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py +++ b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -from typing import ClassVar, Optional +from typing import ClassVar, Optional, TypeVar import gt4py.next.iterator.ir_utils.common_pattern_matcher as cpm from gt4py import eve @@ -23,6 +23,9 @@ def is_center_derefed_only(node: itir.Node) -> bool: return hasattr(node.annex, "recorded_shifts") and node.annex.recorded_shifts in [set(), {()}] +T = TypeVar("T", bound=itir.Program | itir.Lambda) + + @dataclasses.dataclass class InlineCenterDerefLiftVars(eve.NodeTranslator): """ @@ -45,14 +48,19 @@ class InlineCenterDerefLiftVars(eve.NodeTranslator): Note: This pass uses and preserves the `recorded_shifts` annex. """ - PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("recorded_shifts",) + PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("domain", "recorded_shifts") uids: eve_utils.UIDGenerator @classmethod - def apply(cls, node: itir.Program, uids: Optional[eve_utils.UIDGenerator] = None): + def apply( + cls, node: T, *, is_stencil=False, uids: Optional[eve_utils.UIDGenerator] = None + ) -> T: if not uids: uids = eve_utils.UIDGenerator() + if is_stencil: + assert isinstance(node, itir.Expr) + trace_shifts.trace_stencil(node, save_to_annex=True) return cls(uids=uids).visit(node) def visit_FunCall(self, node: itir.FunCall, **kwargs): diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index 7724aa86f6..166324486a 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -8,7 +8,7 @@ import dataclasses import enum -from typing import Callable, Optional +from typing import Callable, ClassVar, Optional import gt4py.eve as eve from gt4py.eve import NodeTranslator, traits @@ -80,6 +80,7 @@ def _transform_and_extract_lift_args( new_args = [] for i, arg in enumerate(node.args): if isinstance(arg, ir.SymRef): + # TODO(tehrengruber): Is it possible to reinfer the type if it is not inherited here? sym = ir.Sym(id=arg.id) assert sym not in extracted_args or extracted_args[sym] == arg extracted_args[sym] = arg @@ -92,6 +93,7 @@ def _transform_and_extract_lift_args( ) assert new_symbol not in extracted_args extracted_args[new_symbol] = arg + # TODO(tehrengruber): Is it possible to reinfer the type if it is not inherited here? new_args.append(ir.SymRef(id=new_symbol.id)) itir_node = im.lift(inner_stencil)(*new_args) @@ -112,6 +114,8 @@ class InlineLifts( function nodes. """ + PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("domain",) + class Flag(enum.IntEnum): #: `shift(...)(lift(f)(args...))` -> `lift(f)(shift(...)(args)...)` PROPAGATE_SHIFT = 1 @@ -157,6 +161,9 @@ def visit_FunCall( if self.flags & self.Flag.PROPAGATE_SHIFT and _is_shift_lift(node): shift = node.fun + # This transformation does not preserve the type (the position dims of the iterator + # change). Delete type to avoid errors. + shift.type = None assert len(node.args) == 1 lift_call = node.args[0] new_args = [ diff --git a/src/gt4py/next/iterator/transforms/inline_scalar.py b/src/gt4py/next/iterator/transforms/inline_scalar.py index d8a6e14d8a..b424074b5c 100644 --- a/src/gt4py/next/iterator/transforms/inline_scalar.py +++ b/src/gt4py/next/iterator/transforms/inline_scalar.py @@ -16,6 +16,8 @@ class InlineScalar(eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("domain",) + @classmethod def apply(cls, program: itir.Program, offset_provider_type: common.OffsetProviderType): program = itir_inference.infer(program, offset_provider_type=offset_provider_type) diff --git a/src/gt4py/next/iterator/transforms/merge_let.py b/src/gt4py/next/iterator/transforms/merge_let.py index 0e7d74e594..9c0c25bd49 100644 --- a/src/gt4py/next/iterator/transforms/merge_let.py +++ b/src/gt4py/next/iterator/transforms/merge_let.py @@ -5,6 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from typing import ClassVar import gt4py.eve as eve from gt4py.next.iterator import ir as itir @@ -26,6 +27,8 @@ class MergeLet(eve.PreserveLocationVisitor, eve.NodeTranslator): This can significantly reduce the depth of the tree and its readability. """ + PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("domain",) + def visit_FunCall(self, node: itir.FunCall): node = self.generic_visit(node) if ( diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 0a79848443..4023950dfb 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -70,6 +70,10 @@ def apply_common_transforms( ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = NormalizeShifts().visit(ir) + # TODO(tehrengruber): Many iterator test contain lifts that need to be inlined, e.g. + # test_can_deref. We didn't notice previously as FieldOpFusion did this implicitly everywhere. + ir = inline_lifts.InlineLifts().visit(ir) + # note: this increases the size of the tree # Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 4c44d660f6..0648df8363 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -14,15 +14,14 @@ from gt4py import eve from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import builtins, ir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im class ValidateRecordedShiftsAnnex(eve.NodeVisitor): """Ensure every applied lift and its arguments have the `recorded_shifts` annex populated.""" def visit_FunCall(self, node: ir.FunCall): - if is_applied_lift(node): + if cpm.is_applied_lift(node): assert hasattr(node.annex, "recorded_shifts") if len(node.annex.recorded_shifts) == 0: @@ -329,13 +328,16 @@ def fun(*args): @classmethod def trace_stencil( cls, stencil: ir.Expr, *, num_args: Optional[int] = None, save_to_annex: bool = False - ): + ) -> list[set[tuple[ir.OffsetLiteral, ...]]]: # If we get a lambda we can deduce the number of arguments. if isinstance(stencil, ir.Lambda): assert num_args is None or num_args == len(stencil.params) num_args = len(stencil.params) + elif cpm.is_call_to(stencil, "scan"): + assert isinstance(stencil.args[0], ir.Lambda) + num_args = len(stencil.args[0].params) - 1 if not isinstance(num_args, int): - raise ValueError("Stencil must be an 'itir.Lambda' or `num_args` is given.") + raise ValueError("Stencil must be an 'itir.Lambda', scan, or `num_args` is given.") assert isinstance(num_args, int) args = [im.ref(f"__arg{i}") for i in range(num_args)] diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 19ab3ecdda..131b773dd2 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -385,25 +385,30 @@ def apply_shift( assert isinstance(it, it_ts.IteratorType) if it.position_dims == "unknown": # nothing to do here return it - new_position_dims = [*it.position_dims] - assert len(offset_literals) % 2 == 0 - for offset_axis, _ in zip(offset_literals[:-1:2], offset_literals[1::2], strict=True): - assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance( - offset_axis.value, common.Dimension - ) - type_ = offset_provider_type[offset_axis.value.value] - if isinstance(type_, common.Dimension): - pass - elif isinstance(type_, common.NeighborConnectivityType): - found = False - for i, dim in enumerate(new_position_dims): - if dim.value == type_.source_dim.value: - assert not found - new_position_dims[i] = type_.codomain - found = True - assert found - else: - raise NotImplementedError(f"{type_} is not a supported Connectivity type.") + new_position_dims: list[common.Dimension] | str + if offset_provider_type: + new_position_dims = [*it.position_dims] + assert len(offset_literals) % 2 == 0 + for offset_axis, _ in zip(offset_literals[:-1:2], offset_literals[1::2], strict=True): + assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance( + offset_axis.value, common.Dimension + ) + type_ = offset_provider_type[offset_axis.value.value] + if isinstance(type_, common.Dimension): + pass + elif isinstance(type_, common.NeighborConnectivityType): + found = False + for i, dim in enumerate(new_position_dims): + if dim.value == type_.source_dim.value: + assert not found + new_position_dims[i] = type_.codomain + found = True + assert found + else: + raise NotImplementedError(f"{type_} is not a supported Connectivity type.") + else: + # during re-inference we don't have an offset provider type + new_position_dims = "unknown" return it_ts.IteratorType( position_dims=new_position_dims, defined_dims=it.defined_dims, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index 168e9490e0..fd884e239f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -5,19 +5,27 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import copy from typing import Callable, Optional from gt4py import next as gtx from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms import fuse_as_fieldop +from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils +from gt4py.next.iterator.transforms import fuse_as_fieldop, collapse_tuple from gt4py.next.type_system import type_specifications as ts IDim = gtx.Dimension("IDim") +JDim = gtx.Dimension("JDim") field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) +def _with_domain_annex(node: itir.Expr, domain: itir.Expr): + node = copy.deepcopy(node) + node.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + return node + + def test_trivial(): d = im.domain("cartesian_domain", {IDim: (0, 1)}) testee = im.op_as_fieldop("plus", d)( @@ -46,6 +54,25 @@ def test_trivial_literal(): assert actual == expected +def test_trivial_same_arg_twice(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.op_as_fieldop("plus", d)( + # note: inp1 occurs twice here + im.op_as_fieldop("multiplies", d)(im.ref("inp1", field_type), im.ref("inp1", field_type)), + im.ref("inp2", field_type), + ) + expected = im.as_fieldop( + im.lambda_("inp1", "inp2")( + im.plus(im.multiplies_(im.deref("inp1"), im.deref("inp1")), im.deref("inp2")) + ), + d, + )(im.ref("inp1", field_type), im.ref("inp2", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + def test_tuple_arg(): d = im.domain("cartesian_domain", {}) testee = im.op_as_fieldop("plus", d)( @@ -99,19 +126,166 @@ def test_no_inline(): im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))) ), d1, - )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type))) + )(im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type))) actual = fuse_as_fieldop.FuseAsFieldOp.apply( testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert actual == testee +def test_staged_inlining(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.let( + "tmp", im.op_as_fieldop("plus", d)(im.ref("a", field_type), im.ref("b", field_type)) + )( + im.op_as_fieldop("plus", d)( + im.op_as_fieldop(im.lambda_("a")(im.plus("a", 1)), d)("tmp"), + im.op_as_fieldop(im.lambda_("a")(im.plus("a", 2)), d)("tmp"), + ) + ) + expected = im.as_fieldop( + im.lambda_("a", "b")( + im.let("_icdlv_1", im.plus(im.deref("a"), im.deref("b")))( + im.plus(im.plus("_icdlv_1", 1), im.plus("_icdlv_1", 2)) + ) + ), + d, + )(im.ref("a", field_type), im.ref("b", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_make_tuple_fusion_trivial(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + im.as_fieldop("deref", d)(im.ref("a", field_type)), + ) + expected = im.as_fieldop( + im.lambda_("a")(im.make_tuple(im.deref("a"), im.deref("a"))), + d, + )(im.ref("a", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call `{v[0], v[1]}(actual)` + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + + +def test_make_tuple_fusion_symref(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + _with_domain_annex(im.ref("b", field_type), d), + ) + expected = im.as_fieldop( + im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))), + d, + )(im.ref("a", field_type), im.ref("b", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + + +def test_make_tuple_fusion_symref_same_ref(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + _with_domain_annex(im.ref("a", field_type), d), + ) + expected = im.as_fieldop( + im.lambda_("a")(im.make_tuple(im.deref("a"), im.deref("a"))), + d, + )(im.ref("a", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + + +def test_make_tuple_nested(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + _with_domain_annex(im.ref("a", field_type), d), + im.make_tuple( + _with_domain_annex(im.ref("b", field_type), d), + _with_domain_annex(im.ref("c", field_type), d), + ), + ) + expected = im.as_fieldop( + im.lambda_("a", "b", "c")( + im.make_tuple(im.deref("a"), im.make_tuple(im.deref("b"), im.deref("c"))) + ), + d, + )(im.ref("a", field_type), im.ref("b", field_type), im.ref("c", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + + +def test_make_tuple_fusion_different_domains(): + d1 = im.domain("cartesian_domain", {IDim: (0, 1)}) + d2 = im.domain("cartesian_domain", {JDim: (0, 1)}) + field_i_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + field_j_type = ts.FieldType(dims=[JDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + testee = im.make_tuple( + im.as_fieldop("deref", d1)(im.ref("a", field_i_type)), + im.as_fieldop("deref", d2)(im.ref("b", field_j_type)), + im.as_fieldop("deref", d1)(im.ref("c", field_i_type)), + im.as_fieldop("deref", d2)(im.ref("d", field_j_type)), + ) + expected = im.let( + ( + "__fasfop_1", + im.as_fieldop(im.lambda_("a", "c")(im.make_tuple(im.deref("a"), im.deref("c"))), d1)( + "a", "c" + ), + ), + ( + "__fasfop_2", + im.as_fieldop(im.lambda_("b", "d")(im.make_tuple(im.deref("b"), im.deref("d"))), d2)( + "b", "d" + ), + ), + )( + im.make_tuple( + im.tuple_get(0, "__fasfop_1"), + im.tuple_get(0, "__fasfop_2"), + im.tuple_get(1, "__fasfop_1"), + im.tuple_get(1, "__fasfop_2"), + ) + ) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + def test_partial_inline(): d1 = im.domain("cartesian_domain", {IDim: (1, 2)}) d2 = im.domain("cartesian_domain", {IDim: (0, 3)}) testee = im.as_fieldop( # first argument read at multiple locations -> not inlined - # second argument only reat at a single location -> inlined + # second argument only read at a single location -> inlined im.lambda_("a", "b")( im.plus( im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))), @@ -120,19 +294,88 @@ def test_partial_inline(): ), d1, )( - im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), - im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), + im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type)), + im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type)), ) expected = im.as_fieldop( - im.lambda_("a", "inp1")( + im.lambda_("a", "inp1", "inp2")( im.plus( im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))), - im.deref("inp1"), + im.plus(im.deref("inp1"), im.deref("inp2")), ) ), d1, - )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), "inp1") + )( + im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type)), + "inp1", + "inp2", + ) actual = fuse_as_fieldop.FuseAsFieldOp.apply( testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert actual == expected + + +def test_chained_fusion(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.let( + "a", im.op_as_fieldop("plus", d)(im.ref("inp1", field_type), im.ref("inp2", field_type)) + )( + im.op_as_fieldop("plus", d)( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + im.as_fieldop("deref", d)(im.ref("a", field_type)), + ) + ) + expected = im.as_fieldop( + im.lambda_("inp1", "inp2")( + im.let("_icdlv_1", im.plus(im.deref("inp1"), im.deref("inp2")))( + im.plus("_icdlv_1", "_icdlv_1") + ) + ), + d, + )(im.ref("inp1", field_type), im.ref("inp2", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_inline_as_fieldop_with_list_dtype(): + list_field_type = ts.FieldType( + dims=[IDim], dtype=ts.ListType(element_type=ts.ScalarType(kind=ts.ScalarKind.INT32)) + ) + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.as_fieldop( + im.lambda_("inp")(im.call(im.call("reduce")("plus", 0))(im.deref("inp"))), d + )(im.as_fieldop("deref")(im.ref("inp", list_field_type))) + expected = im.as_fieldop( + im.lambda_("inp")(im.call(im.call("reduce")("plus", 0))(im.deref("inp"))), d + )(im.ref("inp", list_field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_inline_into_scan(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + scan = im.call("scan")(im.lambda_("state", "a")(im.plus("state", im.deref("a"))), True, 0) + testee = im.as_fieldop(scan, d)(im.as_fieldop("deref")(im.ref("a", field_type))) + expected = im.as_fieldop(scan, d)(im.ref("a", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_no_inline_into_scan(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + scan_stencil = im.call("scan")( + im.lambda_("state", "a")(im.plus("state", im.deref("a"))), True, 0 + ) + scan = im.as_fieldop(scan_stencil, d)(im.ref("a", field_type)) + testee = im.as_fieldop(im.lambda_("arg")(im.deref("arg")), d)(scan) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == testee From 5c3393fdc993c2891d601388ab05dd0f93cb1a9e Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 7 Feb 2025 11:22:03 +0100 Subject: [PATCH 134/178] fix[next][dace]: make if_ always execute branch exclusively (#1846) This PR reverts the change previously made in #1824. Lowering `if_` expressions to tasklet is semantically wrong, from a dataflow perspective. It causes segmentation faults in several stencils, that rely on exclusive branch execution. The source problem was that full array shape was passed into the nested SDFG scope, which prevented map fusion in most cases. This PR extends the lowering with the detection of simple iterator dereferencing, without shifts: for this type of data access, only the local element is moved into the nested SDFG. However, when shift is applied on the iterator input (which typically happens in iterator view), the full array shape is still passed. This approach increases the optimization opportunities by enabling more map fusion. At the same time, it keeps the `if_` semantics of exclusive branch execution. --- .../runners/dace/gtir_dataflow.py | 211 +++++++++--------- .../runners/dace/gtir_scan_translator.py | 7 +- 2 files changed, 102 insertions(+), 116 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py index 04d362b834..e6f33208e3 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -14,6 +14,7 @@ Any, Dict, Final, + Iterable, List, Optional, Protocol, @@ -29,7 +30,7 @@ from gt4py import eve from gt4py.next import common as gtx_common, utils as gtx_utils -from gt4py.next.iterator import builtins, ir as gtir +from gt4py.next.iterator import builtins as gtir_builtins, ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import symbol_ref_utils from gt4py.next.program_processors.runners.dace import ( @@ -336,7 +337,6 @@ class LambdaToDataflow(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState subgraph_builder: gtir_sdfg.DataflowBuilder - scan_carry_symbol: Optional[gtir.Sym] input_edges: list[DataflowInputEdge] = dataclasses.field(default_factory=lambda: []) symbol_map: dict[ str, @@ -533,14 +533,17 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: assert isinstance(node.type, ts.ScalarType) return MemletExpr(arg_expr.field, arg_expr.gt_dtype, subset="0") - # default case: deref a field with one or more dimensions + # handle default case below: deref a field with one or more dimensions + + # when the indices are all dace symbolic expressions, the deref is lowered + # to a memlet, where the index is the memlet subset if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()): # when all indices are symbolic expressions, we can perform direct field access through a memlet field_subset = arg_expr.get_memlet_subset(self.sdfg) return MemletExpr(arg_expr.field, arg_expr.gt_dtype, field_subset) - # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, - # either indirection through connectivity table or dynamic cartesian offset. + # when any of the indices is a runtime value (either a dynamic cartesian + # offset or a connectivity offset), the deref is lowered to a tasklet assert all(dim in arg_expr.indices for dim, _ in arg_expr.field_domain) assert len(field_desc.shape) == len(arg_expr.field_domain) field_indices = [(dim, arg_expr.indices[dim]) for dim, _ in arg_expr.field_domain] @@ -559,7 +562,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: for dim, index in field_indices ) deref_node = self._add_tasklet( - "runtime_deref", + "deref", {"field"} | set(index_connectors), {"val"}, code=f"val = field[{index_internals}]", @@ -603,6 +606,7 @@ def _visit_if_branch_arg( if_branch_state: dace.SDFGState, param_name: str, arg: IteratorExpr | DataExpr, + deref_on_input_memlet: bool, if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr], ) -> IteratorExpr | ValueExpr: """ @@ -613,35 +617,56 @@ def _visit_if_branch_arg( if_branch_state: The state inside the nested SDFG where the if branch is lowered. param_name: The parameter name of the input argument. arg: The input argument expression. + deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet. if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function. """ + use_full_shape = False if isinstance(arg, (MemletExpr, ValueExpr)): + arg_desc = arg.dc_node.desc(self.sdfg) arg_expr = arg - arg_node = arg.dc_node - arg_desc = arg_node.desc(self.sdfg) - if isinstance(arg, MemletExpr): - assert arg.subset.num_elements() == 1 - arg_desc = dace.data.Scalar(arg_desc.dtype) - else: - assert isinstance(arg_desc, dace.data.Scalar) elif isinstance(arg, IteratorExpr): - arg_node = arg.field - arg_desc = arg_node.desc(self.sdfg) - arg_expr = MemletExpr(arg_node, arg.gt_dtype, dace_subsets.Range.from_array(arg_desc)) + arg_desc = arg.field.desc(self.sdfg) + if deref_on_input_memlet: + # If the iterator is just dereferenced inside the branch state, + # we can access the array outside the nested SDFG and pass the + # local data. This approach makes the data dependencies of nested + # structures more explicit and thus makes it easier for MapFusion + # to correctly infer the data dependencies. + memlet_subset = arg.get_memlet_subset(self.sdfg) + arg_expr = MemletExpr(arg.field, arg.gt_dtype, memlet_subset) + else: + # In order to shift the iterator inside the branch dataflow, + # we have to pass the full array shape. + arg_expr = MemletExpr( + arg.field, arg.gt_dtype, dace_subsets.Range.from_array(arg_desc) + ) + use_full_shape = True else: raise TypeError(f"Unexpected {arg} as input argument.") - if param_name in if_sdfg.arrays: - inner_desc = if_sdfg.data(param_name) - assert not inner_desc.transient - else: + if use_full_shape: inner_desc = arg_desc.clone() inner_desc.transient = False + elif isinstance(arg.gt_dtype, ts.ScalarType): + inner_desc = dace.data.Scalar(arg_desc.dtype) + else: + # for list of values, we retrieve the local size from the corresponding offset + assert arg.gt_dtype.offset_type is not None + offset_provider_type = self.subgraph_builder.get_offset_provider_type( + arg.gt_dtype.offset_type.value + ) + assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) + inner_desc = dace.data.Array(arg_desc.dtype, [offset_provider_type.max_neighbors]) + + if param_name in if_sdfg.arrays: + # the data desciptor was added by the visitor of the other branch expression + assert if_sdfg.data(param_name) == inner_desc + else: if_sdfg.add_datadesc(param_name, inner_desc) if_sdfg_input_memlets[param_name] = arg_expr inner_node = if_branch_state.add_access(param_name) - if isinstance(arg, IteratorExpr): + if isinstance(arg, IteratorExpr) and use_full_shape: return IteratorExpr(inner_node, arg.gt_dtype, arg.field_domain, arg.indices) else: return ValueExpr(inner_node, arg.gt_dtype) @@ -652,6 +677,7 @@ def _visit_if_branch( if_branch_state: dace.SDFGState, expr: gtir.Expr, if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr], + direct_deref_iterators: Iterable[str], ) -> tuple[ list[DataflowInputEdge], tuple[DataflowOutputEdge | tuple[Any, ...], ...], @@ -666,6 +692,7 @@ def _visit_if_branch( if_branch_state: The state inside the nested SDFG where the if branch is lowered. expr: The if branch expression to lower. if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function. + direct_deref_iterators: Fields that are accessed with direct iterator deref, without any shift. Returns: A tuple containing: @@ -682,15 +709,29 @@ def _visit_if_branch( ptype = get_tuple_type(arg) # type: ignore[arg-type] psymbol = im.sym(pname, ptype) psymbol_tree = gtir_sdfg_utils.make_symbol_tree(pname, ptype) + deref_on_input_memlet = pname in direct_deref_iterators inner_arg = gtx_utils.tree_map( - lambda tsym, targ: self._visit_if_branch_arg( - if_sdfg, if_branch_state, tsym.id, targ, if_sdfg_input_memlets + lambda tsym, + targ, + deref_on_input_memlet=deref_on_input_memlet: self._visit_if_branch_arg( + if_sdfg, + if_branch_state, + tsym.id, + targ, + deref_on_input_memlet, + if_sdfg_input_memlets, ) )(psymbol_tree, arg) else: psymbol = im.sym(pname, arg.gt_dtype) # type: ignore[union-attr] + deref_on_input_memlet = pname in direct_deref_iterators inner_arg = self._visit_if_branch_arg( - if_sdfg, if_branch_state, pname, arg, if_sdfg_input_memlets + if_sdfg, + if_branch_state, + pname, + arg, + deref_on_input_memlet, + if_sdfg_input_memlets, ) lambda_args.append(inner_arg) lambda_params.append(psymbol) @@ -742,11 +783,6 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[A Lowers an if-expression with exclusive branch execution into a nested SDFG, in which each branch is lowered into a dataflow in a separate state and the if-condition is represented as the inter-state edge condition. - - Exclusive branch execution for local if expressions is meant to be used - in iterator view. Iterator view is required ONLY inside scan field operators. - For regular field operators, the fieldview behavior of if-expressions - corresponds to a local select, therefore it should be lowered to a tasklet. """ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExpr: @@ -805,9 +841,41 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp nsdfg.add_scalar("__cond", dace.dtypes.bool) input_memlets["__cond"] = condition_value + # Collect all field iterators that are shifted inside any of the then/else + # branch expressions. Iterator shift expressions require the field argument + # as iterator, therefore the corresponding array has to be passed with full + # shape into the nested SDFG where the if_ expression is lowered. When the + # branch expression simply does `deref` on the iterator, without any shifting, + # it corresponds to a direct element access. Such `deref` expressions can + # be lowered outside the nested SDFG, so that just the local value (a scalar + # or a list of values) is passed as input to the nested SDFG. + shifted_iterator_symbols = set() + for branch_expr in node.args[1:3]: + for shift_node in eve.walk_values(branch_expr).filter( + lambda x: cpm.is_applied_shift(x) + ): + shifted_iterator_symbols |= ( + eve.walk_values(shift_node) + .if_isinstance(gtir.SymRef) + .map(lambda x: str(x.id)) + .filter(lambda x: isinstance(self.symbol_map.get(x, None), IteratorExpr)) + .to_set() + ) + iterator_symbols = { + sym_name + for sym_name, sym_type in self.symbol_map.items() + if isinstance(sym_type, IteratorExpr) + } + direct_deref_iterators = ( + set(symbol_ref_utils.collect_symbol_refs(node.args[1:3], iterator_symbols)) + - shifted_iterator_symbols + ) + for nstate, arg in zip([tstate, fstate], node.args[1:3]): # visit each if-branch in the corresponding state of the nested SDFG - in_edges, output_tree = self._visit_if_branch(nsdfg, nstate, arg, input_memlets) + in_edges, output_tree = self._visit_if_branch( + nsdfg, nstate, arg, input_memlets, direct_deref_iterators + ) for edge in in_edges: edge.connect(map_entry=None) @@ -1511,7 +1579,7 @@ def _make_unstructured_shift( def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: # convert builtin-index type to dace type IndexDType: Final = gtx_dace_utils.as_dace_type( - ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())) + ts.ScalarType(kind=getattr(ts.ScalarKind, gtir_builtins.INTEGER_INDEX_BUILTIN.upper())) ) assert isinstance(node.fun, gtir.FunCall) @@ -1637,87 +1705,13 @@ def _visit_tuple_get( tuple_fields = self.visit(node.args[1]) return tuple_fields[index] - def requires_exclusive_if(self, node: gtir.FunCall) -> bool: - """ - The meaning of `if_` builtin function is unclear in GTIR. - In some context, it corresponds to a ternary operator where, depending on - the condition result, only one branch or the other should be executed, - because one of them is invalid. The typical case is the use of `if_` to - decide whether it is possible or not to access a shifted iterator, for - example when the condition expression calls `can_deref`. - The ternary operator is also used in iterator view, where the field arguments - are not necessarily both defined on the entire output domain (this behavior - should not appear in field view, because there the user code should use - `concat_where` instead of `where` for such cases). It is difficult to catch - such behavior, because it would require to know the exact domain of all - fields, which is not known at compile time. However, the iterator view - behavior should only appear inside scan field operators. - A different usage of `if_` expressions is selecting one argument value or - the other, where both arguments are defined on the output domain, therefore - always valid. - In order to simplify the SDFG and facilitate the optimization stage, we - try to avoid the ternary operator form when not needed. The reason is that - exclusive branch execution is represented in the SDFG as a conditional - state transition, which prevents fusion. - """ - assert cpm.is_call_to(node, "if_") - assert len(node.args) == 3 - - condition_vars = ( - eve.walk_values(node.args[0]) - .if_isinstance(gtir.SymRef) - .map(lambda node: str(node.id)) - .filter(lambda x: x in self.symbol_map) - .to_set() - ) - - # first, check if any argument contains shift expressions that depend on the condition variables - for arg in node.args[1:3]: - shift_nodes = ( - eve.walk_values(arg).filter(lambda node: cpm.is_applied_shift(node)).to_set() - ) - for shift_node in shift_nodes: - shift_vars = ( - eve.walk_values(shift_node) - .if_isinstance(gtir.SymRef) - .map(lambda node: str(node.id)) - .filter(lambda x: x in self.symbol_map) - .to_set() - ) - # require exclusive branch execution if any shift expression one of - # the if branches accesses a variable used in the condition expression - depend_vars = condition_vars.intersection(shift_vars) - if len(depend_vars) != 0: - return True - - # secondly, check whether the `if_` branches access different sets of fields - # and this happens inside a scan field operator - if self.scan_carry_symbol is not None: - # the `if_` node is inside a scan stencil expression - scan_carry_var = str(self.scan_carry_symbol.id) - if scan_carry_var in condition_vars: - br1_vars, br2_vars = ( - eve.walk_values(arg) - .if_isinstance(gtir.SymRef) - .map(lambda node: str(node.id)) - .filter(lambda x: isinstance(self.symbol_map.get(x, None), MemletExpr)) - .to_set() - for arg in node.args[1:3] - ) - if br1_vars != br2_vars: - # the two branches of the `if_` expression access different sets of fields, - # depending on the scan carry value - return True - - return False - def visit_FunCall( self, node: gtir.FunCall ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) - elif cpm.is_call_to(node, "if_") and self.requires_exclusive_if(node): + elif cpm.is_call_to(node, "if_"): return self._visit_if(node) elif cpm.is_call_to(node, "neighbors"): @@ -1854,7 +1848,6 @@ def translate_lambda_to_dataflow( | ValueExpr | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] ], - scan_carry_symbol: Optional[gtir.Sym] = None, ) -> tuple[ list[DataflowInputEdge], tuple[DataflowOutputEdge | tuple[Any, ...], ...], @@ -1873,15 +1866,13 @@ def translate_lambda_to_dataflow( sdfg_builder: Helper class to build the dataflow inside the given SDFG. node: Lambda node to visit. args: Arguments passed to lambda node. - scan_carry_symbol: When set, the lowering of `if_` expression will consider - using the ternary operator form with exclusive branch execution. Returns: A tuple of two elements: - List of connections for data inputs to the dataflow. - Tree representation of output data connections. """ - taskgen = LambdaToDataflow(sdfg, state, sdfg_builder, scan_carry_symbol) + taskgen = LambdaToDataflow(sdfg, state, sdfg_builder) lambda_output = taskgen.visit_let(node, args) if isinstance(lambda_output, DataflowOutputEdge): diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py b/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py index 743b4d33e4..da10d4bddd 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py @@ -414,12 +414,7 @@ def init_scan_carry(sym: gtir.Sym) -> None: # stil inside the 'compute' state, generate the dataflow representing the stencil # to be applied on the horizontal domain lambda_input_edges, lambda_result = gtir_dataflow.translate_lambda_to_dataflow( - nsdfg, - compute_state, - lambda_translator, - lambda_node, - stencil_args, - scan_carry_symbol=scan_carry_symbol, + nsdfg, compute_state, lambda_translator, lambda_node, stencil_args ) # connect the dataflow input directly to the source data nodes, without passing through a map node; # the reason is that the map for horizontal domain is outside the scan loop region From 34b574a3d3e448fb2ccd3a131c106a1ac26b16c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Fri, 7 Feb 2025 11:51:31 +0100 Subject: [PATCH 135/178] fix[dace]: Updating MapFusion (#1850) The [MapFusion PR](https://github.com/spcl/dace/pull/1629) in DaCe is still under review. However, the MapFusion in that PR has evolved, i.e. some bugs were fixed and now GT4Py is also these bugs. This PR essentially back ports some of the fixes to GT4Py. Note that this is a temporary solution and as soon as the MapFusion PR has been merged (and parallel map fusion has been introduced) the GT4Py version will go away. --- .../dace/transformations/map_fusion_serial.py | 183 +++++++++++------- 1 file changed, 118 insertions(+), 65 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py index 0ef33cae97..27d962d0bd 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py @@ -17,7 +17,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union import dace -from dace import data, dtypes, properties, subsets, symbolic, transformation +from dace import data, properties, subsets, symbolic, transformation from dace.sdfg import SDFG, SDFGState, graph, nodes from . import map_fusion_helper as mfh @@ -752,8 +752,10 @@ def handle_intermediate_set( Before the transformation the `state` does not have to be valid and after this function has run the state is (most likely) invalid. """ - - map_params = map_exit_1.map.params.copy() + first_map_exit = map_exit_1 + second_map_entry = map_entry_2 + second_map_exit = map_exit_2 + map_params = first_map_exit.map.params.copy() # Now we will iterate over all intermediate edges and process them. # If not stated otherwise the comments assume that we run in exclusive mode. @@ -763,36 +765,22 @@ def handle_intermediate_set( inter_node: nodes.AccessNode = out_edge.dst inter_name = inter_node.data inter_desc = inter_node.desc(sdfg) - inter_shape = inter_desc.shape # Now we will determine the shape of the new intermediate. This size of # this temporary is given by the Memlet that goes into the first map exit. pre_exit_edges = list( - state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) + state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:]) ) if len(pre_exit_edges) != 1: raise NotImplementedError() pre_exit_edge = pre_exit_edges[0] - new_inter_shape_raw = symbolic.overapproximate(pre_exit_edge.data.subset.size()) - - # Over approximation will leave us with some unneeded size one dimensions. - # If they are removed some dace transformations (especially auto optimization) - # will have problems. - if not self.strict_dataflow: - squeezed_dims: List[int] = [] # These are the dimensions we removed. - new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. - for dim, (proposed_dim_size, full_dim_size) in enumerate( - zip(new_inter_shape_raw, inter_shape) - ): - if full_dim_size == 1: # Must be kept! - new_inter_shape.append(proposed_dim_size) - elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. - squeezed_dims.append(dim) - else: - new_inter_shape.append(proposed_dim_size) - else: - squeezed_dims = [] - new_inter_shape = list(new_inter_shape_raw) + + (new_inter_shape_raw, new_inter_shape, squeezed_dims) = ( + self.compute_reduced_intermediate( + producer_subset=pre_exit_edge.data.dst_subset, + inter_desc=inter_desc, + ) + ) # This is the name of the new "intermediate" node that we will create. # It will only have the shape `new_inter_shape` which is basically its @@ -808,7 +796,6 @@ def handle_intermediate_set( new_inter_name, dtype=inter_desc.dtype, transient=True, - storage=dtypes.StorageType.Register, find_new_name=True, ) @@ -822,32 +809,30 @@ def handle_intermediate_set( shape=new_inter_shape, dtype=inter_desc.dtype, find_new_name=True, - storage=dtypes.StorageType.Register, ) new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) # Get the subset that defined into which part of the old intermediate # the old output edge wrote to. We need that to adjust the producer # Memlets, since they now write into the new (smaller) intermediate. - assert pre_exit_edge.data.data == inter_name - assert pre_exit_edge.data.dst_subset is not None producer_offset = self.compute_offset_subset( original_subset=pre_exit_edge.data.dst_subset, intermediate_desc=inter_desc, map_params=map_params, + producer_offset=None, ) - # Memlets have a lot of additional informations, such as dynamic. - # To ensure that we get all of them, we will now copy them and modify - # the one that was originally there. We also hope that propagate will - # set the rest for us correctly. + # Memlets have a lot of additional informations, to ensure that we get + # all of them, we have to do it this way. The main reason for this is + # to handle the case were the "Memlet reverse direction", i.e. `data` + # refers to the other end of the connection than before. + assert pre_exit_edge.data.dst_subset is not None + new_pre_exit_memlet_src_subset = copy.deepcopy(pre_exit_edge.data.src_subset) + new_pre_exit_memlet_dst_subset = subsets.Range.from_array(new_inter_desc) + new_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) new_pre_exit_memlet.data = new_inter_name - new_pre_exit_memlet.dst_subset = subsets.Range.from_array(new_inter_desc) - # New we will reroute the output Memlet, thus it will no longer pass - # through the Map exit but through the newly created intermediate. - # NOTE: We will delete the previous edge later. new_pre_exit_edge = state.add_edge( pre_exit_edge.src, pre_exit_edge.src_conn, @@ -856,6 +841,11 @@ def handle_intermediate_set( new_pre_exit_memlet, ) + # We can update `{src, dst}_subset` only after we have inserted the + # edge, this is because the direction of the Memlet might change. + new_pre_exit_edge.data.src_subset = new_pre_exit_memlet_src_subset + new_pre_exit_edge.data.dst_subset = new_pre_exit_memlet_dst_subset + # We now handle the MemletTree defined by this edge. # The newly created edge, only handled the last collection step. for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children( @@ -863,11 +853,15 @@ def handle_intermediate_set( ): producer_edge = producer_tree.edge - # Associate the (already existing) Memlet with the new data. - # TODO(phimuell): Improve the code below to remove the check. - assert producer_edge.data.data == inter_name - producer_edge.data.data = new_inter_name + # In order to preserve the intrinsic direction of Memlets we only have to change + # the `.data` attribute of the producer Memlet if it refers to the old intermediate. + # If it refers to something different we keep it. Note that this case can only + # occur if the producer is an AccessNode. + if producer_edge.data.data == inter_name: + producer_edge.data.data = new_inter_name + # Regardless of the intrinsic direction of the Memlet, the subset we care about + # is always `dst_subset`. if is_scalar: producer_edge.data.dst_subset = "0" elif producer_edge.data.dst_subset is not None: @@ -885,7 +879,7 @@ def handle_intermediate_set( # NOTE: Assumes that map (if connected is the direct neighbour). conn_names: Set[str] = set() for inter_node_out_edge in state.out_edges(inter_node): - if inter_node_out_edge.dst == map_entry_2: + if inter_node_out_edge.dst == second_map_entry: assert inter_node_out_edge.dst_conn.startswith("IN_") conn_names.add(inter_node_out_edge.dst_conn) else: @@ -900,9 +894,7 @@ def handle_intermediate_set( for in_conn_name in conn_names: out_conn_name = "OUT_" + in_conn_name[3:] - for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): - assert inner_edge.data.data == inter_name # DIRECTION!! - + for inner_edge in state.out_edges_by_connector(second_map_entry, out_conn_name): # As for the producer side, we now read from a smaller array, # So we must offset them, we use the original edge for this. assert inner_edge.data.src_subset is not None @@ -913,11 +905,17 @@ def handle_intermediate_set( producer_offset=producer_offset, ) - # Now we create a new connection that instead reads from the new - # intermediate, instead of the old one. For this we use the - # old Memlet as template. However it is not fully initialized. + # Now create the memlet for the new consumer. To make sure that we get all attributes + # of the Memlet we make a deep copy of it. There is a tricky part here, we have to + # access `src_subset` however, this is only correctly set once it is put inside the + # SDFG. Furthermore, we have to make sure that the Memlet does not change its direction. + # i.e. that the association of `subset` and `other_subset` does not change. For this + # reason we only modify `.data` attribute of the Memlet if its name refers to the old + # intermediate. Furthermore, to play it safe, we only access the subset, `src_subset` + # after we have inserted it to the SDFG. new_inner_memlet = copy.deepcopy(inner_edge.data) - new_inner_memlet.data = new_inter_name + if inner_edge.data.data == inter_name: + new_inner_memlet.data = new_inter_name # Now we replace the edge from the SDFG. state.remove_edge(inner_edge) @@ -934,6 +932,7 @@ def handle_intermediate_set( if is_scalar: new_inner_memlet.subset = "0" elif new_inner_memlet.src_subset is not None: + # TODO(phimuell): Figuring out if `src_subset` is None is an error. new_inner_memlet.src_subset.offset(consumer_offset, negative=True) new_inner_memlet.src_subset.pop(squeezed_dims) @@ -941,23 +940,30 @@ def handle_intermediate_set( for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children( include_self=False ): - assert consumer_tree.edge.data.data == inter_name - consumer_edge = consumer_tree.edge - consumer_edge.data.data = new_inter_name + + # We only modify the data if the Memlet refers to the old intermediate data. + # We can not do this unconditionally, because it might change the intrinsic + # direction of a Memlet and then `src_subset` would at the next `try_initialize` + # be wrong. Note that this case only occurs if the destination is an AccessNode. + if consumer_edge.data.data == inter_name: + consumer_edge.data.data = new_inter_name + + # Now we have to adapt the subsets. if is_scalar: consumer_edge.data.src_subset = "0" elif consumer_edge.data.src_subset is not None: + # TODO(phimuell): Figuring out if `src_subset` is None is an error. consumer_edge.data.src_subset.offset(consumer_offset, negative=True) consumer_edge.data.src_subset.pop(squeezed_dims) # The edge that leaves the second map entry was already deleted. We now delete # the edges that connected the intermediate node with the second map entry. - for edge in list(state.in_edges_by_connector(map_entry_2, in_conn_name)): + for edge in list(state.in_edges_by_connector(second_map_entry, in_conn_name)): assert edge.src == inter_node state.remove_edge(edge) - map_entry_2.remove_in_connector(in_conn_name) - map_entry_2.remove_out_connector(out_conn_name) + second_map_entry.remove_in_connector(in_conn_name) + second_map_entry.remove_out_connector(out_conn_name) if is_exclusive_set: # In exclusive mode the old intermediate node is no longer needed. @@ -967,41 +973,88 @@ def handle_intermediate_set( state.remove_node(inter_node) state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) - map_exit_1.remove_out_connector(out_edge.src_conn) + first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) + first_map_exit.remove_out_connector(out_edge.src_conn) del sdfg.arrays[inter_name] else: + # TODO(phimuell): Lift this restriction + assert pre_exit_edge.data.data == inter_name + # This is the shared mode, so we have to recreate the intermediate # node, but this time it is at the exit of the second map. state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) # This is the Memlet that goes from the map internal intermediate # temporary node to the Map output. This will essentially restore # or preserve the output for the intermediate node. It is important # that we use the data that `preExitEdge` was used. final_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) - assert pre_exit_edge.data.data == inter_name final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc) - new_pre_exit_conn = map_exit_2.next_connector() + new_pre_exit_conn = second_map_exit.next_connector() state.add_edge( new_inter_node, None, - map_exit_2, + second_map_exit, "IN_" + new_pre_exit_conn, final_pre_exit_memlet, ) state.add_edge( - map_exit_2, + second_map_exit, "OUT_" + new_pre_exit_conn, inter_node, out_edge.dst_conn, copy.deepcopy(out_edge.data), ) - map_exit_2.add_in_connector("IN_" + new_pre_exit_conn) - map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn) + second_map_exit.add_in_connector("IN_" + new_pre_exit_conn) + second_map_exit.add_out_connector("OUT_" + new_pre_exit_conn) - map_exit_1.remove_out_connector(out_edge.src_conn) + first_map_exit.remove_out_connector(out_edge.src_conn) state.remove_edge(out_edge) + + def compute_reduced_intermediate( + self, + producer_subset: subsets.Range, + inter_desc: dace.data.Data, + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], List[int]]: + """Compute the size of the new (reduced) intermediate. + + `MapFusion` does not only fuses map, but, depending on the situation, also + eliminates intermediate arrays between the two maps. To transmit data between + the two maps a new, but much smaller intermediate is needed. + + :return: The function returns a tuple with three values with the following meaning: + * The raw shape of the reduced intermediate. + * The cleared shape of the reduced intermediate, essentially the raw shape + with all shape 1 dimensions removed. + * Which dimensions of the raw shape have been removed to get the cleared shape. + + :param producer_subset: The subset that was used to write into the intermediate. + :param inter_desc: The data descriptor for the intermediate. + """ + assert producer_subset is not None + + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + new_inter_shape_raw = symbolic.overapproximate(producer_subset.size()) + inter_shape = inter_desc.shape + if not self.strict_dataflow: + squeezed_dims: List[int] = [] # These are the dimensions we removed. + new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate( + zip(new_inter_shape_raw, inter_shape, strict=True) + ): + if full_dim_size == 1: # Must be kept! + new_inter_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_inter_shape.append(proposed_dim_size) + else: + squeezed_dims = [] + new_inter_shape = list(new_inter_shape_raw) + + return (tuple(new_inter_shape_raw), tuple(new_inter_shape), squeezed_dims) From 4b566d7f75369ddd2fd56d67def59e58973bd777 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 7 Feb 2025 20:56:15 +0100 Subject: [PATCH 136/178] ci[cartesian]: Thread safe parallel stencil tests (#1849) ## Description To avoid repeating boiler plate code in testing, `StencilTestSuite` provides a convenient interace to test gtscript stencils. Within that `StencilTestSuite` base class, generating the stencil is separated from running & validating the stencil code. Each deriving test class will end up with two tests: one for stencil generation and a second one to test the implementation by running the generated code with defined inputs and expected outputs. The base class was written such that the implementation test would re-use the generated stencil code from the first test. This introduces an implicit test order dependency. To save time and avoid unnecessary test failure outputs, failing to generate the stencil code would automatically skip the implementation/validation test. Running tests in parallel (with `xdist`) breaks the expected test execution order (in the default configuration). This leads to automatically skiped validation tests in case the stencil code wasn't generated yet. On the CI, we only run with 2 threads so only a couple tests were skipped usually. Locally, I was running with 16 threads and got ~30 skipped validation tests. This PR proposes to address the issue by setting an `xdist_group` mark on the generation/implementation tests that belong togehter. In combination with `--dist loadgroup`, this will keep the expected order where necessary. Only tests with `xdist_group` markers are affected by `--dist loadgroup`. Tests without that marker will be distributed normally as if in `--dist load` mode (the default so far). By grouping with `cls_name` and backend, we keep maximal parallelization, grouping only the two tests that are depending on each other. Further reading: see [`--dist` section](https://pytest-xdist.readthedocs.io/en/stable/distribution.html) in `pytest-xdist` documentation. ## Requirements - [x] All fixes and/or new features come with corresponding tests. Existing tests are still green. No more skipped tests \o/ Works as expected locally - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- noxfile.py | 2 +- src/gt4py/cartesian/testing/suites.py | 56 +++++++++++++++++---------- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/noxfile.py b/noxfile.py index e119669e92..81b1354157 100644 --- a/noxfile.py +++ b/noxfile.py @@ -91,7 +91,7 @@ def test_cartesian( markers = " and ".join(codegen_settings["markers"] + device_settings["markers"]) session.run( - *f"pytest --cache-clear -sv -n {num_processes}".split(), + *f"pytest --cache-clear -sv -n {num_processes} --dist loadgroup".split(), *("-m", f"{markers}"), str(pathlib.Path("tests") / "cartesian_tests"), *session.posargs, diff --git a/src/gt4py/cartesian/testing/suites.py b/src/gt4py/cartesian/testing/suites.py index f680a1dbef..423f834f51 100644 --- a/src/gt4py/cartesian/testing/suites.py +++ b/src/gt4py/cartesian/testing/suites.py @@ -167,7 +167,7 @@ def get_globals_combinations(dtypes): generation_strategy=composite_strategy_factory( d, generation_strategy_factories ), - implementations=[], + implementation=None, test_id=len(cls_dict["tests"]), definition=annotate_function( function=cls_dict["definition"], @@ -199,14 +199,19 @@ def hyp_wrapper(test_hyp, hypothesis_data): for test in cls_dict["tests"]: if test["suite"] == cls_name: - marks = test["marks"] - if gt4pyc.backend.from_name(test["backend"]).storage_info["device"] == "gpu": - marks.append(pytest.mark.requires_gpu) name = test["backend"] name += "".join(f"_{key}_{value}" for key, value in test["constants"].items()) name += "".join( "_{}_{}".format(key, value.name) for key, value in test["dtypes"].items() ) + + marks = test["marks"].copy() + if gt4pyc.backend.from_name(test["backend"]).storage_info["device"] == "gpu": + marks.append(pytest.mark.requires_gpu) + # Run generation and implementation tests in the same group to ensure + # (thread-) safe parallelization of stencil tests. + marks.append(pytest.mark.xdist_group(name=f"{cls_name}_{name}")) + param = pytest.param(test, marks=marks, id=name) pytest_params.append(param) @@ -228,14 +233,19 @@ def hyp_wrapper(test_hyp, hypothesis_data): runtime_pytest_params = [] for test in cls_dict["tests"]: if test["suite"] == cls_name: - marks = test["marks"] - if gt4pyc.backend.from_name(test["backend"]).storage_info["device"] == "gpu": - marks.append(pytest.mark.requires_gpu) name = test["backend"] name += "".join(f"_{key}_{value}" for key, value in test["constants"].items()) name += "".join( "_{}_{}".format(key, value.name) for key, value in test["dtypes"].items() ) + + marks = test["marks"].copy() + if gt4pyc.backend.from_name(test["backend"]).storage_info["device"] == "gpu": + marks.append(pytest.mark.requires_gpu) + # Run generation and implementation tests in the same group to ensure + # (thread-) safe parallelization of stencil tests. + marks.append(pytest.mark.xdist_group(name=f"{cls_name}_{name}")) + runtime_pytest_params.append( pytest.param( test, @@ -434,8 +444,11 @@ class StencilTestSuite(metaclass=SuiteMeta): def _test_generation(cls, test, externals_dict): """Test source code generation for all *backends* and *stencil suites*. - The generated implementations are cached in a :class:`utils.ImplementationsDB` - instance, to avoid duplication of (potentially expensive) compilations. + The generated implementation is cached in the test context, to avoid duplication + of (potentially expensive) compilation. + Note: This caching introduces a dependency between tests, which is captured by an + `xdist_group` marker in combination with `--dist loadgroup` to ensure safe parallel + test execution. """ backend_slug = gt_utils.slugify(test["backend"], valid_symbols="") implementation = gtscript.stencil( @@ -461,7 +474,8 @@ def _test_generation(cls, test, externals_dict): or ax == "K" or field_info.boundary[i] >= cls.global_boundaries[name][i] ) - test["implementations"].append(implementation) + assert test["implementation"] is None + test["implementation"] = implementation @classmethod def _run_test_implementation(cls, parameters_dict, implementation): # too complex @@ -585,16 +599,16 @@ def _run_test_implementation(cls, parameters_dict, implementation): # too compl def _test_implementation(cls, test, parameters_dict): """Test computed values for implementations generated for all *backends* and *stencil suites*. - The generated implementations are reused from previous tests by means of a - :class:`utils.ImplementationsDB` instance shared at module scope. + The generated implementation was cached in the test context, to avoid duplication + of (potentially expensive) compilation. + Note: This caching introduces a dependency between tests, which is captured by an + `xdist_group` marker in combination with `--dist loadgroup` to ensure safe parallel + test execution. """ - implementation_list = test["implementations"] - if not implementation_list: - pytest.skip( - "Cannot perform validation tests, since there are no valid implementations." - ) - for implementation in implementation_list: - if not isinstance(implementation, StencilObject): - raise RuntimeError("Wrong function got from implementations_db cache!") + implementation = test["implementation"] + assert ( + implementation is not None + ), "Stencil implementation not found. This usually means code generation failed." + assert isinstance(implementation, StencilObject) - cls._run_test_implementation(parameters_dict, implementation) + cls._run_test_implementation(parameters_dict, implementation) From 05b3b8737f4e6b177822ddf89564d7cf488d73b6 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 10 Feb 2025 15:15:36 +0100 Subject: [PATCH 137/178] tests[cartesian]: Increase horizontal region test coverage (#1851) ## Description A a test case where horizontal regions are used to write to a subset of the field. Note that this is different from `TestHorizontalRegionsCorners` because we have specialized optimizations in place for corners of the cube. Related issue: https://github.com/GridTools/gt4py/issues/720 ## Requirements - [x] All fixes and/or new features come with corresponding tests. New tests added - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- .../multi_feature_tests/test_suites.py | 41 ++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py index 0312aea7c3..b01a12fc7f 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py @@ -876,7 +876,13 @@ def validation(field_in, field_out, *, domain, origin): field_out[:, -1, :] = field_in[:, -1, :] - 1.0 -class TestHorizontalRegionsCorners(gt_testing.StencilTestSuite): +class TestHorizontalRegionsPartialWrites(gt_testing.StencilTestSuite): + """Use horizontal regions to only write to certain parts of the field. + + This test is different from the corner case below because the corner + case follows a different code path (we have specific optimizations for + them).""" + dtypes = {"field_in": np.float32, "field_out": np.float32} domain_range = [(4, 4), (4, 4), (2, 2)] backends = ALL_BACKENDS @@ -885,8 +891,40 @@ class TestHorizontalRegionsCorners(gt_testing.StencilTestSuite): in_range=(-10, 10), axes="IJK", boundary=[(0, 0), (0, 0), (0, 0)] ), "field_out": gt_testing.field( + in_range=(42, 42), axes="IJK", boundary=[(0, 0), (0, 0), (0, 0)] + ), + } + + def definition(field_in, field_out): + with computation(PARALLEL), interval(...): + with horizontal(region[I[0], :], region[I[-1], :]): + field_out = ( # noqa: F841 [unused-variable] + field_in + 1.0 + ) + with horizontal(region[:, J[0]], region[:, J[-1]]): + field_out = ( # noqa: F841 [unused-variable] + field_in - 1.0 + ) + + def validation(field_in, field_out, *, domain, origin): + field_out[:, :, :] = 42 + field_out[0, :, :] = field_in[0, :, :] + 1.0 + field_out[-1, :, :] = field_in[-1, :, :] + 1.0 + field_out[:, 0, :] = field_in[:, 0, :] - 1.0 + field_out[:, -1, :] = field_in[:, -1, :] - 1.0 + + +class TestHorizontalRegionsCorners(gt_testing.StencilTestSuite): + dtypes = {"field_in": np.float32, "field_out": np.float32} + domain_range = [(4, 4), (4, 4), (2, 2)] + backends = ALL_BACKENDS + symbols = { + "field_in": gt_testing.field( in_range=(-10, 10), axes="IJK", boundary=[(0, 0), (0, 0), (0, 0)] ), + "field_out": gt_testing.field( + in_range=(42, 42), axes="IJK", boundary=[(0, 0), (0, 0), (0, 0)] + ), } def definition(field_in, field_out): @@ -901,6 +939,7 @@ def definition(field_in, field_out): ) def validation(field_in, field_out, *, domain, origin): + field_out[:, :, :] = 42 field_out[0:2, 0:2, :] = field_in[0:2, 0:2, :] + 1.0 field_out[-3:-1, -3:-1, :] = field_in[-3:-1, -3:-1, :] + 1.0 field_out[0:2, -3:-1, :] = field_in[0:2, -3:-1, :] - 1.0 From 3be506415ac97121d3756ffe0549afd194dbcd38 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 11 Feb 2025 09:12:51 +0100 Subject: [PATCH 138/178] test[next]: enable gpu test for 1d scan on dace backend (#1854) The dace backend did not support the case of scan on a 1D vertical array, when lowered to GPU code. After upgrading the dace package to latest main and adopting the `LoopRegion` construct for the lowering of scan, this case is no longer an issue. This PR just re-enables the scan test case. --- pyproject.toml | 1 - tests/next_tests/definitions.py | 13 ++----------- .../feature_tests/ffront_tests/test_execution.py | 1 - 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4a9071e9d4..b512c6c93e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -310,7 +310,6 @@ markers = [ 'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields', 'uses_scalar_in_domain_and_fo', 'uses_scan: tests that uses scan', - 'uses_scan_1d_field: that that uses scan on 1D vertical field', 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', 'uses_scan_in_stencil: tests that require backend support for scan in stencil', 'uses_scan_without_field_args: tests that require calls to scan that do not have any fields as arguments', diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index a96d967430..b412c0c273 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -111,7 +111,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_SCAN_WITHOUT_FIELD_ARGS = "uses_scan_without_field_args" USES_SCAN_NESTED = "uses_scan_nested" USES_SCAN_REQUIRING_PROJECTOR = "uses_scan_requiring_projector" -USES_SCAN_1D_FIELD = "uses_scan_1d_field" USES_SPARSE_FIELDS = "uses_sparse_fields" USES_SPARSE_FIELDS_AS_OUTPUT = "uses_sparse_fields_as_output" USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields" @@ -190,17 +189,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST - + [ - # dace issue https://github.com/spcl/dace/issues/1773 - (USES_SCAN_1D_FIELD, XFAIL, UNSUPPORTED_MESSAGE), - ], + OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST - + [ - # dace issue https://github.com/spcl/dace/issues/1773 - (USES_SCAN_1D_FIELD, XFAIL, UNSUPPORTED_MESSAGE), - ], + OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 95bde32107..f02fdf4cc4 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -821,7 +821,6 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: @pytest.mark.uses_scan -@pytest.mark.uses_scan_1d_field def test_ternary_scan(cartesian_case): @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def simple_scan_operator(carry: float, a: float) -> float: From c96e19e13e507448b2cf12c07a5c8a67987c8a35 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 11 Feb 2025 12:52:07 +0100 Subject: [PATCH 139/178] build: upgrade gt4py to numpy 2.x (#1852) Upgrade dace to latest main version to include the support for numpy 2. All gt4py cartesian and next tests are compatible with both numpy 1.x and 2.x, except cartesian-dace that is limited by dace v1 to numpy < 2. --- src/gt4py/next/ffront/func_to_foast.py | 5 +- src/gt4py/next/iterator/embedded.py | 3 +- src/gt4py/storage/allocators.py | 5 +- src/gt4py/storage/cartesian/utils.py | 8 +- .../feature_tests/test_field_layouts.py | 8 +- .../stencil_definitions.py | 2 +- .../test_code_generation.py | 64 +-- .../multi_feature_tests/test_suites.py | 38 +- .../backend_tests/test_module_generator.py | 2 +- .../frontend_tests/test_gtscript_frontend.py | 108 ++--- .../ffront_tests/test_execution.py | 11 +- .../iterator_tests/test_vertical_advection.py | 10 +- uv.lock | 400 +++++++++++------- 13 files changed, 395 insertions(+), 269 deletions(-) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index ebe12d3a8b..ef20b99d91 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -60,8 +60,9 @@ def func_to_foast(inp: DSL_FOP) -> FOP: >>> print(foast_definition.foast_node.id) dsl_operator - >>> print(foast_definition.closure_vars) - {'const': 2.0} + >>> foast_closure_vars = {k: str(v) for k, v in foast_definition.closure_vars.items()} + >>> print(foast_closure_vars) + {'const': '2.0'} """ source_def = source_utils.SourceDefinition.from_function(inp.definition) closure_vars = source_utils.get_closure_vars_from_function(inp.definition) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 16b1fa9d03..da0516d26b 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -16,7 +16,6 @@ import itertools import math import operator -import sys import warnings import numpy as np @@ -677,7 +676,7 @@ def __float__(self): return np.nan def __int__(self): - return sys.maxsize + return np.iinfo(np.int32).max def __repr__(self): return "_UNDEFINED" diff --git a/src/gt4py/storage/allocators.py b/src/gt4py/storage/allocators.py index 298b9c2e5a..e2311e3e60 100644 --- a/src/gt4py/storage/allocators.py +++ b/src/gt4py/storage/allocators.py @@ -211,9 +211,10 @@ def allocate( # Compute the padding required in the contiguous dimension to get aligned blocks dims_layout = [layout_map.index(i) for i in range(len(shape))] - padded_shape_lst = list(shape) + # Convert shape size to same data type (note that `np.int16` can overflow) + padded_shape_lst = [np.int32(x) for x in shape] if ndim > 0: - padded_shape_lst[dims_layout[-1]] = ( + padded_shape_lst[dims_layout[-1]] = ( # type: ignore[call-overload] math.ceil(shape[dims_layout[-1]] / items_per_aligned_block) * items_per_aligned_block ) diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index 50500e536b..bd89c85052 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -12,10 +12,11 @@ import functools import math import numbers -from typing import Any, Final, Literal, Optional, Sequence, Tuple, Union, cast +from typing import Final, Literal, Optional, Sequence, Tuple, Union, cast import numpy as np import numpy.typing as npt +from numpy.typing import DTypeLike from gt4py._core import definitions as core_defs from gt4py.cartesian import config as gt_config @@ -23,11 +24,6 @@ from gt4py.storage import allocators -if np.lib.NumpyVersion(np.__version__) >= "1.20.0": - from numpy.typing import DTypeLike -else: - DTypeLike = Any # type: ignore[misc] # assign multiple types in both branches - try: import cupy as cp except ImportError: diff --git a/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py b/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py index c1b4e58f97..c3bf40e456 100644 --- a/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py +++ b/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py @@ -26,8 +26,8 @@ def test_numpy_allocators(backend, order): xp = get_array_library(backend) shape = (20, 10, 5) - inp = xp.array(xp.random.randn(*shape), order=order, dtype=xp.float_) - outp = xp.zeros(shape=shape, order=order, dtype=xp.float_) + inp = xp.array(xp.random.randn(*shape), order=order, dtype=xp.float64) + outp = xp.zeros(shape=shape, order=order, dtype=xp.float64) stencil = gtscript.stencil(definition=copy_stencil, backend=backend) stencil(field_a=inp, field_b=outp) @@ -43,8 +43,8 @@ def test_bad_layout_warns(backend): shape = (10, 10, 10) - inp = xp.array(xp.random.randn(*shape), dtype=xp.float_) - outp = gt_storage.zeros(backend=backend, shape=shape, dtype=xp.float_, aligned_index=(0, 0, 0)) + inp = xp.array(xp.random.randn(*shape), dtype=xp.float64) + outp = gt_storage.zeros(backend=backend, shape=shape, dtype=xp.float64, aligned_index=(0, 0, 0)) # set up non-optimal storage layout: if backend_cls.storage_info["is_optimal_layout"](inp, "IJK"): diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py index e1d9a0061a..8112866092 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py @@ -61,7 +61,7 @@ def _register_decorator(actual_func): return _register_decorator(func) if func else _register_decorator -Field3D = gtscript.Field[np.float_] +Field3D = gtscript.Field[np.float64] Field3DBool = gtscript.Field[np.bool_] diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 000dc34c7f..8ace0de740 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -59,7 +59,7 @@ def test_generation(name, backend): @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_lazy_stencil(backend): @gtscript.lazy_stencil(backend=backend) - def definition(field_a: gtscript.Field[np.float_], field_b: gtscript.Field[np.float_]): + def definition(field_a: gtscript.Field[np.float64], field_b: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): field_a = field_b @@ -67,7 +67,7 @@ def definition(field_a: gtscript.Field[np.float_], field_b: gtscript.Field[np.fl @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_temporary_field_declared_in_if(backend): @gtscript.stencil(backend=backend) - def definition(field_a: gtscript.Field[np.float_]): + def definition(field_a: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): if field_a < 0: field_b = -field_a @@ -79,7 +79,7 @@ def definition(field_a: gtscript.Field[np.float_]): @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_stage_without_effect(backend): @gtscript.stencil(backend=backend) - def definition(field_a: gtscript.Field[np.float_]): + def definition(field_a: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): field_c = 0.0 @@ -87,11 +87,11 @@ def definition(field_a: gtscript.Field[np.float_]): def test_ignore_np_errstate(): def setup_and_run(backend, **kwargs): field_a = gt_storage.zeros( - dtype=np.float_, backend=backend, shape=(3, 3, 1), aligned_index=(0, 0, 0) + dtype=np.float64, backend=backend, shape=(3, 3, 1), aligned_index=(0, 0, 0) ) @gtscript.stencil(backend=backend, **kwargs) - def divide_by_zero(field_a: gtscript.Field[np.float_]): + def divide_by_zero(field_a: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): field_a = 1.0 / field_a @@ -106,11 +106,11 @@ def divide_by_zero(field_a: gtscript.Field[np.float_]): @pytest.mark.parametrize("backend", CPU_BACKENDS) def test_stencil_without_effect(backend): - def definition1(field_in: gtscript.Field[np.float_]): + def definition1(field_in: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): tmp = 0.0 - def definition2(f_in: gtscript.Field[np.float_]): + def definition2(f_in: gtscript.Field[np.float64]): from __externals__ import flag with computation(PARALLEL), interval(...): @@ -121,7 +121,7 @@ def definition2(f_in: gtscript.Field[np.float_]): stencil2 = gtscript.stencil(backend, definition2, externals={"flag": False}) field_in = gt_storage.ones( - dtype=np.float_, backend=backend, shape=(23, 23, 23), aligned_index=(0, 0, 0) + dtype=np.float64, backend=backend, shape=(23, 23, 23), aligned_index=(0, 0, 0) ) # test with explicit domain specified @@ -135,14 +135,14 @@ def definition2(f_in: gtscript.Field[np.float_]): @pytest.mark.parametrize("backend", CPU_BACKENDS) def test_stage_merger_induced_interval_block_reordering(backend): field_in = gt_storage.ones( - dtype=np.float_, backend=backend, shape=(23, 23, 23), aligned_index=(0, 0, 0) + dtype=np.float64, backend=backend, shape=(23, 23, 23), aligned_index=(0, 0, 0) ) field_out = gt_storage.zeros( - dtype=np.float_, backend=backend, shape=(23, 23, 23), aligned_index=(0, 0, 0) + dtype=np.float64, backend=backend, shape=(23, 23, 23), aligned_index=(0, 0, 0) ) @gtscript.stencil(backend=backend) - def stencil(field_in: gtscript.Field[np.float_], field_out: gtscript.Field[np.float_]): + def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): with computation(BACKWARD): with interval(-2, -1): # block 1 field_out = field_in @@ -164,9 +164,9 @@ def stencil(field_in: gtscript.Field[np.float_], field_out: gtscript.Field[np.fl def test_lower_dimensional_inputs(backend): @gtscript.stencil(backend=backend) def stencil( - field_3d: gtscript.Field[gtscript.IJK, np.float_], - field_2d: gtscript.Field[gtscript.IJ, np.float_], - field_1d: gtscript.Field[gtscript.K, np.float_], + field_3d: gtscript.Field[gtscript.IJK, np.float64], + field_2d: gtscript.Field[gtscript.IJ, np.float64], + field_1d: gtscript.Field[gtscript.K, np.float64], ): with computation(PARALLEL): with interval(0, -1): @@ -219,9 +219,9 @@ def stencil( def test_lower_dimensional_masked(backend): @gtscript.stencil(backend=backend) def copy_2to3( - cond: gtscript.Field[gtscript.IJK, np.float_], - inp: gtscript.Field[gtscript.IJ, np.float_], - outp: gtscript.Field[gtscript.IJK, np.float_], + cond: gtscript.Field[gtscript.IJK, np.float64], + inp: gtscript.Field[gtscript.IJ, np.float64], + outp: gtscript.Field[gtscript.IJK, np.float64], ): with computation(PARALLEL), interval(...): if cond > 0.0: @@ -250,9 +250,9 @@ def copy_2to3( def test_lower_dimensional_masked_2dcond(backend): @gtscript.stencil(backend=backend) def copy_2to3( - cond: gtscript.Field[gtscript.IJK, np.float_], - inp: gtscript.Field[gtscript.IJ, np.float_], - outp: gtscript.Field[gtscript.IJK, np.float_], + cond: gtscript.Field[gtscript.IJK, np.float64], + inp: gtscript.Field[gtscript.IJ, np.float64], + outp: gtscript.Field[gtscript.IJK, np.float64], ): with computation(FORWARD), interval(...): if cond > 0.0: @@ -281,8 +281,8 @@ def copy_2to3( def test_lower_dimensional_inputs_2d_to_3d_forward(backend): @gtscript.stencil(backend=backend) def copy_2to3( - inp: gtscript.Field[gtscript.IJ, np.float_], - outp: gtscript.Field[gtscript.IJK, np.float_], + inp: gtscript.Field[gtscript.IJ, np.float64], + outp: gtscript.Field[gtscript.IJK, np.float64], ): with computation(FORWARD), interval(...): outp[0, 0, 0] = inp @@ -368,8 +368,8 @@ def stencil( def test_variable_offsets(backend): @gtscript.stencil(backend=backend) def stencil_ij( - in_field: gtscript.Field[np.float_], - out_field: gtscript.Field[np.float_], + in_field: gtscript.Field[np.float64], + out_field: gtscript.Field[np.float64], index_field: gtscript.Field[gtscript.IJ, int], ): with computation(FORWARD), interval(...): @@ -378,8 +378,8 @@ def stencil_ij( @gtscript.stencil(backend=backend) def stencil_ijk( - in_field: gtscript.Field[np.float_], - out_field: gtscript.Field[np.float_], + in_field: gtscript.Field[np.float64], + out_field: gtscript.Field[np.float64], index_field: gtscript.Field[int], ): with computation(PARALLEL), interval(...): @@ -390,10 +390,10 @@ def stencil_ijk( def test_variable_offsets_and_while_loop(backend): @gtscript.stencil(backend=backend) def stencil( - pe1: gtscript.Field[np.float_], - pe2: gtscript.Field[np.float_], - qin: gtscript.Field[np.float_], - qout: gtscript.Field[np.float_], + pe1: gtscript.Field[np.float64], + pe2: gtscript.Field[np.float64], + qin: gtscript.Field[np.float64], + qout: gtscript.Field[np.float64], lev: gtscript.Field[gtscript.IJ, np.int_], ): with computation(FORWARD), interval(0, -1): @@ -410,7 +410,7 @@ def stencil( @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_nested_while_loop(backend): @gtscript.stencil(backend=backend) - def stencil(field_a: gtscript.Field[np.float_], field_b: gtscript.Field[np.int_]): + def stencil(field_a: gtscript.Field[np.float64], field_b: gtscript.Field[np.int_]): with computation(PARALLEL), interval(...): while field_a < 1: add = 0 @@ -422,7 +422,7 @@ def stencil(field_a: gtscript.Field[np.float_], field_b: gtscript.Field[np.int_] @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_mask_with_offset_written_in_conditional(backend): @gtscript.stencil(backend) - def stencil(outp: gtscript.Field[np.float_]): + def stencil(outp: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): cond = True if cond[0, -1, 0] or cond[0, 0, 0]: diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py index b01a12fc7f..10d8999565 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py @@ -47,7 +47,7 @@ def validation(field_a, domain=None, origin=None): class TestCopy(gt_testing.StencilTestSuite): """Copy stencil.""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 25), (1, 25), (1, 25)] backends = ALL_BACKENDS symbols = dict( @@ -66,7 +66,7 @@ def validation(field_a, field_b, domain=None, origin=None): class TestAugAssign(gt_testing.StencilTestSuite): """Increment by one stencil.""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 25), (1, 25), (1, 25)] backends = ALL_BACKENDS symbols = dict( @@ -90,7 +90,7 @@ def validation(field_a, field_b, domain=None, origin=None): class TestGlobalScale(gt_testing.StencilTestSuite): """Scale stencil using a global global_name.""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -112,7 +112,7 @@ def validation(field_a, domain, origin, **kwargs): class TestParametricScale(gt_testing.StencilTestSuite): """Scale stencil using a parameter.""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -136,7 +136,7 @@ class TestParametricMix(gt_testing.StencilTestSuite): ("USE_ALPHA",): np.int_, ("field_a", "field_b", "field_c"): np.float64, ("field_out",): np.float32, - ("weight", "alpha_factor"): np.float_, + ("weight", "alpha_factor"): np.float64, } domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS @@ -177,7 +177,7 @@ def validation( class TestHeatEquation_FTCS_3D(gt_testing.StencilTestSuite): - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -206,7 +206,7 @@ def validation(u, v, u_new, v_new, *, ru, rv, domain, origin, **kwargs): class TestHorizontalDiffusion(gt_testing.StencilTestSuite): """Diffusion in a horizontal 2D plane .""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -270,7 +270,7 @@ def fwd_diff_op_y(field): class TestHorizontalDiffusionSubroutines(gt_testing.StencilTestSuite): """Diffusion in a horizontal 2D plane .""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -305,7 +305,7 @@ def validation(u, diffusion, *, weight, domain, origin, **kwargs): class TestHorizontalDiffusionSubroutines2(gt_testing.StencilTestSuite): """Diffusion in a horizontal 2D plane .""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -346,7 +346,7 @@ def validation(u, diffusion, *, weight, domain, origin, **kwargs): class TestRuntimeIfFlat(gt_testing.StencilTestSuite): """Tests runtime ifs.""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict(outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)])) @@ -365,7 +365,7 @@ def validation(outfield, *, domain, origin, **kwargs): class TestRuntimeIfNested(gt_testing.StencilTestSuite): """Tests nested runtime ifs.""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict(outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)])) @@ -391,7 +391,7 @@ def add_one(field_in): class Test3FoldNestedIf(gt_testing.StencilTestSuite): - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(3, 3), (3, 3), (3, 3)] backends = ALL_BACKENDS symbols = dict(field_a=gt_testing.field(in_range=(-1, 1), boundary=[(0, 0), (0, 0), (0, 0)])) @@ -411,7 +411,7 @@ def validation(field_a, domain, origin): class TestRuntimeIfNestedDataDependent(gt_testing.StencilTestSuite): - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(3, 3), (3, 3), (3, 3)] backends = ALL_BACKENDS symbols = dict( @@ -447,7 +447,7 @@ def validation(field_a, field_b, field_c, *, factor, domain, origin, **kwargs): class TestRuntimeIfNestedWhile(gt_testing.StencilTestSuite): """Test conditional while statements.""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -475,7 +475,7 @@ def validation(infield, outfield, *, domain, origin, **kwargs): class TestTernaryOp(gt_testing.StencilTestSuite): - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (2, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -496,7 +496,7 @@ def validation(infield, outfield, *, domain, origin, **kwargs): class TestThreeWayAnd(gt_testing.StencilTestSuite): - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (2, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -518,7 +518,7 @@ def validation(outfield, *, a, b, c, domain, origin, **kwargs): class TestThreeWayOr(gt_testing.StencilTestSuite): - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (2, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -540,7 +540,7 @@ def validation(outfield, *, a, b, c, domain, origin, **kwargs): class TestOptionalField(gt_testing.StencilTestSuite): - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 32), (1, 32), (1, 32)] backends = ALL_BACKENDS symbols = dict( @@ -568,7 +568,7 @@ class TestNotSpecifiedOptionalField(TestOptionalField): class TestTwoOptionalFields(gt_testing.StencilTestSuite): - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 32), (1, 32), (1, 32)] backends = ALL_BACKENDS symbols = dict( diff --git a/tests/cartesian_tests/unit_tests/backend_tests/test_module_generator.py b/tests/cartesian_tests/unit_tests/backend_tests/test_module_generator.py index 963b824122..8efc414458 100644 --- a/tests/cartesian_tests/unit_tests/backend_tests/test_module_generator.py +++ b/tests/cartesian_tests/unit_tests/backend_tests/test_module_generator.py @@ -36,7 +36,7 @@ def sample_builder(): @pytest.fixture def sample_args_data(): - dtype = np.dtype(np.float_) + dtype = np.dtype(np.float64) yield ModuleData( field_info={ "in_field": FieldInfo( diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py index e62f878746..1f7a779835 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py @@ -720,7 +720,7 @@ def definition_func(field: gtscript.Field[float]): class TestRegions: def test_one_interval_only(self): - def stencil(in_f: gtscript.Field[np.float_]): + def stencil(in_f: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...), horizontal(region[I[0:3], :]): in_f = 1.0 @@ -732,7 +732,7 @@ def stencil(in_f: gtscript.Field[np.float_]): assert isinstance(def_ir.computations[0].body.stmts[0], nodes.HorizontalIf) def test_one_interval_only_single(self): - def stencil(in_f: gtscript.Field[np.float_]): + def stencil(in_f: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...), horizontal(region[I[0], :]): in_f = 1.0 @@ -744,7 +744,7 @@ def stencil(in_f: gtscript.Field[np.float_]): assert def_ir.computations[0].body.stmts[0].intervals["I"].is_single_index def test_from_external(self): - def stencil(in_f: gtscript.Field[np.float_]): + def stencil(in_f: gtscript.Field[np.float64]): from gt4py.cartesian.__externals__ import i1 with computation(PARALLEL), interval(...), horizontal(region[i1, :]): @@ -766,7 +766,7 @@ def stencil(in_f: gtscript.Field[np.float_]): assert def_ir.computations[0].body.stmts[0].intervals["I"].is_single_index def test_multiple_inline(self): - def stencil(in_f: gtscript.Field[np.float_]): + def stencil(in_f: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_f = in_f + 1.0 with horizontal(region[I[0], :], region[:, J[-1]]): @@ -789,7 +789,7 @@ def region_func(): return field - def stencil(in_f: gtscript.Field[np.float_]): + def stencil(in_f: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_f = region_func() @@ -801,7 +801,7 @@ def stencil(in_f: gtscript.Field[np.float_]): ) def test_error_undefined(self): - def stencil(in_f: gtscript.Field[np.float_]): + def stencil(in_f: gtscript.Field[np.float64]): from gt4py.cartesian.__externals__ import i0 # forget to add 'ia' with computation(PARALLEL), interval(...): @@ -813,7 +813,7 @@ def stencil(in_f: gtscript.Field[np.float_]): parse_definition(stencil, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_error_nested(self): - def stencil(in_f: gtscript.Field[np.float_]): + def stencil(in_f: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_f = in_f + 1.0 with horizontal(region[I[0], :]): @@ -1054,9 +1054,9 @@ def definition(inout_field: gtscript.Field[float]): class TestReducedDimensions: def test_syntax(self): def definition_func( - field_3d: gtscript.Field[gtscript.IJK, np.float_], - field_2d: gtscript.Field[gtscript.IJ, np.float_], - field_1d: gtscript.Field[gtscript.K, np.float_], + field_3d: gtscript.Field[gtscript.IJK, np.float64], + field_2d: gtscript.Field[gtscript.IJ, np.float64], + field_1d: gtscript.Field[gtscript.K, np.float64], ): with computation(FORWARD), interval(...): field_2d = field_1d[1] @@ -1085,8 +1085,8 @@ def definition_func( def test_error_syntax(self): def definition( - field_in: gtscript.Field[gtscript.K, np.float_], - field_out: gtscript.Field[gtscript.IJK, np.float_], + field_in: gtscript.Field[gtscript.K, np.float64], + field_out: gtscript.Field[gtscript.IJK, np.float64], ): with computation(PARALLEL), interval(...): field_out = field_in[0, 0, 1] @@ -1099,8 +1099,8 @@ def definition( def test_error_write_1d(self): def definition( - field_in: gtscript.Field[gtscript.IJK, np.float_], - field_out: gtscript.Field[gtscript.K, np.float_], + field_in: gtscript.Field[gtscript.IJK, np.float64], + field_out: gtscript.Field[gtscript.K, np.float64], ): with computation(PARALLEL), interval(...): field_out = field_in[0, 0, 0] @@ -1113,10 +1113,10 @@ def definition( def test_higher_dim_temp(self): def definition( - field_in: gtscript.Field[gtscript.IJK, np.float_], - field_out: gtscript.Field[gtscript.IJK, np.float_], + field_in: gtscript.Field[gtscript.IJK, np.float64], + field_out: gtscript.Field[gtscript.IJK, np.float64], ): - tmp: Field[IJK, (np.float_, (2,))] = 0.0 + tmp: Field[IJK, (np.float64, (2,))] = 0.0 with computation(PARALLEL), interval(...): tmp[0, 0, 0][0] = field_in field_out = tmp[0, 0, 0][0] @@ -1125,10 +1125,10 @@ def definition( def test_typed_temp_missing(self): def definition( - field_in: gtscript.Field[gtscript.IJK, np.float_], - field_out: gtscript.Field[gtscript.IJK, np.float_], + field_in: gtscript.Field[gtscript.IJK, np.float64], + field_out: gtscript.Field[gtscript.IJK, np.float64], ): - tmp: Field[IJ, np.float_] = 0.0 + tmp: Field[IJ, np.float64] = 0.0 with computation(FORWARD), interval(1, None): tmp = field_in[0, 0, -1] field_out = tmp @@ -1143,9 +1143,9 @@ def definition( class TestDataDimensions: def test_syntax(self): def definition( - field_in: gtscript.Field[np.float_], - another_field: gtscript.Field[(np.float_, 3)], - field_out: gtscript.Field[gtscript.IJK, (np.float_, (3,))], + field_in: gtscript.Field[np.float64], + another_field: gtscript.Field[(np.float64, 3)], + field_out: gtscript.Field[gtscript.IJK, (np.float64, (3,))], ): with computation(PARALLEL), interval(...): field_out[0, 0, 0][0] = field_in @@ -1156,8 +1156,8 @@ def definition( def test_syntax_no_datadim(self): def definition( - field_in: gtscript.Field[np.float_], - field_out: gtscript.Field[gtscript.IJK, (np.float_, (3,))], + field_in: gtscript.Field[np.float64], + field_out: gtscript.Field[gtscript.IJK, (np.float64, (3,))], ): with computation(PARALLEL), interval(...): field_out[0, 0, 0][0] = field_in @@ -1169,8 +1169,8 @@ def definition( def test_syntax_out_bounds(self): def definition( - field_in: gtscript.Field[np.float_], - field_out: gtscript.Field[gtscript.IJK, (np.float_, (3,))], + field_in: gtscript.Field[np.float64], + field_out: gtscript.Field[gtscript.IJK, (np.float64, (3,))], ): with computation(PARALLEL), interval(...): field_out[0, 0, 0][3] = field_in[0, 0, 0] @@ -1180,8 +1180,8 @@ def definition( def test_indirect_access_read(self): def definition( - field_3d: gtscript.Field[np.float_], - field_4d: gtscript.Field[gtscript.IJK, (np.float_, (2,))], + field_3d: gtscript.Field[np.float64], + field_4d: gtscript.Field[gtscript.IJK, (np.float64, (2,))], variable: float, ): with computation(PARALLEL), interval(...): @@ -1194,8 +1194,8 @@ def definition( def test_indirect_access_write(self): def definition( - field_3d: gtscript.Field[np.float_], - field_4d: gtscript.Field[gtscript.IJK, (np.float_, (2,))], + field_3d: gtscript.Field[np.float64], + field_4d: gtscript.Field[gtscript.IJK, (np.float64, (2,))], variable: float, ): with computation(PARALLEL), interval(...): @@ -1381,14 +1381,14 @@ def test_literal_floating_parametrization(self, the_float): class TestAssignmentSyntax: def test_ellipsis(self): - def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64], out_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): out_field[...] = in_field parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_offset(self): - def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64], out_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): out_field[0, 0, 0] = in_field @@ -1397,15 +1397,15 @@ def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float with pytest.raises(gt_frontend.GTScriptSyntaxError): def func( - in_field: gtscript.Field[np.float_], - out_field: gtscript.Field[np.float_], + in_field: gtscript.Field[np.float64], + out_field: gtscript.Field[np.float64], ): with computation(PARALLEL), interval(...): out_field[0, 0, 1] = in_field parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) - def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64], out_field: gtscript.Field[np.float64]): from gt4py.cartesian.__externals__ import offset with computation(PARALLEL), interval(...): @@ -1471,8 +1471,8 @@ def test_slice(self): with pytest.raises(gt_frontend.GTScriptSyntaxError): def func( - in_field: gtscript.Field[np.float_], - out_field: gtscript.Field[np.float_], + in_field: gtscript.Field[np.float64], + out_field: gtscript.Field[np.float64], ): with computation(PARALLEL), interval(...): out_field[:, :, :] = in_field @@ -1483,8 +1483,8 @@ def test_string(self): with pytest.raises(gt_frontend.GTScriptSyntaxError): def func( - in_field: gtscript.Field[np.float_], - out_field: gtscript.Field[np.float_], + in_field: gtscript.Field[np.float64], + out_field: gtscript.Field[np.float64], ): with computation(PARALLEL), interval(...): out_field["a_key"] = in_field @@ -1492,7 +1492,7 @@ def func( parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_augmented(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field += 2.0 in_field -= 0.5 @@ -1589,7 +1589,7 @@ def data_dims_with_at( class TestNestedWithSyntax: def test_nested_with(self): - def definition(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_]): + def definition(in_field: gtscript.Field[np.float64], out_field: gtscript.Field[np.float64]): with computation(PARALLEL): with interval(...): in_field = out_field @@ -1598,7 +1598,7 @@ def definition(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np def test_nested_with_ordering(self): def definition_fw( - in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_] + in_field: gtscript.Field[np.float64], out_field: gtscript.Field[np.float64] ): from gt4py.cartesian.__gtscript__ import FORWARD, computation, interval @@ -1609,7 +1609,7 @@ def definition_fw( in_field = out_field + 2 def definition_bw( - in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_] + in_field: gtscript.Field[np.float64], out_field: gtscript.Field[np.float64] ): from gt4py.cartesian.__gtscript__ import FORWARD, computation, interval @@ -1633,35 +1633,35 @@ def definition_bw( class TestNativeFunctions: def test_simple_call(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field += sin(in_field) parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_offset_arg(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field += sin(in_field[1, 0, 0]) parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_nested_calls(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field += sin(abs(in_field)) parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_nested_external_call(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field += sin(add_external_const(in_field)) parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_multi_nested_calls(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field += min(abs(sin(add_external_const(in_field))), -0.5) @@ -1672,28 +1672,28 @@ def test_native_in_function(self): def sinus(field_in): return sin(field_in) - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field += sinus(in_field) parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_native_function_unary(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field = not isfinite(in_field) parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_native_function_binary(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field = asin(in_field) + 1 parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_native_function_ternary(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field = asin(in_field) + 1 if 1 < in_field else sin(in_field) @@ -1702,7 +1702,7 @@ def func(in_field: gtscript.Field[np.float_]): class TestWarnInlined: def test_inlined_emits_warning(self): - def func(field: gtscript.Field[np.float_]): + def func(field: gtscript.Field[np.float64]): from gt4py.cartesian.__externals__ import SET_TO_ONE with computation(PARALLEL), interval(...): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index f02fdf4cc4..d878d8d3ff 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -752,7 +752,16 @@ def expected(a, b, c, d): matrices[:, :, i[1:], i[:-1]] = a[:, :, 1:] matrices[:, :, i, i] = b matrices[:, :, i[:-1], i[1:]] = c[:, :, :-1] - return np.linalg.solve(matrices, d) + # Changed in NumPY version 2.0: In a linear matrix equation ax = b, the b array + # is only treated as a shape (M,) column vector if it is exactly 1-dimensional. + # In all other instances it is treated as a stack of (M, K) matrices. Therefore + # below we add an extra dimension (K) of size 1. Previously b would be treated + # as a stack of (M,) vectors if b.ndim was equal to a.ndim - 1. + # Refer to https://numpy.org/doc/2.0/reference/generated/numpy.linalg.solve.html + d_ext = np.empty(shape=(*shape, 1)) + d_ext[:, :, :, 0] = d + x = np.linalg.solve(matrices, d_ext) + return x[:, :, :, 0] cases.verify_with_default_data(cartesian_case, solve_tridiag, ref=expected) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index 30ceaf9376..e98e820f14 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -62,7 +62,13 @@ def tridiag_reference(): a = rng.normal(size=shape) b = rng.normal(size=shape) * 2 c = rng.normal(size=shape) - d = rng.normal(size=shape) + # Changed in NumPY version 2.0: In a linear matrix equation ax = b, the b array + # is only treated as a shape (M,) column vector if it is exactly 1-dimensional. + # In all other instances it is treated as a stack of (M, K) matrices. Therefore + # below we add an extra dimension (K) of size 1. Previously b would be treated + # as a stack of (M,) vectors if b.ndim was equal to a.ndim - 1. + # Refer to https://numpy.org/doc/2.0/reference/generated/numpy.linalg.solve.html + d = rng.normal(size=(*shape, 1)) matrices = np.zeros(shape + shape[-1:]) i = np.arange(shape[2]) @@ -70,7 +76,7 @@ def tridiag_reference(): matrices[:, :, i, i] = b matrices[:, :, i[:-1], i[1:]] = c[:, :, :-1] x = np.linalg.solve(matrices, d) - return a, b, c, d, x + return a, b, c, d[:, :, :, 0], x[:, :, :, 0] @fendef diff --git a/uv.lock b/uv.lock index 60f62028bd..c07a329b39 100644 --- a/uv.lock +++ b/uv.lock @@ -1,8 +1,8 @@ version = 1 requires-python = ">=3.10, <3.12" resolution-markers = [ - "python_full_version < '3.11'", "python_full_version >= '3.11'", + "python_full_version < '3.11'", ] conflicts = [[ { package = "gt4py", extra = "cuda11" }, @@ -44,6 +44,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, ] +[[package]] +name = "anyio" +version = "4.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "idna" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/73/199a98fc2dae33535d6b8e8e6ec01f8c1d76c9adb096c6b7d64823038cde/anyio-4.8.0.tar.gz", hash = "sha256:1d9fe889df5212298c0c0723fa20479d1b94883a2df44bd3897aa91083316f7a", size = 181126 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/eb/e7f063ad1fec6b3178a3cd82d1a3c4de82cccf283fc42746168188e1cdd5/anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a", size = 96041 }, +] + [[package]] name = "apeye" version = "1.4.1" @@ -152,23 +167,24 @@ wheels = [ [[package]] name = "babel" -version = "2.16.0" +version = "2.17.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/2a/74/f1bc80f23eeba13393b7222b11d95ca3af2c1e28edca18af487137eefed9/babel-2.16.0.tar.gz", hash = "sha256:d1f3554ca26605fe173f3de0c65f750f5a42f924499bf134de6423582298e316", size = 9348104 } +sdist = { url = "https://files.pythonhosted.org/packages/7d/6b/d52e42361e1aa00709585ecc30b3f9684b3ab62530771402248b1b1d6240/babel-2.17.0.tar.gz", hash = "sha256:0c54cffb19f690cdcc52a3b50bcbf71e07a808d1c80d549f2459b9d2cf0afb9d", size = 9951852 } wheels = [ - { url = "https://files.pythonhosted.org/packages/ed/20/bc79bc575ba2e2a7f70e8a1155618bb1301eaa5132a8271373a6903f73f8/babel-2.16.0-py3-none-any.whl", hash = "sha256:368b5b98b37c06b7daf6696391c3240c938b37767d4584413e8438c5c435fa8b", size = 9587599 }, + { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537 }, ] [[package]] name = "beautifulsoup4" -version = "4.12.3" +version = "4.13.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "soupsieve" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b3/ca/824b1195773ce6166d388573fc106ce56d4a805bd7427b624e063596ec58/beautifulsoup4-4.12.3.tar.gz", hash = "sha256:74e3d1928edc070d21748185c46e3fb33490f22f52a3addee9aee0f4f7781051", size = 581181 } +sdist = { url = "https://files.pythonhosted.org/packages/f0/3c/adaf39ce1fb4afdd21b611e3d530b183bb7759c9b673d60db0e347fd4439/beautifulsoup4-4.13.3.tar.gz", hash = "sha256:1bd32405dacc920b42b83ba01644747ed77456a65760e285fbc47633ceddaf8b", size = 619516 } wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/fe/e8c672695b37eecc5cbf43e1d0638d88d66ba3a44c4d321c796f4e59167f/beautifulsoup4-4.12.3-py3-none-any.whl", hash = "sha256:b80878c9f40111313e55da8ba20bdba06d8fa3969fc68304167741bbf9e082ed", size = 147925 }, + { url = "https://files.pythonhosted.org/packages/f9/49/6abb616eb3cbab6a7cca303dc02fdf3836de2e0b834bf966a7f5271a34d8/beautifulsoup4-4.13.3-py3-none-any.whl", hash = "sha256:99045d7d3f08f91f0d656bc9b7efbae189426cd913d830294a15eefa0ea4df16", size = 186015 }, ] [[package]] @@ -199,11 +215,11 @@ wheels = [ [[package]] name = "boltons" -version = "24.1.0" +version = "25.0.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/84/76/dfc34232b3e88634025563f52a430be0838182647c063f99569086922554/boltons-24.1.0.tar.gz", hash = "sha256:4a49b7d57ee055b83a458c8682a2a6f199d263a8aa517098bda9bab813554b87", size = 240916 } +sdist = { url = "https://files.pythonhosted.org/packages/63/54/71a94d8e02da9a865587fb3fff100cb0fc7aa9f4d5ed9ed3a591216ddcc7/boltons-25.0.0.tar.gz", hash = "sha256:e110fbdc30b7b9868cb604e3f71d4722dd8f4dcb4a5ddd06028ba8f1ab0b5ace", size = 246294 } wheels = [ - { url = "https://files.pythonhosted.org/packages/b8/96/e44606e60a0c005ac5f2a641960a93ca8f449ebdce7479f9bc4f10bead6d/boltons-24.1.0-py3-none-any.whl", hash = "sha256:a1776d47fdc387fb730fba1fe245f405ee184ee0be2fb447dd289773a84aed3b", size = 192196 }, + { url = "https://files.pythonhosted.org/packages/45/7f/0e961cf3908bc4c1c3e027de2794f867c6c89fb4916fc7dba295a0e80a2d/boltons-25.0.0-py3-none-any.whl", hash = "sha256:dc9fb38bf28985715497d1b54d00b62ea866eca3938938ea9043e254a3a6ca62", size = 194210 }, ] [[package]] @@ -217,10 +233,11 @@ wheels = [ [[package]] name = "bump-my-version" -version = "0.30.0" +version = "0.32.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, + { name = "httpx" }, { name = "pydantic" }, { name = "pydantic-settings" }, { name = "questionary" }, @@ -229,9 +246,9 @@ dependencies = [ { name = "tomlkit" }, { name = "wcmatch" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e7/4f/57eda33958c5820b462c4c262bc18dc374dca6312bbb63f95606172200cb/bump_my_version-0.30.0.tar.gz", hash = "sha256:d53e784c73abc4bb5759e296f510bc71878e1df078eb525542ec9291b5ceb195", size = 1062228 } +sdist = { url = "https://files.pythonhosted.org/packages/e7/8b/72f0cd91ca6e296b71b05d39fcfbcf365eebaa5679a863ce7bb4d9d8aad7/bump_my_version-0.32.0.tar.gz", hash = "sha256:e8d964d13ba3ab6c090a872d0b5094ecf8df7ae8052b09288ace00fc6647df27", size = 1028515 } wheels = [ - { url = "https://files.pythonhosted.org/packages/41/9b/965ad61f85cbde14694516b02dcd38ec0c5cf7132fe33a30fddb4d8b0803/bump_my_version-0.30.0-py3-none-any.whl", hash = "sha256:b0d683a1cb97fbc2f46adf8eb39ff1f0bdd72866c3583fe01f9837d6f031e5e3", size = 55257 }, + { url = "https://files.pythonhosted.org/packages/ab/67/92853455bb91f09cb1bb9d3a4993b2e5fda80d6c44c727eb93993dc1cc60/bump_my_version-0.32.0-py3-none-any.whl", hash = "sha256:7c807110bdd8ecc845019e68a050ff378d836effb116440ba7f4a8ad59652b63", size = 57572 }, ] [[package]] @@ -277,11 +294,11 @@ wheels = [ [[package]] name = "certifi" -version = "2024.12.14" +version = "2025.1.31" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0f/bd/1d41ee578ce09523c81a15426705dd20969f5abf006d1afe8aeff0dd776a/certifi-2024.12.14.tar.gz", hash = "sha256:b650d30f370c2b724812bee08008be0c4163b163ddaec3f2546c1caf65f191db", size = 166010 } +sdist = { url = "https://files.pythonhosted.org/packages/1c/ab/c9f1e32b7b1bf505bf26f0ef697775960db7932abeb7b516de930ba2705f/certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651", size = 167577 } wheels = [ - { url = "https://files.pythonhosted.org/packages/a5/32/8f6669fc4798494966bf446c8c4a162e0b5d893dff088afddf76414f70e1/certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56", size = 164927 }, + { url = "https://files.pythonhosted.org/packages/38/fc/bce832fd4fd99766c04d1ee0eead6b0ec6486fb100ae5e74c1d91292b982/certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe", size = 166393 }, ] [[package]] @@ -461,7 +478,8 @@ name = "contourpy" version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-5-gt4py-all' or extra == 'extra-5-gt4py-dace' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/25/c2/fc7193cc5383637ff390a712e88e4ded0452c9fbcf84abe3de5ea3df1866/contourpy-1.3.1.tar.gz", hash = "sha256:dfd97abd83335045a913e3bcc4a09c0ceadbe66580cf573fe961f4a825efa699", size = 13465753 } wheels = [ @@ -542,7 +560,8 @@ version = "13.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "fastrlock" }, - { name = "numpy" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/2a/1b/3afbaea2b78114c82b33ecc9affc79b7d9f4899945940b9b50790c93fd33/cupy_cuda11x-13.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ef854f0c63525d8163ab7af19f503d964de9dde0dd1cf9ea806a6ecb302cdce3", size = 109578634 }, @@ -559,7 +578,8 @@ version = "13.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "fastrlock" }, - { name = "numpy" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-5-gt4py-all' or extra == 'extra-5-gt4py-dace' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/34/60/dc268d1d9c5fdde4673a463feff5e9c70c59f477e647b54b501f65deef60/cupy_cuda12x-13.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:674488e990998042cc54d2486d3c37cae80a12ba3787636be5a10b9446dd6914", size = 103601326 }, @@ -576,7 +596,8 @@ version = "13.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "fastrlock" }, - { name = "numpy" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/f8/16/7fd4bc8a8f1a4697f76e52c13f348f284fcc5c37195efd7e4c5d0eb2b15c/cupy_rocm_4_3-13.3.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:fc6b93be093bcea8b820baed856b61efc5c8cb09b02ebdc890431655714366ad", size = 41259087 }, @@ -589,7 +610,8 @@ version = "13.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "fastrlock" }, - { name = "numpy" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-cuda11' and extra != 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/8d/2e/6e4ecd65f5158808a54ef75d90fc7a884afb55bd405c4a7dbc34bb4a8f96/cupy_rocm_5_0-13.3.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:d4c370441f7778b00f3ab80d6f0d669ea0215b6e96bbed9663ecce7ffce83fa9", size = 60056031 }, @@ -677,10 +699,10 @@ wheels = [ [[package]] name = "dace" version = "1.0.0" -source = { git = "https://github.com/spcl/dace?branch=main#118c1312961dc1146f43d5b15cde4b97e067d9cb" } +source = { git = "https://github.com/spcl/dace?branch=main#5097d6f1a4b6e1dc8e06be6eb4aa585a6c6e04f3" } resolution-markers = [ - "python_full_version < '3.11'", "python_full_version >= '3.11'", + "python_full_version < '3.11'", ] dependencies = [ { name = "aenum" }, @@ -688,7 +710,7 @@ dependencies = [ { name = "dill" }, { name = "fparser" }, { name = "networkx" }, - { name = "numpy" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" } }, { name = "packaging" }, { name = "ply" }, { name = "pyreadline", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, @@ -701,8 +723,8 @@ name = "dace" version = "1.0.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.11'", "python_full_version >= '3.11'", + "python_full_version < '3.11'", ] dependencies = [ { name = "aenum" }, @@ -710,7 +732,7 @@ dependencies = [ { name = "dill" }, { name = "fparser" }, { name = "networkx" }, - { name = "numpy" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" } }, { name = "packaging" }, { name = "ply" }, { name = "pyreadline", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, @@ -747,14 +769,14 @@ wheels = [ [[package]] name = "deepdiff" -version = "8.1.1" +version = "8.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "orderly-set" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/50/4b/ce2d3a36f77186d7dbca0f10b33e6a1c0eee390d9434960d2a14e2736b52/deepdiff-8.1.1.tar.gz", hash = "sha256:dd7bc7d5c8b51b5b90f01b0e2fe23c801fd8b4c6a7ee7e31c5a3c3663fcc7ceb", size = 433560 } +sdist = { url = "https://files.pythonhosted.org/packages/89/12/207d2ec96a526cf9d04fc2423ff9832e93b665e94b9d7c9b5198903e18a7/deepdiff-8.2.0.tar.gz", hash = "sha256:6ec78f65031485735545ffbe7a61e716c3c2d12ca6416886d5e9291fc76c46c3", size = 432573 } wheels = [ - { url = "https://files.pythonhosted.org/packages/66/f7/2df72b55635926872b947203aacbe7e1109a51929aec8ebfef8c4a348eb5/deepdiff-8.1.1-py3-none-any.whl", hash = "sha256:b0231fa3afb0f7184e82535f2b4a36636442ed21e94a0cf3aaa7982157e7ebca", size = 84655 }, + { url = "https://files.pythonhosted.org/packages/6c/13/d7dd6b8c297b1d5cfea4f1ebd678e68d90ab04b6613d005c0a7c506d11e1/deepdiff-8.2.0-py3-none-any.whl", hash = "sha256:5091f2cdfd372b1b9f6bfd8065ba323ae31118dc4e42594371b38c8bea3fd0a4", size = 83672 }, ] [[package]] @@ -889,15 +911,15 @@ wheels = [ [[package]] name = "faker" -version = "35.0.0" +version = "35.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "python-dateutil" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d5/18/86fe668976308d09e0178041c3756e646a1f5ddc676aa7fb0cf3cd52f5b9/faker-35.0.0.tar.gz", hash = "sha256:42f2da8cf561e38c72b25e9891168b1e25fec42b6b0b5b0b6cd6041da54af885", size = 1855098 } +sdist = { url = "https://files.pythonhosted.org/packages/6c/d9/c5bc5edaeea1a3a5da6e7f93a5c0bdd49e0740d8c4a1e7ea9515fd4da2ed/faker-35.2.0.tar.gz", hash = "sha256:28c24061780f83b45d9cb15a72b8f143b09d276c9ff52eb557744b7a89e8ba19", size = 1874908 } wheels = [ - { url = "https://files.pythonhosted.org/packages/b8/fe/40452fb1730b10afa34dfe016097b28baa070ad74a1c1a3512ebed438c08/Faker-35.0.0-py3-none-any.whl", hash = "sha256:926d2301787220e0554c2e39afc4dc535ce4b0a8d0a089657137999f66334ef4", size = 1894841 }, + { url = "https://files.pythonhosted.org/packages/4e/db/bab82efcf241dabc93ad65cebaf0f2332cb2827b55a5d3a6ef1d52fa2c29/Faker-35.2.0-py3-none-any.whl", hash = "sha256:609abe555761ff31b0e5e16f958696e9b65c9224a7ac612ac96bfc2b8f09fe35", size = 1917786 }, ] [[package]] @@ -944,27 +966,27 @@ wheels = [ [[package]] name = "fonttools" -version = "4.55.7" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/55/55/3b1566c6186a5e58a17a19ad63195f87c6ca4039ef10ff5318a1b9fc5639/fonttools-4.55.7.tar.gz", hash = "sha256:6899e3d97225a8218f525e9754da0376e1c62953a0d57a76c5abaada51e0d140", size = 3458372 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4f/5c/ce2fce845af9696d043ac912f15b9fac4b9002fcd9ff66b80aa513a6c43f/fonttools-4.55.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:c2680a3e6e2e2d104a7ea81fb89323e1a9122c23b03d6569d0768887d0d76e69", size = 2752048 }, - { url = "https://files.pythonhosted.org/packages/07/9b/f7f9409adcf22763263c6327d2d31d538babd9ad2d63d1732c9e85d60a78/fonttools-4.55.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a7831d16c95b60866772a15fdcc03772625c4bb6d858e0ad8ef3d6e48709b2ef", size = 2280495 }, - { url = "https://files.pythonhosted.org/packages/91/df/348cf4ff1becd63ed952e35e436de3f9fd3245edb74c070457b465c40a58/fonttools-4.55.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:833927d089e6585019f2c85e3f8f7d87733e3fe81cd704ebaca7afa27e2e7113", size = 4561947 }, - { url = "https://files.pythonhosted.org/packages/14/fe/48b808bdf14bb9467e4a5aaa8aa89f8aba9979d52be3f7f1962f065e933e/fonttools-4.55.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7858dc6823296a053d85b831fa8428781c6c6f06fca44582bf7b6b2ff32a9089", size = 4604618 }, - { url = "https://files.pythonhosted.org/packages/52/25/305d88761aa15a8b2761869a15db34c070e72756d166a163756c53d07b35/fonttools-4.55.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:05568a66b090ed9d79aefdce2ceb180bb64fc856961deaedc29f5ad51355ce2c", size = 4558896 }, - { url = "https://files.pythonhosted.org/packages/0c/0b/c6f7877611940ab75dbe50f035d16ca5ce6d9ff2e5e65b9c76da830286ff/fonttools-4.55.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2dbc08e227fbeb716776905a7bd3c4fc62c8e37c8ef7d481acd10cb5fde12222", size = 4728347 }, - { url = "https://files.pythonhosted.org/packages/43/2c/490223b8cfaeccdef3d8819945a455aa8cc57f12f49233a3d40556b739cc/fonttools-4.55.7-cp310-cp310-win32.whl", hash = "sha256:6eb93cbba484a463b5ee83f7dd3211905f27a3871d20d90fb72de84c6c5056e3", size = 2155437 }, - { url = "https://files.pythonhosted.org/packages/37/f8/ee47526b3f03596cbed9dc7f38519cb650e7769bf9365e04bd81ff4a5302/fonttools-4.55.7-cp310-cp310-win_amd64.whl", hash = "sha256:7ff8e606f905048dc91a55a06d994b68065bf35752ae199df54a9bf30013dcaa", size = 2199898 }, - { url = "https://files.pythonhosted.org/packages/07/cb/f1dd2e31553bd03dcb4eb3af1ac6acc7fe41f26067d1bba104005ec1bb04/fonttools-4.55.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:916e1d926823b4b3b3815c59fc79f4ed670696fdd5fd9a5e690a0503eef38f79", size = 2753201 }, - { url = "https://files.pythonhosted.org/packages/21/84/f9f82093789947547b4bc86242669cde816ef4d949b23f472e47e85f125d/fonttools-4.55.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b89da448e0073408d7b2c44935f9fdae4fdc93644899f99f6102ef883ecf083c", size = 2281418 }, - { url = "https://files.pythonhosted.org/packages/46/e1/e0398d2aa7bf5400c84650fc7d85708502289bb92a40f8090e6e71cfe315/fonttools-4.55.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:087ace2d06894ccdb03e6975d05da6bb9cec0c689b2a9983c059880e33a1464a", size = 4869132 }, - { url = "https://files.pythonhosted.org/packages/d4/2d/9d86cd653c758334285a5c95d1bc0a7f13b6a72fc674c6b33fef3b8e3f77/fonttools-4.55.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:775ed0700ee6f781436641f18a0c61b1846a8c1aecae6da6b395c4417e2cb567", size = 4898375 }, - { url = "https://files.pythonhosted.org/packages/48/ce/f49fccb7d9f7c9c6d239434fc48546a0b37a91ba8310c7bcd5127cfeb5f6/fonttools-4.55.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9ec71d0cc0242899f87e4c230ed0b22c7b8681f288fb80e3d81c2c54c5bd2c79", size = 4877574 }, - { url = "https://files.pythonhosted.org/packages/cc/85/afe73e96a1572ba0acc86e82d52554bf69f384b431acd7a15b8c3890833b/fonttools-4.55.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d4b1c5939c0521525f45522823508e6fad21175bca978583688ea3b3736e6625", size = 5045681 }, - { url = "https://files.pythonhosted.org/packages/b8/37/dc59bc5a2f049d39b62996c806c147ae2eee5316f047a37bcf4cb9dbc4ef/fonttools-4.55.7-cp311-cp311-win32.whl", hash = "sha256:23df0f1003abaf8a435543f59583fc247e7ae1b047ee2263510e0654a5f207e0", size = 2154302 }, - { url = "https://files.pythonhosted.org/packages/86/33/281989403a57945c7871df144af3512ad3d1cd223e025b08b7f377847e6d/fonttools-4.55.7-cp311-cp311-win_amd64.whl", hash = "sha256:82163d58b43eff6e2025a25c32905fdb9042a163cc1ff82dab393e7ffc77a7d5", size = 2200818 }, - { url = "https://files.pythonhosted.org/packages/7b/6d/304a16caf63a8c193ec387b1fae1cb10072a59d34549f2eefe7e3fa9f364/fonttools-4.55.7-py3-none-any.whl", hash = "sha256:3304dfcf9ca204dd0ef691a287bd851ddd8e8250108658c0677c3fdfec853a20", size = 1089677 }, +version = "4.55.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/24/de7e40adc99be2aa5adc6321bbdf3cf58dbe751b87343da658dd3fc7d946/fonttools-4.55.8.tar.gz", hash = "sha256:54d481d456dcd59af25d4a9c56b2c4c3f20e9620b261b84144e5950f33e8df17", size = 3458915 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/b8/82b3444cb081798eabb8397452ddf73680e623d7fdf9c575594a2240b8a2/fonttools-4.55.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d11600f5343092697d7434f3bf77a393c7ae74be206fe30e577b9a195fd53165", size = 2752288 }, + { url = "https://files.pythonhosted.org/packages/86/8f/9c5f2172e9f6dcf52bb6477bcd5a023d056114787c8184b683c34996f5a1/fonttools-4.55.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c96f2506ce1a0beeaa9595f9a8b7446477eb133f40c0e41fc078744c28149f80", size = 2280718 }, + { url = "https://files.pythonhosted.org/packages/c6/a6/b7cd7b54412bb7a27e282ee54459cae24524ad0eab6f81ead2a91d435287/fonttools-4.55.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b5f05ef72e846e9f49ccdd74b9da4309901a4248434c63c1ee9321adcb51d65", size = 4562177 }, + { url = "https://files.pythonhosted.org/packages/0e/16/eff3be24cecb9336639148c40507f949c193642d8369352af480597633fb/fonttools-4.55.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba45b637da80a262b55b7657aec68da2ac54b8ae7891cd977a5dbe5fd26db429", size = 4604843 }, + { url = "https://files.pythonhosted.org/packages/b5/95/737574364439cbcc5e6d4f3e000f15432141680ca8cb5c216b619a3d1cab/fonttools-4.55.8-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:edcffaeadba9a334c1c3866e275d7dd495465e7dbd296f688901bdbd71758113", size = 4559127 }, + { url = "https://files.pythonhosted.org/packages/5f/07/ea90834742f9b3e51a05f0f15f7c817eb7aab3d6ebf4f06c4626825ccb89/fonttools-4.55.8-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b9f9fce3c9b2196e162182ec5db8af8eb3acd0d76c2eafe9fdba5f370044e556", size = 4728575 }, + { url = "https://files.pythonhosted.org/packages/93/74/0c816d83cd2945a25aed592b0cb3c9ba32e8b259781bf41dc112204129d9/fonttools-4.55.8-cp310-cp310-win32.whl", hash = "sha256:f089e8da0990cfe2d67e81d9cf581ff372b48dc5acf2782701844211cd1f0eb3", size = 2155662 }, + { url = "https://files.pythonhosted.org/packages/78/bc/f5a24229edd8cdd7494f2099e1c62fca288dad4c8637ee62df04459db27e/fonttools-4.55.8-cp310-cp310-win_amd64.whl", hash = "sha256:01ea3901b0802fc5f9e854f5aeb5bc27770dd9dd24c28df8f74ba90f8b3f5915", size = 2200126 }, + { url = "https://files.pythonhosted.org/packages/0a/e3/834e0919b34b40a6a2895f533323231bba3b8f5ae22c19ab725b84cf84c0/fonttools-4.55.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:95f5a1d4432b3cea6571f5ce4f4e9b25bf36efbd61c32f4f90130a690925d6ee", size = 2753424 }, + { url = "https://files.pythonhosted.org/packages/b6/f9/9cf7fc04da85d37cfa1c287f0a25c274d6940dad259dbaa9fd796b87bd3c/fonttools-4.55.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3d20f152de7625a0008ba1513f126daaaa0de3b4b9030aa72dd5c27294992260", size = 2281635 }, + { url = "https://files.pythonhosted.org/packages/35/1f/25330293a5bb6bd50825725270c587c2b25c2694020a82d2c424d2fd5469/fonttools-4.55.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5a3ff5bb95fd5a3962b2754f8435e6d930c84fc9e9921c51e802dddf40acd56", size = 4869363 }, + { url = "https://files.pythonhosted.org/packages/f2/e0/e58b10ef50830145ba94dbeb64b70773af61cfccea663d485c7fae2aab65/fonttools-4.55.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b99d4fd2b6d0a00c7336c8363fccc7a11eccef4b17393af75ca6e77cf93ff413", size = 4898604 }, + { url = "https://files.pythonhosted.org/packages/e0/66/b59025011dbae1ea10dcb60f713a10e54d17cde5c8dc48db75af79dc2088/fonttools-4.55.8-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d637e4d33e46619c79d1a6c725f74d71b574cd15fb5bbb9b6f3eba8f28363573", size = 4877804 }, + { url = "https://files.pythonhosted.org/packages/67/76/abbbae972af55d54f83fcaeb90e26aaac937c8711b5a32d7c63768c37891/fonttools-4.55.8-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0f38bfb6b7a39c4162c3eb0820a0bdf8e3bdd125cd54e10ba242397d15e32439", size = 5045913 }, + { url = "https://files.pythonhosted.org/packages/8b/f2/5eb68b5202731b008ccfd4ad6d82af9a8abdec411609e76fdd6c43881f2c/fonttools-4.55.8-cp311-cp311-win32.whl", hash = "sha256:acfec948de41cd5e640d5c15d0200e8b8e7c5c6bb82afe1ca095cbc4af1188ee", size = 2154525 }, + { url = "https://files.pythonhosted.org/packages/42/d6/96dc2462006ffa16c8d475244e372abdc47d03a7bd38be0f29e7ae552af4/fonttools-4.55.8-cp311-cp311-win_amd64.whl", hash = "sha256:604c805b41241b4880e2dc86cf2d4754c06777371c8299799ac88d836cb18c3b", size = 2201043 }, + { url = "https://files.pythonhosted.org/packages/cc/e6/efdcd5d6858b951c29d56de31a19355579d826712bf390d964a21b076ddb/fonttools-4.55.8-py3-none-any.whl", hash = "sha256:07636dae94f7fe88561f9da7a46b13d8e3f529f87fdb221b11d85f91eabceeb7", size = 1089900 }, ] [[package]] @@ -1051,7 +1073,8 @@ dependencies = [ { name = "mako" }, { name = "nanobind" }, { name = "ninja" }, - { name = "numpy" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-5-gt4py-all' or extra == 'extra-5-gt4py-dace' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "packaging" }, { name = "pybind11" }, { name = "setuptools" }, @@ -1080,7 +1103,7 @@ dace = [ { name = "dace", version = "1.0.1", source = { registry = "https://pypi.org/simple" } }, ] dace-next = [ - { name = "dace", version = "1.0.0", source = { git = "https://github.com/spcl/dace?branch=main#118c1312961dc1146f43d5b15cde4b97e067d9cb" } }, + { name = "dace", version = "1.0.0", source = { git = "https://github.com/spcl/dace?branch=main#5097d6f1a4b6e1dc8e06be6eb4aa585a6c6e04f3" } }, ] formatting = [ { name = "clang-format" }, @@ -1317,6 +1340,15 @@ typing = [ { name = "types-tabulate", specifier = ">=0.8.10" }, ] +[[package]] +name = "h11" +version = "0.14.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/38/3af3d3633a34a3316095b39c8e8fb4853a28a536e55d347bd8d8e9a14b03/h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d", size = 100418 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259 }, +] + [[package]] name = "html5lib" version = "1.1" @@ -1330,18 +1362,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6c/dd/a834df6482147d48e225a49515aabc28974ad5a4ca3215c18a882565b028/html5lib-1.1-py2.py3-none-any.whl", hash = "sha256:0d78f8fde1c230e99fe37986a60526d7049ed4bf8a9fadbad5f00e22e58e041d", size = 112173 }, ] +[[package]] +name = "httpcore" +version = "1.0.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6a/41/d7d0a89eb493922c37d343b607bc1b5da7f5be7e383740b4753ad8943e90/httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c", size = 85196 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/f5/72347bc88306acb359581ac4d52f23c0ef445b57157adedb9aee0cd689d2/httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd", size = 78551 }, +] + +[[package]] +name = "httpx" +version = "0.28.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "certifi" }, + { name = "httpcore" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 }, +] + [[package]] name = "hypothesis" -version = "6.124.7" +version = "6.125.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "sortedcontainers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6a/ef/6e3736663ee67369f7f5b697674bfbd3efc91e7096ddd4452bbbc80065ff/hypothesis-6.124.7.tar.gz", hash = "sha256:8ed6c6ae47e7d26d869c1dc3dee04e8fc50c95240715bb9915ded88d6d920f0e", size = 416938 } +sdist = { url = "https://files.pythonhosted.org/packages/f9/69/3273c85add01293b0ed8fc71554cecb256c9e7826fa102c72cc847bb8bac/hypothesis-6.125.2.tar.gz", hash = "sha256:c70f0a12deb688ce90f2765a507070c4bff57e48ac86849f4350bbddc1df41a3", size = 417961 } wheels = [ - { url = "https://files.pythonhosted.org/packages/03/48/2412d4aacf1c50882126910ce036c92a838784915e3de66fb603a75c05ec/hypothesis-6.124.7-py3-none-any.whl", hash = "sha256:a6e1f66de84de3152d57f595a187a123ce3ecdea9dc8ef51ff8dcaa069137085", size = 479518 }, + { url = "https://files.pythonhosted.org/packages/3c/1b/e78605ce304554451a36c6e24e603cfcee808c9ed09be5112bf00a10eb5e/hypothesis-6.125.2-py3-none-any.whl", hash = "sha256:55d4966d521b85d2f77e916dabb00d66d5530ea9fbb89c7489ee810625fac802", size = 480692 }, ] [[package]] @@ -1415,7 +1475,7 @@ wheels = [ [[package]] name = "ipython" -version = "8.31.0" +version = "8.32.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, @@ -1430,9 +1490,9 @@ dependencies = [ { name = "traitlets" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/01/35/6f90fdddff7a08b7b715fccbd2427b5212c9525cd043d26fdc45bee0708d/ipython-8.31.0.tar.gz", hash = "sha256:b6a2274606bec6166405ff05e54932ed6e5cfecaca1fc05f2cacde7bb074d70b", size = 5501011 } +sdist = { url = "https://files.pythonhosted.org/packages/36/80/4d2a072e0db7d250f134bc11676517299264ebe16d62a8619d49a78ced73/ipython-8.32.0.tar.gz", hash = "sha256:be2c91895b0b9ea7ba49d33b23e2040c352b33eb6a519cca7ce6e0c743444251", size = 5507441 } wheels = [ - { url = "https://files.pythonhosted.org/packages/04/60/d0feb6b6d9fe4ab89fe8fe5b47cbf6cd936bfd9f1e7ffa9d0015425aeed6/ipython-8.31.0-py3-none-any.whl", hash = "sha256:46ec58f8d3d076a61d128fe517a51eb730e3aaf0c184ea8c17d16e366660c6a6", size = 821583 }, + { url = "https://files.pythonhosted.org/packages/e7/e1/f4474a7ecdb7745a820f6f6039dc43c66add40f1bcc66485607d93571af6/ipython-8.32.0-py3-none-any.whl", hash = "sha256:cae85b0c61eff1fc48b0a8002de5958b6528fa9c8defb1894da63f42613708aa", size = 825524 }, ] [[package]] @@ -1442,7 +1502,8 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxlib" }, { name = "ml-dtypes" }, - { name = "numpy" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-5-gt4py-all' or extra == 'extra-5-gt4py-dace' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "opt-einsum" }, { name = "scipy" }, ] @@ -1486,7 +1547,8 @@ version = "0.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ml-dtypes" }, - { name = "numpy" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-5-gt4py-all' or extra == 'extra-5-gt4py-dace' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "scipy" }, ] wheels = [ @@ -1666,14 +1728,14 @@ wheels = [ [[package]] name = "mako" -version = "1.3.8" +version = "1.3.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markupsafe" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5f/d9/8518279534ed7dace1795d5a47e49d5299dd0994eed1053996402a8902f9/mako-1.3.8.tar.gz", hash = "sha256:577b97e414580d3e088d47c2dbbe9594aa7a5146ed2875d4dfa9075af2dd3cc8", size = 392069 } +sdist = { url = "https://files.pythonhosted.org/packages/62/4f/ddb1965901bc388958db9f0c991255b2c469349a741ae8c9cd8a562d70a6/mako-1.3.9.tar.gz", hash = "sha256:b5d65ff3462870feec922dbccf38f6efb44e5714d7b593a656be86663d8600ac", size = 392195 } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/bf/7a6a36ce2e4cafdfb202752be68850e22607fccd692847c45c1ae3c17ba6/Mako-1.3.8-py3-none-any.whl", hash = "sha256:42f48953c7eb91332040ff567eb7eea69b22e7a4affbc5ba8e845e8f730f6627", size = 78569 }, + { url = "https://files.pythonhosted.org/packages/cd/83/de0a49e7de540513f53ab5d2e105321dedeb08a8f5850f0208decf4390ec/Mako-1.3.9-py3-none-any.whl", hash = "sha256:95920acccb578427a9aa38e37a186b1e43156c87260d7ba18ca63aa4c7cbd3a1", size = 78456 }, ] [[package]] @@ -1725,7 +1787,8 @@ dependencies = [ { name = "cycler" }, { name = "fonttools" }, { name = "kiwisolver" }, - { name = "numpy" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-5-gt4py-all' or extra == 'extra-5-gt4py-dace' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "packaging" }, { name = "pillow" }, { name = "pyparsing" }, @@ -1788,7 +1851,8 @@ name = "ml-dtypes" version = "0.5.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-5-gt4py-all' or extra == 'extra-5-gt4py-dace' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/32/49/6e67c334872d2c114df3020e579f3718c333198f8312290e09ec0216703a/ml_dtypes-0.5.1.tar.gz", hash = "sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9", size = 698772 } wheels = [ @@ -1909,11 +1973,11 @@ wheels = [ [[package]] name = "nanobind" -version = "2.4.0" +version = "2.5.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1e/01/a28722f6626e5c8a606dee71cb40c0b2ab9f7715b96bd34a9553c79dbf42/nanobind-2.4.0.tar.gz", hash = "sha256:a0392dee5f58881085b2ac8bfe8e53f74285aa4868b1472bfaf76cfb414e1c96", size = 953467 } +sdist = { url = "https://files.pythonhosted.org/packages/20/fa/8e5930837f9b08202c4e566cf529480b0c3266e88f39723388baf8c69700/nanobind-2.5.0.tar.gz", hash = "sha256:cc8412e94acffa20a369191382bcdbb6fbfb302e475e87cacff9516d51023a15", size = 962802 } wheels = [ - { url = "https://files.pythonhosted.org/packages/7a/07/abff41fcade3613349eac71dacb166352babef515efd960a751e3175c262/nanobind-2.4.0-py3-none-any.whl", hash = "sha256:8cf27b04fbadeb9deb4a73f02bd838bf9f7e3e5a8ce44c50c93142b5728da58a", size = 232882 }, + { url = "https://files.pythonhosted.org/packages/8e/9e/dadc3831f40e22c1b3925f07894646ada7906ef5b48db5c5eb2b03ca9faa/nanobind-2.5.0-py3-none-any.whl", hash = "sha256:e1e5c816e5d10f0b252d82ba7f769f0f6679f5e043cf406aec3d9e184bf2a60d", size = 236912 }, ] [[package]] @@ -2042,6 +2106,10 @@ wheels = [ name = "numpy" version = "1.26.4" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.11'", + "python_full_version < '3.11'", +] sdist = { url = "https://files.pythonhosted.org/packages/65/6e/09db70a523a96d25e115e71cc56a6f9031e7b8cd166c1ac8438307c14058/numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010", size = 15786129 } wheels = [ { url = "https://files.pythonhosted.org/packages/a7/94/ace0fdea5241a27d13543ee117cbc65868e82213fb31a8eb7fe9ff23f313/numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0", size = 20631468 }, @@ -2062,6 +2130,42 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/6b/5610004206cf7f8e7ad91c5a85a8c71b2f2f8051a0c0c4d5916b76d6cbb2/numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2", size = 15811913 }, ] +[[package]] +name = "numpy" +version = "2.2.2" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.11'", + "python_full_version < '3.11'", +] +sdist = { url = "https://files.pythonhosted.org/packages/ec/d0/c12ddfd3a02274be06ffc71f3efc6d0e457b0409c4481596881e748cb264/numpy-2.2.2.tar.gz", hash = "sha256:ed6906f61834d687738d25988ae117683705636936cc605be0bb208b23df4d8f", size = 20233295 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/70/2a/69033dc22d981ad21325314f8357438078f5c28310a6d89fb3833030ec8a/numpy-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7079129b64cb78bdc8d611d1fd7e8002c0a2565da6a47c4df8062349fee90e3e", size = 21215825 }, + { url = "https://files.pythonhosted.org/packages/31/2c/39f91e00bbd3d5639b027ac48c55dc5f2992bd2b305412d26be4c830862a/numpy-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2ec6c689c61df613b783aeb21f945c4cbe6c51c28cb70aae8430577ab39f163e", size = 14354996 }, + { url = "https://files.pythonhosted.org/packages/0a/2c/d468ebd253851af10de5b3e8f3418ebabfaab5f0337a75299fbeb8b8c17a/numpy-2.2.2-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:40c7ff5da22cd391944a28c6a9c638a5eef77fcf71d6e3a79e1d9d9e82752715", size = 5393621 }, + { url = "https://files.pythonhosted.org/packages/7f/f4/3d8a5a0da297034106c5de92be881aca7079cde6058934215a1de91334f6/numpy-2.2.2-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:995f9e8181723852ca458e22de5d9b7d3ba4da3f11cc1cb113f093b271d7965a", size = 6928931 }, + { url = "https://files.pythonhosted.org/packages/47/a7/029354ab56edd43dd3f5efbfad292b8844f98b93174f322f82353fa46efa/numpy-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b78ea78450fd96a498f50ee096f69c75379af5138f7881a51355ab0e11286c97", size = 14333157 }, + { url = "https://files.pythonhosted.org/packages/e3/d7/11fc594838d35c43519763310c316d4fd56f8600d3fc80a8e13e325b5c5c/numpy-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3fbe72d347fbc59f94124125e73fc4976a06927ebc503ec5afbfb35f193cd957", size = 16381794 }, + { url = "https://files.pythonhosted.org/packages/af/d4/dd9b19cd4aff9c79d3f54d17f8be815407520d3116004bc574948336981b/numpy-2.2.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8e6da5cffbbe571f93588f562ed130ea63ee206d12851b60819512dd3e1ba50d", size = 15543990 }, + { url = "https://files.pythonhosted.org/packages/30/97/ab96b7650f27f684a9b1e46757a7294ecc50cab27701d05f146e9f779627/numpy-2.2.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:09d6a2032faf25e8d0cadde7fd6145118ac55d2740132c1d845f98721b5ebcfd", size = 18170896 }, + { url = "https://files.pythonhosted.org/packages/81/9b/bae9618cab20db67a2ca9d711795cad29b2ca4b73034dd3b5d05b962070a/numpy-2.2.2-cp310-cp310-win32.whl", hash = "sha256:159ff6ee4c4a36a23fe01b7c3d07bd8c14cc433d9720f977fcd52c13c0098160", size = 6573458 }, + { url = "https://files.pythonhosted.org/packages/92/9b/95678092febd14070cfb7906ea7932e71e9dd5a6ab3ee948f9ed975e905d/numpy-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:64bd6e1762cd7f0986a740fee4dff927b9ec2c5e4d9a28d056eb17d332158014", size = 12915812 }, + { url = "https://files.pythonhosted.org/packages/21/67/32c68756eed84df181c06528ff57e09138f893c4653448c4967311e0f992/numpy-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:642199e98af1bd2b6aeb8ecf726972d238c9877b0f6e8221ee5ab945ec8a2189", size = 21220002 }, + { url = "https://files.pythonhosted.org/packages/3b/89/f43bcad18f2b2e5814457b1c7f7b0e671d0db12c8c0e43397ab8cb1831ed/numpy-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6d9fc9d812c81e6168b6d405bf00b8d6739a7f72ef22a9214c4241e0dc70b323", size = 14391215 }, + { url = "https://files.pythonhosted.org/packages/9c/e6/efb8cd6122bf25e86e3dd89d9dbfec9e6861c50e8810eed77d4be59b51c6/numpy-2.2.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:c7d1fd447e33ee20c1f33f2c8e6634211124a9aabde3c617687d8b739aa69eac", size = 5391918 }, + { url = "https://files.pythonhosted.org/packages/47/e2/fccf89d64d9b47ffb242823d4e851fc9d36fa751908c9aac2807924d9b4e/numpy-2.2.2-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:451e854cfae0febe723077bd0cf0a4302a5d84ff25f0bfece8f29206c7bed02e", size = 6933133 }, + { url = "https://files.pythonhosted.org/packages/34/22/5ece749c0e5420a9380eef6fbf83d16a50010bd18fef77b9193d80a6760e/numpy-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd249bc894af67cbd8bad2c22e7cbcd46cf87ddfca1f1289d1e7e54868cc785c", size = 14338187 }, + { url = "https://files.pythonhosted.org/packages/5b/86/caec78829311f62afa6fa334c8dfcd79cffb4d24bcf96ee02ae4840d462b/numpy-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02935e2c3c0c6cbe9c7955a8efa8908dd4221d7755644c59d1bba28b94fd334f", size = 16393429 }, + { url = "https://files.pythonhosted.org/packages/c8/4e/0c25f74c88239a37924577d6ad780f3212a50f4b4b5f54f5e8c918d726bd/numpy-2.2.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:a972cec723e0563aa0823ee2ab1df0cb196ed0778f173b381c871a03719d4826", size = 15559103 }, + { url = "https://files.pythonhosted.org/packages/d4/bd/d557f10fa50dc4d5871fb9606af563249b66af2fc6f99041a10e8757c6f1/numpy-2.2.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d6d6a0910c3b4368d89dde073e630882cdb266755565155bc33520283b2d9df8", size = 18182967 }, + { url = "https://files.pythonhosted.org/packages/30/e9/66cc0f66386d78ed89e45a56e2a1d051e177b6e04477c4a41cd590ef4017/numpy-2.2.2-cp311-cp311-win32.whl", hash = "sha256:860fd59990c37c3ef913c3ae390b3929d005243acca1a86facb0773e2d8d9e50", size = 6571499 }, + { url = "https://files.pythonhosted.org/packages/66/a3/4139296b481ae7304a43581046b8f0a20da6a0dfe0ee47a044cade796603/numpy-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:da1eeb460ecce8d5b8608826595c777728cdf28ce7b5a5a8c8ac8d949beadcf2", size = 12919805 }, + { url = "https://files.pythonhosted.org/packages/96/7e/1dd770ee68916ed358991ab62c2cc353ffd98d0b75b901d52183ca28e8bb/numpy-2.2.2-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:b0531f0b0e07643eb089df4c509d30d72c9ef40defa53e41363eca8a8cc61495", size = 21047291 }, + { url = "https://files.pythonhosted.org/packages/d1/3c/ccd08578dc532a8e6927952339d4a02682b776d5e85be49ed0760308433e/numpy-2.2.2-pp310-pypy310_pp73-macosx_14_0_x86_64.whl", hash = "sha256:e9e82dcb3f2ebbc8cb5ce1102d5f1c5ed236bf8a11730fb45ba82e2841ec21df", size = 6792494 }, + { url = "https://files.pythonhosted.org/packages/7c/28/8754b9aee4f97199f9a047f73bb644b5a2014994a6d7b061ba67134a42de/numpy-2.2.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0d4142eb40ca6f94539e4db929410f2a46052a0fe7a2c1c59f6179c39938d2a", size = 16197312 }, + { url = "https://files.pythonhosted.org/packages/26/96/deb93f871f401045a684ca08a009382b247d14996d7a94fea6aa43c67b94/numpy-2.2.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:356ca982c188acbfa6af0d694284d8cf20e95b1c3d0aefa8929376fea9146f60", size = 12822674 }, +] + [[package]] name = "opt-einsum" version = "3.4.0" @@ -2073,11 +2177,11 @@ wheels = [ [[package]] name = "orderly-set" -version = "5.2.3" +version = "5.3.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/5d/9e/8fdcb9ab1b6983cc7c185a4ddafc27518118bd80e9ff2f30aba83636af37/orderly_set-5.2.3.tar.gz", hash = "sha256:571ed97c5a5fca7ddeb6b2d26c19aca896b0ed91f334d9c109edd2f265fb3017", size = 19698 } +sdist = { url = "https://files.pythonhosted.org/packages/e7/0e/ef328b512c2595831304e51f25e9287697b7bf13be0527ca9592a2659c16/orderly_set-5.3.0.tar.gz", hash = "sha256:80b3d8fdd3d39004d9aad389eaa0eab02c71f0a0511ba3a6d54a935a6c6a0acc", size = 20026 } wheels = [ - { url = "https://files.pythonhosted.org/packages/cc/bb/a3a4eab8430f14c7d1476f9db261d32654cb3d1794c0266a46f6574e1190/orderly_set-5.2.3-py3-none-any.whl", hash = "sha256:d357cedcf67f4ebff0d4cbd5b0997e98eeb65dd24fdf5c990a501ae9e82c7d34", size = 12024 }, + { url = "https://files.pythonhosted.org/packages/df/fe/8009ebb64a19cf4bdf51b16d3074375010735d8c30408efada6ce02bf37e/orderly_set-5.3.0-py3-none-any.whl", hash = "sha256:c2c0bfe604f5d3d9b24e8262a06feb612594f37aa3845650548befd7772945d1", size = 12179 }, ] [[package]] @@ -2607,42 +2711,42 @@ wheels = [ [[package]] name = "pyzmq" -version = "26.2.0" +version = "26.2.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "implementation_name == 'pypy' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fd/05/bed626b9f7bb2322cdbbf7b4bd8f54b1b617b0d2ab2d3547d6e39428a48e/pyzmq-26.2.0.tar.gz", hash = "sha256:070672c258581c8e4f640b5159297580a9974b026043bd4ab0470be9ed324f1f", size = 271975 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1f/a8/9837c39aba390eb7d01924ace49d761c8dbe7bc2d6082346d00c8332e431/pyzmq-26.2.0-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:ddf33d97d2f52d89f6e6e7ae66ee35a4d9ca6f36eda89c24591b0c40205a3629", size = 1340058 }, - { url = "https://files.pythonhosted.org/packages/a2/1f/a006f2e8e4f7d41d464272012695da17fb95f33b54342612a6890da96ff6/pyzmq-26.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dacd995031a01d16eec825bf30802fceb2c3791ef24bcce48fa98ce40918c27b", size = 1008818 }, - { url = "https://files.pythonhosted.org/packages/b6/09/b51b6683fde5ca04593a57bbe81788b6b43114d8f8ee4e80afc991e14760/pyzmq-26.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89289a5ee32ef6c439086184529ae060c741334b8970a6855ec0b6ad3ff28764", size = 673199 }, - { url = "https://files.pythonhosted.org/packages/c9/78/486f3e2e824f3a645238332bf5a4c4b4477c3063033a27c1e4052358dee2/pyzmq-26.2.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5506f06d7dc6ecf1efacb4a013b1f05071bb24b76350832c96449f4a2d95091c", size = 911762 }, - { url = "https://files.pythonhosted.org/packages/5e/3b/2eb1667c9b866f53e76ee8b0c301b0469745a23bd5a87b7ee3d5dd9eb6e5/pyzmq-26.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ea039387c10202ce304af74def5021e9adc6297067f3441d348d2b633e8166a", size = 868773 }, - { url = "https://files.pythonhosted.org/packages/16/29/ca99b4598a9dc7e468b5417eda91f372b595be1e3eec9b7cbe8e5d3584e8/pyzmq-26.2.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a2224fa4a4c2ee872886ed00a571f5e967c85e078e8e8c2530a2fb01b3309b88", size = 868834 }, - { url = "https://files.pythonhosted.org/packages/ad/e5/9efaeb1d2f4f8c50da04144f639b042bc52869d3a206d6bf672ab3522163/pyzmq-26.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:28ad5233e9c3b52d76196c696e362508959741e1a005fb8fa03b51aea156088f", size = 1202861 }, - { url = "https://files.pythonhosted.org/packages/c3/62/c721b5608a8ac0a69bb83cbb7d07a56f3ff00b3991a138e44198a16f94c7/pyzmq-26.2.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:1c17211bc037c7d88e85ed8b7d8f7e52db6dc8eca5590d162717c654550f7282", size = 1515304 }, - { url = "https://files.pythonhosted.org/packages/87/84/e8bd321aa99b72f48d4606fc5a0a920154125bd0a4608c67eab742dab087/pyzmq-26.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b8f86dd868d41bea9a5f873ee13bf5551c94cf6bc51baebc6f85075971fe6eea", size = 1414712 }, - { url = "https://files.pythonhosted.org/packages/cd/cd/420e3fd1ac6977b008b72e7ad2dae6350cc84d4c5027fc390b024e61738f/pyzmq-26.2.0-cp310-cp310-win32.whl", hash = "sha256:46a446c212e58456b23af260f3d9fb785054f3e3653dbf7279d8f2b5546b21c2", size = 578113 }, - { url = "https://files.pythonhosted.org/packages/5c/57/73930d56ed45ae0cb4946f383f985c855c9b3d4063f26416998f07523c0e/pyzmq-26.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:49d34ab71db5a9c292a7644ce74190b1dd5a3475612eefb1f8be1d6961441971", size = 641631 }, - { url = "https://files.pythonhosted.org/packages/61/d2/ae6ac5c397f1ccad59031c64beaafce7a0d6182e0452cc48f1c9c87d2dd0/pyzmq-26.2.0-cp310-cp310-win_arm64.whl", hash = "sha256:bfa832bfa540e5b5c27dcf5de5d82ebc431b82c453a43d141afb1e5d2de025fa", size = 543528 }, - { url = "https://files.pythonhosted.org/packages/12/20/de7442172f77f7c96299a0ac70e7d4fb78cd51eca67aa2cf552b66c14196/pyzmq-26.2.0-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:8f7e66c7113c684c2b3f1c83cdd3376103ee0ce4c49ff80a648643e57fb22218", size = 1340639 }, - { url = "https://files.pythonhosted.org/packages/98/4d/5000468bd64c7910190ed0a6c76a1ca59a68189ec1f007c451dc181a22f4/pyzmq-26.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3a495b30fc91db2db25120df5847d9833af237546fd59170701acd816ccc01c4", size = 1008710 }, - { url = "https://files.pythonhosted.org/packages/e1/bf/c67fd638c2f9fbbab8090a3ee779370b97c82b84cc12d0c498b285d7b2c0/pyzmq-26.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77eb0968da535cba0470a5165468b2cac7772cfb569977cff92e240f57e31bef", size = 673129 }, - { url = "https://files.pythonhosted.org/packages/86/94/99085a3f492aa538161cbf27246e8886ff850e113e0c294a5b8245f13b52/pyzmq-26.2.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ace4f71f1900a548f48407fc9be59c6ba9d9aaf658c2eea6cf2779e72f9f317", size = 910107 }, - { url = "https://files.pythonhosted.org/packages/31/1d/346809e8a9b999646d03f21096428453465b1bca5cd5c64ecd048d9ecb01/pyzmq-26.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92a78853d7280bffb93df0a4a6a2498cba10ee793cc8076ef797ef2f74d107cf", size = 867960 }, - { url = "https://files.pythonhosted.org/packages/ab/68/6fb6ae5551846ad5beca295b7bca32bf0a7ce19f135cb30e55fa2314e6b6/pyzmq-26.2.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:689c5d781014956a4a6de61d74ba97b23547e431e9e7d64f27d4922ba96e9d6e", size = 869204 }, - { url = "https://files.pythonhosted.org/packages/0f/f9/18417771dee223ccf0f48e29adf8b4e25ba6d0e8285e33bcbce078070bc3/pyzmq-26.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0aca98bc423eb7d153214b2df397c6421ba6373d3397b26c057af3c904452e37", size = 1203351 }, - { url = "https://files.pythonhosted.org/packages/e0/46/f13e67fe0d4f8a2315782cbad50493de6203ea0d744610faf4d5f5b16e90/pyzmq-26.2.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f3496d76b89d9429a656293744ceca4d2ac2a10ae59b84c1da9b5165f429ad3", size = 1514204 }, - { url = "https://files.pythonhosted.org/packages/50/11/ddcf7343b7b7a226e0fc7b68cbf5a5bb56291fac07f5c3023bb4c319ebb4/pyzmq-26.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5c2b3bfd4b9689919db068ac6c9911f3fcb231c39f7dd30e3138be94896d18e6", size = 1414339 }, - { url = "https://files.pythonhosted.org/packages/01/14/1c18d7d5b7be2708f513f37c61bfadfa62161c10624f8733f1c8451b3509/pyzmq-26.2.0-cp311-cp311-win32.whl", hash = "sha256:eac5174677da084abf378739dbf4ad245661635f1600edd1221f150b165343f4", size = 576928 }, - { url = "https://files.pythonhosted.org/packages/3b/1b/0a540edd75a41df14ec416a9a500b9fec66e554aac920d4c58fbd5756776/pyzmq-26.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:5a509df7d0a83a4b178d0f937ef14286659225ef4e8812e05580776c70e155d5", size = 642317 }, - { url = "https://files.pythonhosted.org/packages/98/77/1cbfec0358078a4c5add529d8a70892db1be900980cdb5dd0898b3d6ab9d/pyzmq-26.2.0-cp311-cp311-win_arm64.whl", hash = "sha256:c0e6091b157d48cbe37bd67233318dbb53e1e6327d6fc3bb284afd585d141003", size = 543834 }, - { url = "https://files.pythonhosted.org/packages/53/fb/36b2b2548286e9444e52fcd198760af99fd89102b5be50f0660fcfe902df/pyzmq-26.2.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:706e794564bec25819d21a41c31d4df2d48e1cc4b061e8d345d7fb4dd3e94072", size = 906955 }, - { url = "https://files.pythonhosted.org/packages/77/8f/6ce54f8979a01656e894946db6299e2273fcee21c8e5fa57c6295ef11f57/pyzmq-26.2.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b435f2753621cd36e7c1762156815e21c985c72b19135dac43a7f4f31d28dd1", size = 565701 }, - { url = "https://files.pythonhosted.org/packages/ee/1c/bf8cd66730a866b16db8483286078892b7f6536f8c389fb46e4beba0a970/pyzmq-26.2.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:160c7e0a5eb178011e72892f99f918c04a131f36056d10d9c1afb223fc952c2d", size = 794312 }, - { url = "https://files.pythonhosted.org/packages/71/43/91fa4ff25bbfdc914ab6bafa0f03241d69370ef31a761d16bb859f346582/pyzmq-26.2.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c4a71d5d6e7b28a47a394c0471b7e77a0661e2d651e7ae91e0cab0a587859ca", size = 752775 }, - { url = "https://files.pythonhosted.org/packages/ec/d2/3b2ab40f455a256cb6672186bea95cd97b459ce4594050132d71e76f0d6f/pyzmq-26.2.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:90412f2db8c02a3864cbfc67db0e3dcdbda336acf1c469526d3e869394fe001c", size = 550762 }, +sdist = { url = "https://files.pythonhosted.org/packages/5a/e3/8d0382cb59feb111c252b54e8728257416a38ffcb2243c4e4775a3c990fe/pyzmq-26.2.1.tar.gz", hash = "sha256:17d72a74e5e9ff3829deb72897a175333d3ef5b5413948cae3cf7ebf0b02ecca", size = 278433 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/70/3d/c2d9d46c033d1b51692ea49a22439f7f66d91d5c938e8b5c56ed7a2151c2/pyzmq-26.2.1-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:f39d1227e8256d19899d953e6e19ed2ccb689102e6d85e024da5acf410f301eb", size = 1345451 }, + { url = "https://files.pythonhosted.org/packages/0e/df/4754a8abcdeef280651f9bb51446c47659910940b392a66acff7c37f5cef/pyzmq-26.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a23948554c692df95daed595fdd3b76b420a4939d7a8a28d6d7dea9711878641", size = 942766 }, + { url = "https://files.pythonhosted.org/packages/74/da/e6053a3b13c912eded6c2cdeee22ff3a4c33820d17f9eb24c7b6e957ffe7/pyzmq-26.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95f5728b367a042df146cec4340d75359ec6237beebf4a8f5cf74657c65b9257", size = 678488 }, + { url = "https://files.pythonhosted.org/packages/9e/50/614934145244142401ca174ca81071777ab93aa88173973ba0154f491e09/pyzmq-26.2.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:95f7b01b3f275504011cf4cf21c6b885c8d627ce0867a7e83af1382ebab7b3ff", size = 917115 }, + { url = "https://files.pythonhosted.org/packages/80/2b/ebeb7bc4fc8e9e61650b2e09581597355a4341d413fa9b2947d7a6558119/pyzmq-26.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80a00370a2ef2159c310e662c7c0f2d030f437f35f478bb8b2f70abd07e26b24", size = 874162 }, + { url = "https://files.pythonhosted.org/packages/79/48/93210621c331ad16313dc2849801411fbae10d91d878853933f2a85df8e7/pyzmq-26.2.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:8531ed35dfd1dd2af95f5d02afd6545e8650eedbf8c3d244a554cf47d8924459", size = 874180 }, + { url = "https://files.pythonhosted.org/packages/f0/8b/40924b4d8e33bfdd54c1970fb50f327e39b90b902f897cf09b30b2e9ac48/pyzmq-26.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:cdb69710e462a38e6039cf17259d328f86383a06c20482cc154327968712273c", size = 1208139 }, + { url = "https://files.pythonhosted.org/packages/c8/b2/82d6675fc89bd965eae13c45002c792d33f06824589844b03f8ea8fc6d86/pyzmq-26.2.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e7eeaef81530d0b74ad0d29eec9997f1c9230c2f27242b8d17e0ee67662c8f6e", size = 1520666 }, + { url = "https://files.pythonhosted.org/packages/9d/e2/5ff15f2d3f920dcc559d477bd9bb3faacd6d79fcf7c5448e585c78f84849/pyzmq-26.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:361edfa350e3be1f987e592e834594422338d7174364763b7d3de5b0995b16f3", size = 1420056 }, + { url = "https://files.pythonhosted.org/packages/40/a2/f9bbeccf7f75aa0d8963e224e5730abcefbf742e1f2ae9ea60fd9d6ff72b/pyzmq-26.2.1-cp310-cp310-win32.whl", hash = "sha256:637536c07d2fb6a354988b2dd1d00d02eb5dd443f4bbee021ba30881af1c28aa", size = 583874 }, + { url = "https://files.pythonhosted.org/packages/56/b1/44f513135843272f0e12f5aebf4af35839e2a88eb45411f2c8c010d8c856/pyzmq-26.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:45fad32448fd214fbe60030aa92f97e64a7140b624290834cc9b27b3a11f9473", size = 647367 }, + { url = "https://files.pythonhosted.org/packages/27/9c/1bef14a37b02d651a462811bbdb1390b61cd4a5b5e95cbd7cc2d60ef848c/pyzmq-26.2.1-cp310-cp310-win_arm64.whl", hash = "sha256:d9da0289d8201c8a29fd158aaa0dfe2f2e14a181fd45e2dc1fbf969a62c1d594", size = 561784 }, + { url = "https://files.pythonhosted.org/packages/b9/03/5ecc46a6ed5971299f5c03e016ca637802d8660e44392bea774fb7797405/pyzmq-26.2.1-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:c059883840e634a21c5b31d9b9a0e2b48f991b94d60a811092bc37992715146a", size = 1346032 }, + { url = "https://files.pythonhosted.org/packages/40/51/48fec8f990ee644f461ff14c8fe5caa341b0b9b3a0ad7544f8ef17d6f528/pyzmq-26.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed038a921df836d2f538e509a59cb638df3e70ca0fcd70d0bf389dfcdf784d2a", size = 943324 }, + { url = "https://files.pythonhosted.org/packages/c1/f4/f322b389727c687845e38470b48d7a43c18a83f26d4d5084603c6c3f79ca/pyzmq-26.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9027a7fcf690f1a3635dc9e55e38a0d6602dbbc0548935d08d46d2e7ec91f454", size = 678418 }, + { url = "https://files.pythonhosted.org/packages/a8/df/2834e3202533bd05032d83e02db7ac09fa1be853bbef59974f2b2e3a8557/pyzmq-26.2.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6d75fcb00a1537f8b0c0bb05322bc7e35966148ffc3e0362f0369e44a4a1de99", size = 915466 }, + { url = "https://files.pythonhosted.org/packages/b5/e2/45c0f6e122b562cb8c6c45c0dcac1160a4e2207385ef9b13463e74f93031/pyzmq-26.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0019cc804ac667fb8c8eaecdb66e6d4a68acf2e155d5c7d6381a5645bd93ae4", size = 873347 }, + { url = "https://files.pythonhosted.org/packages/de/b9/3e0fbddf8b87454e914501d368171466a12550c70355b3844115947d68ea/pyzmq-26.2.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:f19dae58b616ac56b96f2e2290f2d18730a898a171f447f491cc059b073ca1fa", size = 874545 }, + { url = "https://files.pythonhosted.org/packages/1f/1c/1ee41d6e10b2127263b1994bc53b9e74ece015b0d2c0a30e0afaf69b78b2/pyzmq-26.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f5eeeb82feec1fc5cbafa5ee9022e87ffdb3a8c48afa035b356fcd20fc7f533f", size = 1208630 }, + { url = "https://files.pythonhosted.org/packages/3d/a9/50228465c625851a06aeee97c74f253631f509213f979166e83796299c60/pyzmq-26.2.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:000760e374d6f9d1a3478a42ed0c98604de68c9e94507e5452951e598ebecfba", size = 1519568 }, + { url = "https://files.pythonhosted.org/packages/c6/f2/6360b619e69da78863c2108beb5196ae8b955fe1e161c0b886b95dc6b1ac/pyzmq-26.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:817fcd3344d2a0b28622722b98500ae9c8bfee0f825b8450932ff19c0b15bebd", size = 1419677 }, + { url = "https://files.pythonhosted.org/packages/da/d5/f179da989168f5dfd1be8103ef508ade1d38a8078dda4f10ebae3131a490/pyzmq-26.2.1-cp311-cp311-win32.whl", hash = "sha256:88812b3b257f80444a986b3596e5ea5c4d4ed4276d2b85c153a6fbc5ca457ae7", size = 582682 }, + { url = "https://files.pythonhosted.org/packages/60/50/e5b2e9de3ffab73ff92bee736216cf209381081fa6ab6ba96427777d98b1/pyzmq-26.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:ef29630fde6022471d287c15c0a2484aba188adbfb978702624ba7a54ddfa6c1", size = 648128 }, + { url = "https://files.pythonhosted.org/packages/d9/fe/7bb93476dd8405b0fc9cab1fd921a08bd22d5e3016aa6daea1a78d54129b/pyzmq-26.2.1-cp311-cp311-win_arm64.whl", hash = "sha256:f32718ee37c07932cc336096dc7403525301fd626349b6eff8470fe0f996d8d7", size = 562465 }, + { url = "https://files.pythonhosted.org/packages/65/d1/e630a75cfb2534574a1258fda54d02f13cf80b576d4ce6d2aa478dc67829/pyzmq-26.2.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:380816d298aed32b1a97b4973a4865ef3be402a2e760204509b52b6de79d755d", size = 847743 }, + { url = "https://files.pythonhosted.org/packages/27/df/f94a711b4f6c4b41e227f9a938103f52acf4c2e949d91cbc682495a48155/pyzmq-26.2.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:97cbb368fd0debdbeb6ba5966aa28e9a1ae3396c7386d15569a6ca4be4572b99", size = 570991 }, + { url = "https://files.pythonhosted.org/packages/bf/08/0c6f97fb3c9dbfa23382f0efaf8f9aa1396a08a3358974eaae3ee659ed5c/pyzmq-26.2.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abf7b5942c6b0dafcc2823ddd9154f419147e24f8df5b41ca8ea40a6db90615c", size = 799664 }, + { url = "https://files.pythonhosted.org/packages/05/14/f4d4fd8bb8988c667845734dd756e9ee65b9a17a010d5f288dfca14a572d/pyzmq-26.2.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3fe6e28a8856aea808715f7a4fc11f682b9d29cac5d6262dd8fe4f98edc12d53", size = 758156 }, + { url = "https://files.pythonhosted.org/packages/e3/fe/72e7e166bda3885810bee7b23049133e142f7c80c295bae02c562caeea16/pyzmq-26.2.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:bd8fdee945b877aa3bffc6a5a8816deb048dab0544f9df3731ecd0e54d8c84c9", size = 556563 }, ] [[package]] @@ -2800,27 +2904,27 @@ wheels = [ [[package]] name = "ruff" -version = "0.9.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1e/7f/60fda2eec81f23f8aa7cbbfdf6ec2ca11eb11c273827933fb2541c2ce9d8/ruff-0.9.3.tar.gz", hash = "sha256:8293f89985a090ebc3ed1064df31f3b4b56320cdfcec8b60d3295bddb955c22a", size = 3586740 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f9/77/4fb790596d5d52c87fd55b7160c557c400e90f6116a56d82d76e95d9374a/ruff-0.9.3-py3-none-linux_armv6l.whl", hash = "sha256:7f39b879064c7d9670197d91124a75d118d00b0990586549949aae80cdc16624", size = 11656815 }, - { url = "https://files.pythonhosted.org/packages/a2/a8/3338ecb97573eafe74505f28431df3842c1933c5f8eae615427c1de32858/ruff-0.9.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:a187171e7c09efa4b4cc30ee5d0d55a8d6c5311b3e1b74ac5cb96cc89bafc43c", size = 11594821 }, - { url = "https://files.pythonhosted.org/packages/8e/89/320223c3421962762531a6b2dd58579b858ca9916fb2674874df5e97d628/ruff-0.9.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c59ab92f8e92d6725b7ded9d4a31be3ef42688a115c6d3da9457a5bda140e2b4", size = 11040475 }, - { url = "https://files.pythonhosted.org/packages/b2/bd/1d775eac5e51409535804a3a888a9623e87a8f4b53e2491580858a083692/ruff-0.9.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2dc153c25e715be41bb228bc651c1e9b1a88d5c6e5ed0194fa0dfea02b026439", size = 11856207 }, - { url = "https://files.pythonhosted.org/packages/7f/c6/3e14e09be29587393d188454064a4aa85174910d16644051a80444e4fd88/ruff-0.9.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:646909a1e25e0dc28fbc529eab8eb7bb583079628e8cbe738192853dbbe43af5", size = 11420460 }, - { url = "https://files.pythonhosted.org/packages/ef/42/b7ca38ffd568ae9b128a2fa76353e9a9a3c80ef19746408d4ce99217ecc1/ruff-0.9.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a5a46e09355695fbdbb30ed9889d6cf1c61b77b700a9fafc21b41f097bfbba4", size = 12605472 }, - { url = "https://files.pythonhosted.org/packages/a6/a1/3167023f23e3530fde899497ccfe239e4523854cb874458ac082992d206c/ruff-0.9.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c4bb09d2bbb394e3730d0918c00276e79b2de70ec2a5231cd4ebb51a57df9ba1", size = 13243123 }, - { url = "https://files.pythonhosted.org/packages/d0/b4/3c600758e320f5bf7de16858502e849f4216cb0151f819fa0d1154874802/ruff-0.9.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:96a87ec31dc1044d8c2da2ebbed1c456d9b561e7d087734336518181b26b3aa5", size = 12744650 }, - { url = "https://files.pythonhosted.org/packages/be/38/266fbcbb3d0088862c9bafa8b1b99486691d2945a90b9a7316336a0d9a1b/ruff-0.9.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9bb7554aca6f842645022fe2d301c264e6925baa708b392867b7a62645304df4", size = 14458585 }, - { url = "https://files.pythonhosted.org/packages/63/a6/47fd0e96990ee9b7a4abda62de26d291bd3f7647218d05b7d6d38af47c30/ruff-0.9.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cabc332b7075a914ecea912cd1f3d4370489c8018f2c945a30bcc934e3bc06a6", size = 12419624 }, - { url = "https://files.pythonhosted.org/packages/84/5d/de0b7652e09f7dda49e1a3825a164a65f4998175b6486603c7601279baad/ruff-0.9.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:33866c3cc2a575cbd546f2cd02bdd466fed65118e4365ee538a3deffd6fcb730", size = 11843238 }, - { url = "https://files.pythonhosted.org/packages/9e/be/3f341ceb1c62b565ec1fb6fd2139cc40b60ae6eff4b6fb8f94b1bb37c7a9/ruff-0.9.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:006e5de2621304c8810bcd2ee101587712fa93b4f955ed0985907a36c427e0c2", size = 11484012 }, - { url = "https://files.pythonhosted.org/packages/a3/c8/ff8acbd33addc7e797e702cf00bfde352ab469723720c5607b964491d5cf/ruff-0.9.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ba6eea4459dbd6b1be4e6bfc766079fb9b8dd2e5a35aff6baee4d9b1514ea519", size = 12038494 }, - { url = "https://files.pythonhosted.org/packages/73/b1/8d9a2c0efbbabe848b55f877bc10c5001a37ab10aca13c711431673414e5/ruff-0.9.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:90230a6b8055ad47d3325e9ee8f8a9ae7e273078a66401ac66df68943ced029b", size = 12473639 }, - { url = "https://files.pythonhosted.org/packages/cb/44/a673647105b1ba6da9824a928634fe23186ab19f9d526d7bdf278cd27bc3/ruff-0.9.3-py3-none-win32.whl", hash = "sha256:eabe5eb2c19a42f4808c03b82bd313fc84d4e395133fb3fc1b1516170a31213c", size = 9834353 }, - { url = "https://files.pythonhosted.org/packages/c3/01/65cadb59bf8d4fbe33d1a750103e6883d9ef302f60c28b73b773092fbde5/ruff-0.9.3-py3-none-win_amd64.whl", hash = "sha256:040ceb7f20791dfa0e78b4230ee9dce23da3b64dd5848e40e3bf3ab76468dcf4", size = 10821444 }, - { url = "https://files.pythonhosted.org/packages/69/cb/b3fe58a136a27d981911cba2f18e4b29f15010623b79f0f2510fd0d31fd3/ruff-0.9.3-py3-none-win_arm64.whl", hash = "sha256:800d773f6d4d33b0a3c60e2c6ae8f4c202ea2de056365acfa519aa48acf28e0b", size = 10038168 }, +version = "0.9.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/02/74/6c359f6b9ed85b88df6ef31febce18faeb852f6c9855651dfb1184a46845/ruff-0.9.5.tar.gz", hash = "sha256:11aecd7a633932875ab3cb05a484c99970b9d52606ce9ea912b690b02653d56c", size = 3634177 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/4b/82b7c9ac874e72b82b19fd7eab57d122e2df44d2478d90825854f9232d02/ruff-0.9.5-py3-none-linux_armv6l.whl", hash = "sha256:d466d2abc05f39018d53f681fa1c0ffe9570e6d73cde1b65d23bb557c846f442", size = 11681264 }, + { url = "https://files.pythonhosted.org/packages/27/5c/f5ae0a9564e04108c132e1139d60491c0abc621397fe79a50b3dc0bd704b/ruff-0.9.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:38840dbcef63948657fa7605ca363194d2fe8c26ce8f9ae12eee7f098c85ac8a", size = 11657554 }, + { url = "https://files.pythonhosted.org/packages/2a/83/c6926fa3ccb97cdb3c438bb56a490b395770c750bf59f9bc1fe57ae88264/ruff-0.9.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d56ba06da53536b575fbd2b56517f6f95774ff7be0f62c80b9e67430391eeb36", size = 11088959 }, + { url = "https://files.pythonhosted.org/packages/af/a7/42d1832b752fe969ffdbfcb1b4cb477cb271bed5835110fb0a16ef31ab81/ruff-0.9.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f7cb2a01da08244c50b20ccfaeb5972e4228c3c3a1989d3ece2bc4b1f996001", size = 11902041 }, + { url = "https://files.pythonhosted.org/packages/53/cf/1fffa09fb518d646f560ccfba59f91b23c731e461d6a4dedd21a393a1ff1/ruff-0.9.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:96d5c76358419bc63a671caac70c18732d4fd0341646ecd01641ddda5c39ca0b", size = 11421069 }, + { url = "https://files.pythonhosted.org/packages/09/27/bb8f1b7304e2a9431f631ae7eadc35550fe0cf620a2a6a0fc4aa3d736f94/ruff-0.9.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:deb8304636ed394211f3a6d46c0e7d9535b016f53adaa8340139859b2359a070", size = 12625095 }, + { url = "https://files.pythonhosted.org/packages/d7/ce/ab00bc9d3df35a5f1b64f5117458160a009f93ae5caf65894ebb63a1842d/ruff-0.9.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:df455000bf59e62b3e8c7ba5ed88a4a2bc64896f900f311dc23ff2dc38156440", size = 13257797 }, + { url = "https://files.pythonhosted.org/packages/88/81/c639a082ae6d8392bc52256058ec60f493c6a4d06d5505bccface3767e61/ruff-0.9.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de92170dfa50c32a2b8206a647949590e752aca8100a0f6b8cefa02ae29dce80", size = 12763793 }, + { url = "https://files.pythonhosted.org/packages/b3/d0/0a3d8f56d1e49af466dc770eeec5c125977ba9479af92e484b5b0251ce9c/ruff-0.9.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d28532d73b1f3f627ba88e1456f50748b37f3a345d2be76e4c653bec6c3e393", size = 14386234 }, + { url = "https://files.pythonhosted.org/packages/04/70/e59c192a3ad476355e7f45fb3a87326f5219cc7c472e6b040c6c6595c8f0/ruff-0.9.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c746d7d1df64f31d90503ece5cc34d7007c06751a7a3bbeee10e5f2463d52d2", size = 12437505 }, + { url = "https://files.pythonhosted.org/packages/55/4e/3abba60a259d79c391713e7a6ccabf7e2c96e5e0a19100bc4204f1a43a51/ruff-0.9.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:11417521d6f2d121fda376f0d2169fb529976c544d653d1d6044f4c5562516ee", size = 11884799 }, + { url = "https://files.pythonhosted.org/packages/a3/db/b0183a01a9f25b4efcae919c18fb41d32f985676c917008620ad692b9d5f/ruff-0.9.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:5b9d71c3879eb32de700f2f6fac3d46566f644a91d3130119a6378f9312a38e1", size = 11527411 }, + { url = "https://files.pythonhosted.org/packages/0a/e4/3ebfcebca3dff1559a74c6becff76e0b64689cea02b7aab15b8b32ea245d/ruff-0.9.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:2e36c61145e70febcb78483903c43444c6b9d40f6d2f800b5552fec6e4a7bb9a", size = 12078868 }, + { url = "https://files.pythonhosted.org/packages/ec/b2/5ab808833e06c0a1b0d046a51c06ec5687b73c78b116e8d77687dc0cd515/ruff-0.9.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:2f71d09aeba026c922aa7aa19a08d7bd27c867aedb2f74285a2639644c1c12f5", size = 12524374 }, + { url = "https://files.pythonhosted.org/packages/e0/51/1432afcc3b7aa6586c480142caae5323d59750925c3559688f2a9867343f/ruff-0.9.5-py3-none-win32.whl", hash = "sha256:134f958d52aa6fdec3b294b8ebe2320a950d10c041473c4316d2e7d7c2544723", size = 9853682 }, + { url = "https://files.pythonhosted.org/packages/b7/ad/c7a900591bd152bb47fc4882a27654ea55c7973e6d5d6396298ad3fd6638/ruff-0.9.5-py3-none-win_amd64.whl", hash = "sha256:78cc6067f6d80b6745b67498fb84e87d32c6fc34992b52bffefbdae3442967d6", size = 10865744 }, + { url = "https://files.pythonhosted.org/packages/75/d9/fde7610abd53c0c76b6af72fc679cb377b27c617ba704e25da834e0a0608/ruff-0.9.5-py3-none-win_arm64.whl", hash = "sha256:18a29f1a005bddb229e580795627d297dfa99f16b30c7039e73278cf6b5f9fa9", size = 10064595 }, ] [[package]] @@ -2828,7 +2932,8 @@ name = "scipy" version = "1.15.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-5-gt4py-all' or extra == 'extra-5-gt4py-dace' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/76/c6/8eb0654ba0c7d0bb1bf67bf8fbace101a8e4f250f7722371105e8b6f68fc/scipy-1.15.1.tar.gz", hash = "sha256:033a75ddad1463970c96a88063a1df87ccfddd526437136b6ee81ff0312ebdf6", size = 59407493 } wheels = [ @@ -2891,6 +2996,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303 }, ] +[[package]] +name = "sniffio" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 }, +] + [[package]] name = "snowballstemmer" version = "2.2.0" @@ -3147,7 +3261,7 @@ wheels = [ [[package]] name = "tach" -version = "0.23.0" +version = "0.24.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "gitpython" }, @@ -3159,18 +3273,18 @@ dependencies = [ { name = "tomli" }, { name = "tomli-w" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d9/4b/de2e7ad0a22e63fbed979064381da1290391dd623a3fd80d0728ea72d545/tach-0.23.0.tar.gz", hash = "sha256:ae123491231ab0712417d579b9a3259014d713d72626805ff64552955e43e912", size = 482218 } +sdist = { url = "https://files.pythonhosted.org/packages/05/2c/1afb1a3c16125b9cfc5a1da79ba2329dec11e16b9c9eea7ac411074a49cb/tach-0.24.1.tar.gz", hash = "sha256:63f7f3b3e3458a97ded020b524f32fc72bc731ff880d0709301b2802ff759721", size = 490250 } wheels = [ - { url = "https://files.pythonhosted.org/packages/c3/87/9aa4142dc31314500af0003f406851a212b589a7e680e78c39751fc26681/tach-0.23.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:aa30db4158e48694154d346def14d3a096672381fa09e3cf09eae190ff9066f0", size = 3240516 }, - { url = "https://files.pythonhosted.org/packages/b3/db/3d856d856a688b024470494785dc8d177e1728904e180aa9394e80d8787e/tach-0.23.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:2e54365a3101c08a35d51357007e37723cd86c8bf464b73a3b43401edd2053d8", size = 3095903 }, - { url = "https://files.pythonhosted.org/packages/19/c9/1302175f5b350891727356c03bfdbffb884323db3c30cc34b2c7e93c932b/tach-0.23.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0af6be9328ec907deac141165b43b7db58f055bc20ea46b65b82b10fed72cd3", size = 3373159 }, - { url = "https://files.pythonhosted.org/packages/af/3d/ad4a2f4e2142b789085886a3acbb2f8e1a99068014303c7aa1166350aa38/tach-0.23.0-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1b8205440863f61389b29a9baf2e2cd171d87c6931f3d6baf69eda69092440df", size = 3325828 }, - { url = "https://files.pythonhosted.org/packages/ab/87/4114a20e97f9a8652865bdf541d7b3121a731d6539d7f6b7d6bb70a86f46/tach-0.23.0-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b783d0f121c579f761dad7bf6ceeddec8f901e3778ed29a2db57c1c17804577", size = 3627127 }, - { url = "https://files.pythonhosted.org/packages/b5/cd/88b4f103eea5d2a3b0696265131f43f07e5bf9b1b81ccc0471512121ceae/tach-0.23.0-cp37-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:625403b59430eee9b5c2c05dff9575c8623ea88bcf58728e55b843fdbf04031d", size = 3623389 }, - { url = "https://files.pythonhosted.org/packages/12/77/3be44b77ad3ab8a6f05c245e399ff1e9f48df6be5e706c34b0863eaa4bdc/tach-0.23.0-cp37-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0c9671d4be806f9aa6a714a38ac26b455704ac01019555f2441445335e749fb5", size = 3884923 }, - { url = "https://files.pythonhosted.org/packages/d7/8b/d7f9c9a1cb6a0f6745a1c4cdb824bc1abbac2a4f9fa30e57de37b7a223b9/tach-0.23.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:caede4e23800d412c83b96288c7f03845971f6ea10dcfff40a026d294db1996f", size = 3483408 }, - { url = "https://files.pythonhosted.org/packages/48/8e/930460944b5cddeff297de774981ce8ffd1e80c59ea5f0616ade89a6871b/tach-0.23.0-cp37-abi3-win32.whl", hash = "sha256:828a59f7e2effdac3802025177b1a83e53b27ee54b00ef6305a0e36cec448e55", size = 2725999 }, - { url = "https://files.pythonhosted.org/packages/ea/01/4e4c9b551fa9ffd0db74e14966c393928aefa59019b6d5bd8a9a645ee714/tach-0.23.0-cp37-abi3-win_amd64.whl", hash = "sha256:5dc03ef01d1a2e9d39fa238c271e9a4f8d9db2459212425ceb05b8ed0547000f", size = 2930346 }, + { url = "https://files.pythonhosted.org/packages/9e/b3/2af242caa456cd48c83ed8a3872c8eabe9d616d556ea52c1b39835f661c3/tach-0.24.1-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:b965048c4918bd8d24d54a8a7a232bf6b210c1dd0c97caed83ac2f8db271db45", size = 3403749 }, + { url = "https://files.pythonhosted.org/packages/d1/2d/a64f5a9b0674527cc6c95fba681d7d53652f0cc092ce3d768e11409c3378/tach-0.24.1-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:bd203c8a581c6cf1f3813d5eeacd612bdb0c2681939677b33cc7d555d9216ff0", size = 3252234 }, + { url = "https://files.pythonhosted.org/packages/41/36/627ef905e792a0a281ce416581eae33e963b7dda5023460fd81ea0ab944e/tach-0.24.1-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73230ce1af9be01b08e42bd6002344562a5e51942b806869e0c3d784a38ae117", size = 3537522 }, + { url = "https://files.pythonhosted.org/packages/cc/90/d79c0cbfcae6f91b9c3cf5f2c077786057fcd59a4ca06608a3df1c072b3b/tach-0.24.1-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb982d6ead606ead1ca2d5decf1aa10414d6eecdded92de9755940acb18fd1df", size = 3497754 }, + { url = "https://files.pythonhosted.org/packages/77/5b/07fb1554509539cd4a2582a24b49ff3961cdb39cfe064429c8fd7b4fc9cb/tach-0.24.1-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50a56b14fcb8d311d07ac49fdec1a6619b4644b991112c17e894838827f198bb", size = 3814772 }, + { url = "https://files.pythonhosted.org/packages/36/33/1c9b051aada11d4171ba4a64cb537f1f95bc6d093cfae4d235bb0124813a/tach-0.24.1-cp37-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3756ea8fdd7ffeaaa4c2bb272ff3c407f51e7c83d8108ecc28f4acdcb11f5bd4", size = 3789273 }, + { url = "https://files.pythonhosted.org/packages/96/d8/6b3f624d5fa7db9a43e29887b643ae4c560127764e94aea93a4ec51a87e4/tach-0.24.1-cp37-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8f278a930651e7cafb5b2b8fd398cfc0ac205f9c81e618aad1d5bedcce86217d", size = 4057183 }, + { url = "https://files.pythonhosted.org/packages/b3/63/bd8028d67f36f4a35acbed746eb822be8825c1cc02eb990c780ad24877ee/tach-0.24.1-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6eb884a8936d9910d2d8675ad04726ecfba7ac830e09c2463acd561250f507e", size = 3655117 }, + { url = "https://files.pythonhosted.org/packages/6a/be/4a8ff273365dbafe2414665d81bb7416e0ed76b836ebfa6e5aa92ab579f9/tach-0.24.1-cp37-abi3-win32.whl", hash = "sha256:7d5db6480ea33ee95f023d9882b1d67863fb06eb802e97948d5b6c7b0a56bb39", size = 2857513 }, + { url = "https://files.pythonhosted.org/packages/8e/1a/92e7b283147e27750d1485fbe6bd595c64d9d8d017104971175bd82d4072/tach-0.24.1-cp37-abi3-win_amd64.whl", hash = "sha256:4e321f45a1457da49e9aab2f11630907776b0031e78242a80650b27413cb925c", size = 3071088 }, ] [[package]] @@ -3266,11 +3380,11 @@ wheels = [ [[package]] name = "types-pytz" -version = "2024.2.0.20241221" +version = "2025.1.0.20250204" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/54/26/516311b02b5a215e721155fb65db8a965d061372e388d6125ebce8d674b0/types_pytz-2024.2.0.20241221.tar.gz", hash = "sha256:06d7cde9613e9f7504766a0554a270c369434b50e00975b3a4a0f6eed0f2c1a9", size = 10213 } +sdist = { url = "https://files.pythonhosted.org/packages/b3/d2/2190c54d53c04491ad72a1df019c5dfa692e6ab6c2dba1be7b6c9d530e30/types_pytz-2025.1.0.20250204.tar.gz", hash = "sha256:00f750132769f1c65a4f7240bc84f13985b4da774bd17dfbe5d9cd442746bd49", size = 10352 } wheels = [ - { url = "https://files.pythonhosted.org/packages/74/db/c92ca6920cccd9c2998b013601542e2ac5e59bc805bcff94c94ad254b7df/types_pytz-2024.2.0.20241221-py3-none-any.whl", hash = "sha256:8fc03195329c43637ed4f593663df721fef919b60a969066e22606edf0b53ad5", size = 10008 }, + { url = "https://files.pythonhosted.org/packages/be/50/65ffad73746f1d8b15992c030e0fd22965fd5ae2c0206dc28873343b3230/types_pytz-2025.1.0.20250204-py3-none-any.whl", hash = "sha256:32ca4a35430e8b94f6603b35beb7f56c32260ddddd4f4bb305fdf8f92358b87e", size = 10059 }, ] [[package]] From 64b90dc73e12d415b93d0ee6344df68f5c912c46 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 11 Feb 2025 14:14:14 +0100 Subject: [PATCH 140/178] fix[next][dace]: fix map fusion and loop blocking (#1856) Improve optimization for icon4py stencil `apply_diffusion_to_vn` by means of two changes: - Ignore check of dynamic volume property on memlet that prevented serial map fusion. - Add support for NestedSDFG nodes in loop blocking transformation. --- .../dace/transformations/loop_blocking.py | 28 ++-- .../dace/transformations/map_fusion_serial.py | 121 ++++++++--------- .../test_loop_blocking.py | 125 ++++++++++++++---- 3 files changed, 175 insertions(+), 99 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py index 344b0b8c22..fa77a7fd1d 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py @@ -380,6 +380,24 @@ def _classify_node( ): return False + # Test if the body of the Tasklet depends on the block variable. + if self.blocking_parameter in node_to_classify.free_symbols: + return False + + elif isinstance(node_to_classify, dace.nodes.NestedSDFG): + # Same check as for Tasklets applies to the outputs of a nested SDFG node + if not all( + isinstance(out_edge.dst, dace_nodes.AccessNode) + for out_edge in state.out_edges(node_to_classify) + if not out_edge.data.is_empty() + ): + return False + + # Additionally, test if the symbol mapping depends on the block variable. + for v in node_to_classify.symbol_mapping.values(): + if self.blocking_parameter in v.free_symbols: + return False + elif isinstance(node_to_classify, dace_nodes.AccessNode): # AccessNodes need to have some special properties. node_desc: dace.data.Data = node_to_classify.desc(sdfg) @@ -422,16 +440,6 @@ def _classify_node( if out_edge.dst is outer_exit: return False - # Now we have ensured that the partition exists, thus we will now evaluate - # if the node is independent or dependent. - - # Test if the body of the Tasklet depends on the block variable. - if ( - isinstance(node_to_classify, dace_nodes.Tasklet) - and self.blocking_parameter in node_to_classify.free_symbols - ): - return False - # Now we have to look at incoming edges individually. # We will inspect the subset of the Memlet to see if they depend on the # block variable. If this loop ends normally, then we classify the node diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py index 27d962d0bd..977f2933b5 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py @@ -125,8 +125,8 @@ def can_be_applied( output_partition = self.partition_first_outputs( state=graph, sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, + first_map_exit=map_exit_1, + second_map_entry=map_entry_2, ) if output_partition is None: return False @@ -375,8 +375,8 @@ def partition_first_outputs( self, state: SDFGState, sdfg: SDFG, - map_exit_1: nodes.MapExit, - map_entry_2: nodes.MapEntry, + first_map_exit: nodes.MapExit, + second_map_entry: nodes.MapEntry, ) -> Union[ Tuple[ Set[graph.MultiConnectorEdge[dace.Memlet]], @@ -385,19 +385,19 @@ def partition_first_outputs( ], None, ]: - """Partition the output edges of `map_exit_1` for serial map fusion. + """Partition the output edges of `first_map_exit` for serial map fusion. The output edges of the first map are partitioned into three distinct sets, defined as follows: - - Pure Output Set `\mathbb{P}`: + * Pure Output Set `\mathbb{P}`: These edges exits the first map and does not enter the second map. These outputs will be simply be moved to the output of the second map. - - Exclusive Intermediate Set `\mathbb{E}`: + * Exclusive Intermediate Set `\mathbb{E}`: Edges in this set leaves the first map exit, enters an access node, from where a Memlet then leads immediately to the second map. The memory referenced by this access node is not used anywhere else, thus it can be removed. - - Shared Intermediate Set `\mathbb{S}`: + * Shared Intermediate Set `\mathbb{S}`: These edges are very similar to the one in `\mathbb{E}` except that they are used somewhere else, thus they can not be removed and have to be recreated as output of the second map. @@ -406,17 +406,14 @@ def partition_first_outputs( output can be added to either intermediate set and might fail to compute the partition, even if it would exist. - Returns: - If such a decomposition exists the function will return the three sets - mentioned above in the same order. - In case the decomposition does not exist, i.e. the maps can not be fused - the function returns `None`. + :return: If such a decomposition exists the function will return the three sets + mentioned above in the same order. In case the decomposition does not exist, + i.e. the maps can not be fused the function returns `None`. - Args: - state: The in which the two maps are located. - sdfg: The full SDFG in whcih we operate. - map_exit_1: The exit node of the first map. - map_entry_2: The entry node of the second map. + :param state: The in which the two maps are located. + :param sdfg: The full SDFG in whcih we operate. + :param first_map_exit: The exit node of the first map. + :param second_map_entry: The entry node of the second map. """ # The three outputs set. pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() @@ -425,28 +422,17 @@ def partition_first_outputs( # Compute the renaming that for translating the parameter of the _second_ # map to the ones used by the first map. - repl_dict: Dict[str, str] = self.find_parameter_remapping( # type: ignore[assignment] - first_map=map_exit_1.map, - second_map=map_entry_2.map, + param_repl: Dict[str, str] = self.find_parameter_remapping( # type: ignore[assignment] + first_map=first_map_exit.map, + second_map=second_map_entry.map, ) - assert repl_dict is not None + assert param_repl is not None # Set of intermediate nodes that we have already processed. processed_inter_nodes: Set[nodes.Node] = set() - # These are the data that is written to multiple times in _this_ state. - # If a data is written to multiple time in a state, it could be - # classified as shared. However, it might happen that the node has zero - # degree. This is not a problem as the maps also induced a before-after - # relationship. But some DaCe transformations do not catch this. - # Thus we will never modify such intermediate nodes and fail instead. - if self.strict_dataflow: - multi_write_data: Set[str] = self._compute_multi_write_data(state, sdfg) - else: - multi_write_data = set() - # Now scan all output edges of the first exit and classify them - for out_edge in state.out_edges(map_exit_1): + for out_edge in state.out_edges(first_map_exit): intermediate_node: nodes.Node = out_edge.dst # We already processed the node, this should indicate that we should @@ -469,7 +455,7 @@ def partition_first_outputs( if not self.is_node_reachable_from( graph=state, begin=intermediate_node, - end=map_entry_2, + end=second_map_entry, ): pure_outputs.add(out_edge) continue @@ -479,6 +465,12 @@ def partition_first_outputs( # cases, as handling them is essentially rerouting an edge, whereas # handling intermediate nodes is much more complicated. + # Empty Memlets are only allowed if they are in `\mathbb{P}`, which + # is also the only place they really make sense (for a map exit). + # Thus if we now found an empty Memlet we reject it. + if out_edge.data.is_empty(): + return None + # For us an intermediate node must always be an access node, because # everything else we do not know how to handle. It is important that # we do not test for non transient data here, because they can be @@ -488,22 +480,6 @@ def partition_first_outputs( if self.is_view(intermediate_node, sdfg): return None - # Checks if the intermediate node refers to data that is accessed by - # _other_ access nodes in _this_ state. If this is the case then never - # touch this intermediate node. - # TODO(phimuell): Technically it would be enough to turn the node into - # a shared output node, because this will still fulfil the dependencies. - # However, some DaCe transformation can not handle this properly, so we - # are _forced_ to reject this node. - if intermediate_node.data in multi_write_data: - return None - - # Empty Memlets are only allowed if they are in `\mathbb{P}`, which - # is also the only place they really make sense (for a map exit). - # Thus if we now found an empty Memlet we reject it. - if out_edge.data.is_empty(): - return None - # It can happen that multiple edges converges at the `IN_` connector # of the first map exit, but there is only one edge leaving the exit. # It is complicate to handle this, so for now we ignore it. @@ -511,7 +487,7 @@ def partition_first_outputs( # To handle this we need to associate a consumer edge (the outgoing edges # of the second map) with exactly one producer. producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list( - state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) + state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:]) ) if len(producer_edges) > 1: return None @@ -520,7 +496,7 @@ def partition_first_outputs( # - The source of the producer can not be a view (we do not handle this) # - The edge shall also not be a reduction edge. # - Defined location to where they write. - # - No dynamic Memlets. + # - No dynamic Melets. # Furthermore, we will also extract the subsets, i.e. the location they # modify inside the intermediate array. # Since we do not allow for WCR, we do not check if the producer subsets intersects. @@ -531,6 +507,7 @@ def partition_first_outputs( ): return None if producer_edge.data.dynamic: + # TODO(phimuell): Find out if this restriction could be lifted, but it is unlikely. return None if producer_edge.data.wcr is not None: return None @@ -562,9 +539,9 @@ def partition_first_outputs( for intermediate_node_out_edge in state.out_edges(intermediate_node): # If the second map entry is not immediately reachable from the intermediate # node, then ensure that there is not path that goes to it. - if intermediate_node_out_edge.dst is not map_entry_2: + if intermediate_node_out_edge.dst is not second_map_entry: if self.is_node_reachable_from( - graph=state, begin=intermediate_node_out_edge.dst, end=map_entry_2 + graph=state, begin=intermediate_node_out_edge.dst, end=second_map_entry ): return None continue @@ -583,15 +560,16 @@ def partition_first_outputs( # Now we look at all edges that leave the second map entry, i.e. the # edges that feeds the consumer and define what is read inside the map. # We do not check them, but collect them and inspect them. - # NOTE: The subset still uses the old iteration variables. + # NOTE1: The subset still uses the old iteration variables. + # NOTE2: In case of consumer Memlet we explicitly allow dynamic Memlets. + # This is different compared to the producer Memlet. The reason is + # because in a consumer the data is conditionally read, so the data + # has to exists anyway. for inner_consumer_edge in state.out_edges_by_connector( - map_entry_2, "OUT_" + intermediate_node_out_edge.dst_conn[3:] + second_map_entry, "OUT_" + intermediate_node_out_edge.dst_conn[3:] ): if inner_consumer_edge.data.src_subset is None: return None - if inner_consumer_edge.data.dynamic: - # TODO(phimuell): Is this restriction necessary, I am not sure. - return None consumer_subsets.append(inner_consumer_edge.data.src_subset) assert ( found_second_map @@ -599,11 +577,11 @@ def partition_first_outputs( assert len(consumer_subsets) != 0 # The consumer still uses the original symbols of the second map, so we must rename them. - if repl_dict: + if param_repl: consumer_subsets = copy.deepcopy(consumer_subsets) for consumer_subset in consumer_subsets: symbolic.safe_replace( - mapping=repl_dict, replace_callback=consumer_subset.replace + mapping=param_repl, replace_callback=consumer_subset.replace ) # Now we are checking if a single iteration of the first (top) map @@ -623,6 +601,21 @@ def partition_first_outputs( # output of the second map. if self.is_shared_data(data=intermediate_node, state=state, sdfg=sdfg): # The intermediate data is used somewhere else, either in this or another state. + # NOTE: If the intermediate is shared, then we will turn it into a + # sink node attached to the combined map exit. Technically this + # should be enough, even if the same data appears again in the + # dataflow down streams. However, some DaCe transformations, + # I am looking at you `auto_optimizer()` do not like that. Thus + # if the intermediate is used further down in the same datadflow, + # then we consider that the maps can not be fused. But we only + # do this in the strict data flow mode. + if self.strict_dataflow: + if self._is_data_accessed_downstream( + data=intermediate_node.data, + graph=state, + begin=intermediate_node, # is ignored itself. + ): + return None shared_outputs.add(out_edge) else: # The intermediate can be removed, as it is not used anywhere else. @@ -669,8 +662,8 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non output_partition = self.partition_first_outputs( state=graph, sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, + first_map_exit=map_exit_1, + second_map_entry=map_entry_2, ) assert output_partition is not None # Make MyPy happy. pure_outputs, exclusive_outputs, shared_outputs = output_partition diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py index 4d7a8156d7..86136994dc 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py @@ -7,6 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import copy +from enum import Enum from typing import Callable import numpy as np @@ -506,8 +507,14 @@ def test_empty_memlet_3(): assert scope_dict[task2] is inner_mentry +class IndependentPart(Enum): + NONE = 0 + TASKLET = 1 + NESTED_SDFG = 2 + + def _make_loop_blocking_sdfg_with_inner_map( - add_independent_part: bool, + add_independent_part: IndependentPart, ) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapEntry]: """ Generate the SDFGs with an inner map. @@ -533,15 +540,6 @@ def _make_loop_blocking_sdfg_with_inner_map( "computation", inputs={"__in1", "__in2"}, outputs={"__out"}, code="__out = __in1 + __in2" ) - if add_independent_part: - sdfg.add_array("C", shape=(10,), dtype=dace.float64, transient=False) - sdfg.add_scalar("tmp", dtype=dace.float64, transient=True) - sdfg.add_scalar("tmp2", dtype=dace.float64, transient=True) - tmp, tmp2, C = (state.add_access(name) for name in ("tmp", "tmp2", "C")) - tskli = state.add_tasklet( - "independent_comp", inputs={"__field"}, outputs={"__out"}, code="__out = __field[1, 1]" - ) - # construct the inner map of the map. state.add_edge(A, None, me_out, "IN_A", dace.Memlet("A[0:10, 0:10]")) me_out.add_in_connector("IN_A") @@ -565,14 +563,42 @@ def _make_loop_blocking_sdfg_with_inner_map( mx_out.add_out_connector("OUT_B") # If requested add a part that is independent, i.e. is before the inner loop - if add_independent_part: - state.add_edge(me_out, "OUT_A", tskli, "__field", dace.Memlet("A[0:10, 0:10]")) - state.add_edge(tskli, "__out", tmp, None, dace.Memlet("tmp[0]")) + if add_independent_part != IndependentPart.NONE: + sdfg.add_array("C", shape=(10,), dtype=dace.float64, transient=False) + sdfg.add_scalar("tmp", dtype=dace.float64, transient=True) + sdfg.add_scalar("tmp2", dtype=dace.float64, transient=True) + tmp, tmp2, C = (state.add_access(name) for name in ("tmp", "tmp2", "C")) state.add_edge(tmp, None, tmp2, None, dace.Memlet("tmp2[0]")) state.add_edge(tmp2, None, mx_out, "IN_tmp", dace.Memlet("C[__i0]")) mx_out.add_in_connector("IN_tmp") state.add_edge(mx_out, "OUT_tmp", C, None, dace.Memlet("C[0:10]")) mx_out.add_out_connector("OUT_tmp") + match add_independent_part: + case IndependentPart.TASKLET: + tskli = state.add_tasklet( + "independent_comp", + inputs={"__field"}, + outputs={"__out"}, + code="__out = __field[1, 1]", + ) + state.add_edge(me_out, "OUT_A", tskli, "__field", dace.Memlet("A[0:10, 0:10]")) + state.add_edge(tskli, "__out", tmp, None, dace.Memlet("tmp[0]")) + case IndependentPart.NESTED_SDFG: + nsdfg_sym, nsdfg_inp, nsdfg_out = ("S", "I", "V") + nsdfg = _make_conditional_block_sdfg( + "independent_comp", nsdfg_sym, nsdfg_inp, nsdfg_out + ) + nsdfg_node = state.add_nested_sdfg( + nsdfg, + sdfg, + inputs={nsdfg_inp}, + outputs={nsdfg_out}, + symbol_mapping={nsdfg_sym: 0}, + ) + state.add_edge(me_out, "OUT_A", nsdfg_node, nsdfg_inp, dace.Memlet("A[1, 1]")) + state.add_edge(nsdfg_node, nsdfg_out, tmp, None, dace.Memlet("tmp[0]")) + case _: + raise NotImplementedError() sdfg.validate() return sdfg, state, me_out, me_in @@ -582,7 +608,9 @@ def test_loop_blocking_inner_map(): """ Tests with an inner map, without an independent part. """ - sdfg, state, outer_map, inner_map = _make_loop_blocking_sdfg_with_inner_map(False) + sdfg, state, outer_map, inner_map = _make_loop_blocking_sdfg_with_inner_map( + IndependentPart.NONE + ) assert all(oedge.dst is inner_map for oedge in state.out_edges(outer_map)) count = sdfg.apply_transformations_repeated( @@ -605,20 +633,23 @@ def test_loop_blocking_inner_map(): assert all(oedge.dst is inner_map for oedge in state.out_edges(inner_blocking_map)) -def test_loop_blocking_inner_map_with_independent_part(): +@pytest.mark.parametrize("independent_part", [IndependentPart.TASKLET, IndependentPart.NESTED_SDFG]) +def test_loop_blocking_inner_map_with_independent_part(independent_part): """ Tests with an inner map with an independent part. """ - sdfg, state, outer_map, inner_map = _make_loop_blocking_sdfg_with_inner_map(True) + sdfg, state, outer_map, inner_map = _make_loop_blocking_sdfg_with_inner_map(independent_part) # Find the parts that are independent. - itskl: dace_nodes.Tasklet = next( + independent_node: dace_nodes.Tasklet | dace_nodes.NestedSDFG = next( oedge.dst for oedge in state.out_edges(outer_map) - if isinstance(oedge.dst, dace_nodes.Tasklet) + if isinstance(oedge.dst, (dace_nodes.Tasklet, dace_nodes.NestedSDFG)) + ) + assert independent_node.label == "independent_comp" + i_access_node: dace_nodes.AccessNode = next( + oedge.dst for oedge in state.out_edges(independent_node) ) - assert itskl.label == "independent_comp" - i_access_node: dace_nodes.AccessNode = next(oedge.dst for oedge in state.out_edges(itskl)) assert i_access_node.data == "tmp" count = sdfg.apply_transformations_repeated( @@ -634,7 +665,9 @@ def test_loop_blocking_inner_map_with_independent_part(): ) assert inner_blocking_map is not inner_map - assert all(oedge.dst in {inner_blocking_map, itskl} for oedge in state.out_edges(outer_map)) + assert all( + oedge.dst in {inner_blocking_map, independent_node} for oedge in state.out_edges(outer_map) + ) assert state.scope_dict()[i_access_node] is outer_map assert all(oedge.dst is inner_blocking_map for oedge in state.out_edges(i_access_node)) @@ -745,7 +778,34 @@ def _apply_and_run_mixed_memlet_sdfg(sdfg: dace.SDFG) -> None: assert all(np.allclose(ref[name], res[name]) for name in ref) -def test_loop_blocking_mixked_memlets_1(): +def _make_conditional_block_sdfg(sdfg_label: str, sym: str, inp: str, out: str): + sdfg = dace.SDFG(sdfg_label) + for data in [inp, out]: + sdfg.add_scalar(data, dtype=dace.float64) + + if_region = dace.sdfg.state.ConditionalBlock("if") + sdfg.add_node(if_region) + entry_state = sdfg.add_state("entry", is_start_block=True) + sdfg.add_edge(entry_state, if_region, dace.InterstateEdge()) + + then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=sdfg) + tstate = then_body.add_state("true_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock(f"{sym} % 2 == 0"), then_body) + tskli = tstate.add_tasklet("write_0", inputs={"inp"}, outputs={"val"}, code=f"val = inp + 0") + tstate.add_edge(tstate.add_access(inp), None, tskli, "inp", dace.Memlet(f"{inp}[0]")) + tstate.add_edge(tskli, "val", tstate.add_access(out), None, dace.Memlet(f"{out}[0]")) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=sdfg) + fstate = else_body.add_state("false_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock(f"{sym} % 2 != 0"), else_body) + tskli = fstate.add_tasklet("write_1", inputs={"inp"}, outputs={"val"}, code=f"val = inp + 1") + fstate.add_edge(fstate.add_access(inp), None, tskli, "inp", dace.Memlet(f"{inp}[0]")) + fstate.add_edge(tskli, "val", fstate.add_access(out), None, dace.Memlet(f"{out}[0]")) + + return sdfg + + +def test_loop_blocking_mixed_memlets_1(): sdfg, state, me, tskl1, tskl2 = _make_mixed_memlet_sdfg(True) mx = state.exit_node(me) @@ -782,7 +842,7 @@ def test_loop_blocking_mixked_memlets_1(): ) -def test_loop_blocking_mixked_memlets_2(): +def test_loop_blocking_mixed_memlets_2(): sdfg, state, me, tskl1, tskl2 = _make_mixed_memlet_sdfg(False) mx = state.exit_node(me) @@ -810,7 +870,7 @@ def test_loop_blocking_no_independent_nodes(): sdfg = dace.SDFG(util.unique_name("mixed_memlet_sdfg")) state = sdfg.add_state(is_start_block=True) - names = ["A", "B"] + names = ["A", "B", "C"] for aname in names: sdfg.add_array( aname, @@ -818,13 +878,28 @@ def test_loop_blocking_no_independent_nodes(): dtype=dace.float64, transient=False, ) - state.add_mapped_tasklet( + A = state.add_access("A") + _, me, mx = state.add_mapped_tasklet( "fully_dependent_computation", map_ranges={"__i0": "0:10", "__i1": "0:10"}, inputs={"__in1": dace.Memlet("A[__i0, __i1]")}, code="__out = __in1 + 10.0", outputs={"__out": dace.Memlet("B[__i0, __i1]")}, external_edges=True, + input_nodes={A}, + ) + nsdfg_sym, nsdfg_inp, nsdfg_out = ("S", "I", "V") + nsdfg = _make_conditional_block_sdfg("dependent_component", nsdfg_sym, nsdfg_inp, nsdfg_out) + nsdfg_node = state.add_nested_sdfg( + nsdfg, sdfg, inputs={nsdfg_inp}, outputs={nsdfg_out}, symbol_mapping={nsdfg_sym: "__i1"} + ) + state.add_memlet_path(A, me, nsdfg_node, dst_conn=nsdfg_inp, memlet=dace.Memlet("A[1,1]")) + state.add_memlet_path( + nsdfg_node, + mx, + state.add_access("C"), + src_conn=nsdfg_out, + memlet=dace.Memlet("C[__i0, __i1]"), ) sdfg.validate() From 83985b29389c9889e4969bcb95035508265eeb3e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 12 Feb 2025 12:35:00 +0100 Subject: [PATCH 141/178] build: bump minimum matplotlib version (#1858) Daily CI builds the minimum version which doesn't have wheels for python 3.10 and 3.11. Didn't investigate why it worked at some point... --- pyproject.toml | 2 +- uv.lock | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b512c6c93e..1efce6bd29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dev = [ docs = [ 'esbonio>=0.16.0', 'jupytext>=1.14', - 'matplotlib>=3.3', + 'matplotlib>=3.8.4', 'myst-parser>=4.0.0', 'pygments>=2.7.3', 'sphinx>=7.3.7', diff --git a/uv.lock b/uv.lock index c07a329b39..dbcb32411d 100644 --- a/uv.lock +++ b/uv.lock @@ -1274,7 +1274,7 @@ dev = [ { name = "esbonio", specifier = ">=0.16.0" }, { name = "hypothesis", specifier = ">=6.0.0" }, { name = "jupytext", specifier = ">=1.14" }, - { name = "matplotlib", specifier = ">=3.3" }, + { name = "matplotlib", specifier = ">=3.8.4" }, { name = "mypy", extras = ["faster-cache"], specifier = ">=1.13.0" }, { name = "myst-parser", specifier = ">=4.0.0" }, { name = "nbmake", specifier = ">=1.4.6" }, @@ -1305,7 +1305,7 @@ dev = [ docs = [ { name = "esbonio", specifier = ">=0.16.0" }, { name = "jupytext", specifier = ">=1.14" }, - { name = "matplotlib", specifier = ">=3.3" }, + { name = "matplotlib", specifier = ">=3.8.4" }, { name = "myst-parser", specifier = ">=4.0.0" }, { name = "pygments", specifier = ">=2.7.3" }, { name = "sphinx", specifier = ">=7.3.7" }, From 27ad3840f7ffeaf916a8648f1c3469afc2b7c73a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Fri, 14 Feb 2025 07:44:05 +0100 Subject: [PATCH 142/178] fix[dace][next]: Fixed `LoopBlocking` (#1859) Co-authored-by: Edoardo Paone Co-authored-by: edopao --- .../dace/transformations/loop_blocking.py | 6 +- .../test_loop_blocking.py | 65 +++++++++++++++++++ 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py index fa77a7fd1d..826b5949f2 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py @@ -204,9 +204,9 @@ def _prepare_inner_outer_maps( inner_label = f"inner_{outer_map.label}" inner_range = { self.blocking_parameter: dace_subsets.Range.from_string( - f"({coarse_block_var} * {self.blocking_size} + {rng_start})" + f"(({rng_start}) + ({coarse_block_var}) * ({self.blocking_size}))" + ":" - + f"min(({rng_start} + {coarse_block_var} + 1) * {self.blocking_size}, {rng_stop} + 1)" + + f"min(({rng_start}) + ({coarse_block_var} + 1) * ({self.blocking_size}), ({rng_stop}) + 1)" ) } inner_entry, inner_exit = state.add_map( @@ -219,7 +219,7 @@ def _prepare_inner_outer_maps( # Now we modify the properties of the outer map. coarse_block_range = dace_subsets.Range.from_string( - f"0:int_ceil(({rng_stop} + 1) - {rng_start}, {self.blocking_size})" + f"0:int_ceil((({rng_stop}) + 1) - ({rng_start}), ({self.blocking_size}))" ).ranges[0] outer_map.params[blocking_parameter_dim] = coarse_block_var outer_map.range[blocking_parameter_dim] = coarse_block_range diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py index 86136994dc..3b41da6336 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py @@ -927,3 +927,68 @@ def test_loop_blocking_no_independent_nodes(): validate_all=True, ) assert count == 1 + + +def _make_only_last_two_elements_sdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("simple_block_sdfg")) + state = sdfg.add_state("state", is_start_block=True) + sdfg.add_symbol("N", dace.int32) + sdfg.add_symbol("B", dace.int32) + sdfg.add_symbol("M", dace.int32) + + for name in "acb": + sdfg.add_array( + name, + shape=(20, 10), + dtype=dace.float64, + ) + + state.add_mapped_tasklet( + "computation", + map_ranges={"i": "B:N", "k": "(M-2):M"}, + inputs={ + "__in1": dace.Memlet("a[i, k]"), + "__in2": dace.Memlet("b[i, k]"), + }, + code="__out = __in1 + __in2", + outputs={"__out": dace.Memlet("c[i, k]")}, + external_edges=True, + ) + sdfg.validate() + + return sdfg + + +def test_only_last_two_elements_sdfg(): + sdfg = _make_only_last_two_elements_sdfg() + + def ref_comp(a, b, c, B, N, M): + for i in range(B, N): + for k in range(M - 2, M): + c[i, k] = a[i, k] + b[i, k] + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking( + blocking_size=1, + blocking_parameter="k", + require_independent_nodes=False, + ), + validate=True, + validate_all=True, + ) + assert count == 1 + + ref = { + "a": np.array(np.random.rand(20, 10), dtype=np.float64), + "b": np.array(np.random.rand(20, 10), dtype=np.float64), + "c": np.zeros((20, 10), dtype=np.float64), + "B": 0, + "N": 20, + "M": 6, + } + res = copy.deepcopy(ref) + + ref_comp(**ref) + sdfg(**res) + + assert np.allclose(ref["c"], res["c"]) From 937e894aebb517a6bd46742bf65cdc2b7118cd43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Fri, 14 Feb 2025 08:35:12 +0100 Subject: [PATCH 143/178] fix[dace][next]: Update MapFusion (#1857) Last year, the then state of MapFusion from [DaCe PR#1629](https://github.com/spcl/dace/pull/1629) was added to GT4Py, as a temporary fix until the PR in DaCe is merged and parallel map fusion has become available there. However, during that time the transformation in the PR has evolved and improved and some of the bug that were fixed are now appearing in GT4Py, for example [PR#1850](https://github.com/GridTools/gt4py/pull/1850) and [PR#1856](https://github.com/GridTools/gt4py/pull/1856). Thus this PR updated the MapFusion transformation that is currently inside GT4Py and replaces it with newest development version from DaCe. Because we need it, and it was designed from the start to be that way, it also adds parallel map fusion to the transformation. As before, this transformation, currently fully located in `map_fusion_dace.py`, is only kept inside the repo until DaCe has caught up to it. The PR also introduces some additional memory layer that encapsulates the DaCe transformation. Something that we have to deal with in the long run and we currently do because other parts of the toolchain require it. --------- Co-authored-by: edopao --- .../runners/dace/transformations/__init__.py | 4 +- .../dace/transformations/auto_optimize.py | 29 +- .../runners/dace/transformations/gpu_utils.py | 8 +- .../dace/transformations/map_fusion.py | 180 ++ .../dace/transformations/map_fusion_dace.py | 2090 +++++++++++++++++ .../dace/transformations/map_fusion_helper.py | 676 ------ .../transformations/map_fusion_parallel.py | 170 -- .../dace/transformations/map_fusion_serial.py | 1053 --------- .../dace/transformations/map_promoter.py | 6 +- .../transformation_tests/test_map_fusion.py | 102 +- 10 files changed, 2389 insertions(+), 1929 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace/transformations/map_fusion.py create mode 100644 src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_dace.py delete mode 100644 src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_helper.py delete mode 100644 src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_parallel.py delete mode 100644 src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index c8e1cf292f..df48d35d39 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -21,8 +21,7 @@ ) from .local_double_buffering import gt_create_local_double_buffering from .loop_blocking import LoopBlocking -from .map_fusion_parallel import MapFusionParallel -from .map_fusion_serial import MapFusionSerial +from .map_fusion import MapFusion, MapFusionParallel, MapFusionSerial from .map_orderer import MapIterationOrder, gt_set_iteration_order from .map_promoter import SerialMapPromoter from .simplify import ( @@ -52,6 +51,7 @@ "GT4PyMapBufferElimination", "GT4PyMoveTaskletIntoMap", "LoopBlocking", + "MapFusion", "MapFusionParallel", "MapFusionSerial", "MapIterationOrder", diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 8137c60959..d6e9fc259d 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -13,6 +13,7 @@ import dace from dace.transformation import dataflow as dace_dataflow from dace.transformation.auto import auto_optimize as dace_aoptimize +from dace.transformation.passes import analysis as dace_analysis from gt4py.next import common as gtx_common from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations @@ -328,22 +329,26 @@ def gt_auto_fuse_top_level_maps( # after the other, thus new opportunities might arise in the next round. # We use the hash of the SDFG to detect if we have reached a fix point. for _ in range(max_optimization_rounds): - # Use map fusion to reduce their number and to create big kernels # TODO(phimuell): Use a cost measurement to decide if fusion should be done. # TODO(phimuell): Add parallel fusion transformation. Should it run after # or with the serial one? + # TODO(phimuell): Switch to `FullMapFusion` once DaCe has parallel map fusion + # and [issue#1911](https://github.com/spcl/dace/issues/1911) has been solved. + + # First we do scan the entire SDFG to figure out which data is only + # used once and can be deleted. MapFusion could do this on its own but + # it is more efficient to do it once and then reuse it. + find_single_use_data = dace_analysis.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + + fusion_transformation = gtx_transformations.MapFusion( + only_toplevel_maps=True, + only_if_common_ancestor=False, + ) + fusion_transformation._single_use_data = single_use_data + sdfg.apply_transformations_repeated( - [ - gtx_transformations.MapFusionSerial( - only_toplevel_maps=True, - ), - gtx_transformations.MapFusionParallel( - only_toplevel_maps=True, - # This will lead to the creation of big probably unrelated maps. - # However, it might be good. - only_if_common_ancestor=False, - ), - ], + fusion_transformation, validate=validate, validate_all=validate_all, ) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index c2ac528647..6359cc1127 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -171,7 +171,7 @@ def gt_gpu_transform_non_standard_memlet( # This function allows to restrict any fusion operation to the maps # that we have just created. def restrict_fusion_to_newly_created_maps( - self: gtx_transformations.map_fusion_helper.MapFusionHelper, + self: gtx_transformations.MapFusion, map_entry_1: dace_nodes.MapEntry, map_entry_2: dace_nodes.MapEntry, graph: Union[dace.SDFGState, dace.SDFG], @@ -690,9 +690,9 @@ def can_be_applied( self._promote_map(graph, replace_trivail_map_parameter=False) if not gtx_transformations.MapFusionSerial.can_be_applied_to( sdfg=sdfg, - map_exit_1=trivial_map_exit, - intermediate_access_node=self.access_node, - map_entry_2=self.second_map_entry, + first_map_exit=trivial_map_exit, + array=self.access_node, + second_map_entry=self.second_map_entry, ): return False finally: diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion.py new file mode 100644 index 0000000000..00828520c8 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion.py @@ -0,0 +1,180 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +"""An interface between DaCe's MapFusion and the one of GT4Py.""" + +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. + +from typing import Any, Callable, Optional, TypeAlias, TypeVar, Union + +import dace +from dace import nodes as dace_nodes, properties as dace_properties + +from gt4py.next.program_processors.runners.dace.transformations import ( + map_fusion_dace as dace_map_fusion, +) + + +_MapFusionType = TypeVar("_MapFusionType", bound="dace_map_fusion.MapFusion") + +FusionTestCallback: TypeAlias = Callable[ + [_MapFusionType, dace_nodes.MapEntry, dace_nodes.MapEntry, dace.SDFGState, dace.SDFG, int], bool +] +"""Callback for the map fusion transformation to check if a fusion should be performed. + +The callback returns `True` if the fusion should be performed and `False` if it +should be rejected. See also the description of GT4Py's MapFusion transformation for +more information. + +The arguments are as follows: +- The transformation object that is active. +- The MapEntry node of the first map; exact meaning depends on if parallel or + serial map fusion is performed. +- The MapEntry node of the second map; exact meaning depends on if parallel or + serial map fusion is performed. +- The SDFGState that that contains the data flow. +- The SDFG that is processed. +- The expression index, see `expr_index` in `can_be_applied()` it is `0` for + serial map fusion and `1` for parallel map fusion. +""" + + +@dace_properties.make_properties +class MapFusion(dace_map_fusion.MapFusion): + """GT4Py's MapFusion transformation. + + It is a wrapper that adds some functionality to the transformation that is not + present in the DaCe version of this transformation. + There are two important differences when compared with DaCe's MapFusion: + - In DaCe strict data flow is enabled by default, in GT4Py it is disabled by default. + - GT4Py accepts an additional argument `apply_fusion_callback`. This is a + function that is called by the transformation, at the _beginning_ of + `self.can_be_applied()`, i.e. before the transformation does any check if + the maps can be fused. If this function returns `False`, `self.can_be_applied()` + ends and returns `False`. In case the callback returns `True` the transformation + will perform the usual steps to check if the transformation can apply or not. + For the signature see `FusionTestCallback`. + + Args: + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + strict_dataflow: Strict dataflow mode should be used, it is disabled by default. + assume_always_shared: Assume that all intermediates are shared. + allow_serial_map_fusion: Allow serial map fusion, by default `True`. + allow_parallel_fusion: Allow to merge parallel maps, by default `False`. + only_if_common_ancestor: In parallel map fusion mode, only fuse if both maps + have a common direct ancestor. + apply_fusion_callback: The callback function that is used. + + Todo: + Investigate ways of how to remove this intermediate layer. The main reason + why we need it is the callback functionality, but it is not needed often + and in these cases it might be solved differently. + """ + + _apply_fusion_callback: Optional[FusionTestCallback] + + def __init__( + self, + strict_dataflow: bool = False, + apply_fusion_callback: Optional[FusionTestCallback] = None, + **kwargs: Any, + ) -> None: + self._apply_fusion_callback = None + super().__init__(strict_dataflow=strict_dataflow, **kwargs) + if apply_fusion_callback is not None: + self._apply_fusion_callback = apply_fusion_callback + + def can_be_applied( + self, + graph: Union[dace.SDFGState, dace.SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Performs basic checks if the maps can be fused. + + Args: + map_entry_1: The entry of the first (in serial case the top) map. + map_exit_2: The entry of the second (in serial case the bottom) map. + graph: The SDFGState in which the maps are located. + sdfg: The SDFG itself. + permissive: Currently unused. + """ + assert expr_index in [0, 1] + + # If the call back is given then proceed with it. + if self._apply_fusion_callback is not None: + if expr_index == 0: # Serial MapFusion. + first_map_entry: dace_nodes.MapEntry = graph.entry_node(self.first_map_exit) + second_map_entry: dace_nodes.MapEntry = self.second_map_entry + elif expr_index == 1: # Parallel MapFusion + first_map_entry = self.first_parallel_map_entry + second_map_entry = self.second_parallel_map_entry + else: + raise NotImplementedError(f"Not implemented expression: {expr_index}") + + # Apply the call back. + if not self._apply_fusion_callback( + self, + first_map_entry, + second_map_entry, + graph, + sdfg, + expr_index, + ): + return False + + # Now forward to the underlying implementation. + return super().can_be_applied( + graph=graph, + expr_index=expr_index, + sdfg=sdfg, + permissive=permissive, + ) + + +@dace_properties.make_properties +class MapFusionSerial(MapFusion): + """Wrapper around `MapFusion` that only supports serial map fusion. + + Note: + This class exists only for the transition period. + """ + + def __init__( + self, + **kwargs: Any, + ) -> None: + assert "allow_serial_map_fusion" not in kwargs + assert "allow_parallel_map_fusion" not in kwargs + super().__init__( + allow_serial_map_fusion=True, + allow_parallel_map_fusion=False, + **kwargs, + ) + + +@dace_properties.make_properties +class MapFusionParallel(MapFusion): + """Wrapper around `MapFusion` that only supports parallel map fusion. + + Note: + This class exists only for the transition period. + """ + + def __init__( + self, + **kwargs: Any, + ) -> None: + assert "allow_serial_map_fusion" not in kwargs + assert "allow_parallel_map_fusion" not in kwargs + super().__init__( + allow_serial_map_fusion=False, + allow_parallel_map_fusion=True, + **kwargs, + ) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_dace.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_dace.py new file mode 100644 index 0000000000..c301ce0ac4 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_dace.py @@ -0,0 +1,2090 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +"""Implements Helper functionaliyies for map fusion + +THIS FILE WAS COPIED FROM DACE TO FACILITATE DEVELOPMENT UNTIL THE PR#1625 IN +DACE IS MERGED AND THE VERSION WAS UPGRADED. +""" + +import copy +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union + +import dace +from dace import data, properties, subsets, symbolic, transformation +from dace.sdfg import SDFG, SDFGState, graph, nodes, validation +from dace.transformation import helpers + + +@properties.make_properties +class MapFusion(transformation.SingleStateTransformation): + """Implements the MapFusion transformation. + + From a high level perspective it will remove the MapExit node of the first and the MapEntry node of + the second Map. It will then rewire and modify the Memlets such that the data flow bypasses the + intermediate node. For this a new intermediate node will be created, which is much smaller because + it has no longer to store the whole output of the first map, but only the data that is produced by + a single iteration of the first map. The transformation will then remove the old intermediate. + Thus by merging the two Maps together the transformation will reduce the memory footprint. It is + important that it is not always possible to fully remove the intermediate node. For example the + data might be used somewhere else. In this case the intermediate will become an output of the Map. + + An example would be the following: + ```python + for i in range(N): + T[i] = foo(A[i]) + for j in range(N): + B[j] = bar(T[i]) + ``` + which would be translated into: + ```python + for i in range(N): + temp: scalar = foo(A[i]) + B[i] = bar(temp) + ``` + + The checks that two Maps can be fused are quite involved, however, they essentially check: + * If the two Maps cover the same iteration space, essentially have the same start, stop and + iteration , see `find_parameter_remapping()`. + * Furthermore, they verify if the new fused Map did not introduce read write conflict, + essentially it tests if the data is pointwise, i.e. what is read is also written, + see `has_read_write_dependency()`. + * Then it will examine the intermediate data. This will essentially test if the data that + is needed by a single iteration of the second Map is produced by a single iteration of + the first Map, see `partition_first_outputs()`. + + By default `strict_dataflow` is enabled. In this mode the transformation is more conservative. + The main difference is, that it will not adjust the subsets of the intermediate, i.e. turning + an array with shape `(1, 1, 1, 1)` into a scalar. Furthermore, shared intermediates, see + `partition_first_outputs()` will only be created if the data is not referred downstream in + the dataflow. + + In order to determine if an intermediate can be removed or has to be kept, it is in general + necessary to scan the whole SDFG, which is the default behaviour. There are two ways to + speed this up. The first way is to set `assume_always_shared` to `True`. In this case the + transformation will not perform the scan, but assume that the data is shared, i.e. used + somewhere else. This might lead to dead data flow. + The second way is to use the transformation inside a pipeline, which includes the + `FindSingleUseData` analysis pass. If the result of this pass is present then the + transformation will use it instead to determine if a intermediate can be removed. + Note that `assume_always_shared` takes precedence. + For this pattern the `FullMapFusion` pass is provided, that combines the analysis + pass and `MapFusion`. + + By default this transformation only handles the case where to maps are right after each other, + separated by an intermediate array. However, by setting `allow_parallel_map_fusion` to `True`, + the transformation will be _in addition_ also be able to handle the case where the Maps are + parallel (parallel here means that neither of the two Map can be reached from the other; see + `is_parallel()`). If you only want to perform parallel map fusion you also have to set + `allow_serial_map_fusion` to `False`. + + :param only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + :param only_toplevel_maps: Only consider Maps that are at the top. + :param strict_dataflow: Which dataflow mode should be used, see above. + :param assume_always_shared: Assume that all intermediates are shared. + :param allow_serial_map_fusion: Allow serial map fusion, by default `True`. + :param allow_parallel_map_fusion: Allow to merge parallel maps, by default `False`. + :param only_if_common_ancestor: In parallel map fusion mode, only fuse if both map + have a common direct ancestor. + + :note: This transformation modifies more nodes than it matches. + :note: If `assume_always_shared` is `True` then the transformation will assume that + all intermediates are shared. This avoids the problems mentioned above with + the cache at the expense of the creation of dead dataflow. + """ + + # Pattern Nodes: For the serial map fusion + # NOTE: Can only be accessed in the `can_serial_map_fusion_be_applied()` and the + # `apply_serial_map_fusion()` functions. + first_map_exit = transformation.transformation.PatternNode(nodes.MapExit) + array = transformation.transformation.PatternNode(nodes.AccessNode) + second_map_entry = transformation.transformation.PatternNode(nodes.MapEntry) + + # Pattern Nodes: For the parallel map fusion + # NOTE: Can only be used in the `can_parallel_map_fusion_be_applied()` and the + # `apply_map_fusion_parallel()` functions. + first_parallel_map_entry = transformation.transformation.PatternNode(nodes.MapEntry) + second_parallel_map_entry = transformation.transformation.PatternNode(nodes.MapEntry) + + # Settings + only_toplevel_maps = properties.Property( + dtype=bool, + default=False, + desc="Only perform fusing if the Maps are in the top level.", + ) + only_inner_maps = properties.Property( + dtype=bool, + default=False, + desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", + ) + strict_dataflow = properties.Property( + dtype=bool, + default=True, + desc="If `True` then the transformation will ensure a more stricter data flow.", + ) + assume_always_shared = properties.Property( + dtype=bool, + default=False, + desc="If `True` then all intermediates will be classified as shared.", + ) + + allow_serial_map_fusion = properties.Property( + dtype=bool, + default=True, + desc="If `True`, the default, then allow serial map fusion.", + ) + + allow_parallel_map_fusion = properties.Property( + dtype=bool, + default=False, + desc="If `True` then also perform parallel map fusion, disabled by default.", + ) + only_if_common_ancestor = properties.Property( + dtype=bool, + default=False, + desc="If `True` restrict parallel map fusion to maps that have a direct common ancestor.", + ) + + def __init__( + self, + only_inner_maps: Optional[bool] = None, + only_toplevel_maps: Optional[bool] = None, + strict_dataflow: Optional[bool] = None, + assume_always_shared: Optional[bool] = None, + allow_serial_map_fusion: Optional[bool] = None, + allow_parallel_map_fusion: Optional[bool] = None, + only_if_common_ancestor: Optional[bool] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if only_toplevel_maps is not None: + self.only_toplevel_maps = only_toplevel_maps + if only_inner_maps is not None: + self.only_inner_maps = only_inner_maps + if strict_dataflow is not None: + self.strict_dataflow = strict_dataflow + if assume_always_shared is not None: + self.assume_always_shared = assume_always_shared + if allow_serial_map_fusion is not None: + self.allow_serial_map_fusion = allow_serial_map_fusion + if allow_parallel_map_fusion is not None: + self.allow_parallel_map_fusion = allow_parallel_map_fusion + if only_if_common_ancestor is not None: + self.only_if_common_ancestor = only_if_common_ancestor + + # See comment in `is_shared_data()` for more information. + self._single_use_data: Optional[Dict[dace.SDFG, Set[str]]] = None + + @classmethod + def expressions(cls) -> Any: + """Get the match expression. + + The function returns a list of two expressions. + + The first, index `0`, is used by the serial map fusion. It consists of the + exit node of the first map, `first_map_exit`, the intermediate array, `array`, + and the map entry node of the second map, `second_map_entry`. An important note + is, that the transformation operates not just on the matched nodes, but more + or less on anything that has an incoming connection from the first Map or an + outgoing connection to the second Map entry. + + The second expression, index `1`, is used by parallel map fusion. It matches + any two maps entries, `first_parallel_map_entry` and `second_parallel_map_entry + in a state. + """ + map_fusion_serial_match = dace.sdfg.utils.node_path_graph( + cls.first_map_exit, cls.array, cls.second_map_entry + ) + + map_fusion_parallel_match = graph.OrderedMultiDiConnectorGraph() + map_fusion_parallel_match.add_nodes_from( + [cls.first_parallel_map_entry, cls.second_parallel_map_entry] + ) + + return [map_fusion_serial_match, map_fusion_parallel_match] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Checks if the map fusion can be applied. + + Depending on the value of `expr_index` the function will dispatch the call + either to `can_serial_map_fusion_be_applied()` or + `can_parallel_map_fusion_be_applied()`, see there for more information. + """ + # Perform some checks of the deferred configuration data. + if not (self.allow_parallel_map_fusion or self.allow_serial_map_fusion): + raise ValueError("Disabled serial and parallel map fusion.") + assert expr_index == self.expr_index + assert self.expr_index in [0, 1], f"Found invalid 'expr_index' {self.expr_index}" + + # To ensures that the `{src,dst}_subset` are properly set, run initialization. + # See [issue 1708](https://github.com/spcl/dace/issues/1703) + for edge in graph.edges(): + edge.data.try_initialize(sdfg, graph, edge) + + # Now perform the dispatch. + if self.allow_serial_map_fusion and expr_index == 0: + return self.can_serial_map_fusion_be_applied( + graph=graph, + sdfg=sdfg, + ) + + elif self.allow_parallel_map_fusion and expr_index == 1: + return self.can_parallel_map_fusion_be_applied( + graph=graph, + sdfg=sdfg, + ) + + # Non of the cases applied + return False + + def apply( + self, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + ) -> None: + """Apply the map fusion. + + Depending on the settings the function will either dispatch to + `apply_serial_map_fusion()` or to `apply_parallel_map_fusion()`. + """ + # Perform some checks of the deferred configuration data. + if not (self.allow_parallel_map_fusion or self.allow_serial_map_fusion): + raise ValueError("Disabled serial and parallel map fusion.") + assert self.expr_index in [0, 1] + + # To ensures that the `{src,dst}_subset` are properly set, run initialization. + # See [issue 1708](https://github.com/spcl/dace/issues/1703) + for edge in graph.edges(): + edge.data.try_initialize(sdfg, graph, edge) + + # Now perform the dispatch. + if self.expr_index == 0: + assert self.allow_serial_map_fusion + return self.apply_serial_map_fusion( + graph=graph, + sdfg=sdfg, + ) + + elif self.expr_index == 1: + assert self.allow_parallel_map_fusion + return self.apply_parallel_map_fusion( + graph=graph, + sdfg=sdfg, + ) + + else: + raise NotImplementedError(f"Encountered unknown expression index {self.expr_index}") + + def can_parallel_map_fusion_be_applied( + self, + graph: Union[SDFGState, SDFG], + sdfg: dace.SDFG, + ) -> bool: + """Check if the matched Maps can be fused in parallel.""" + assert self.expr_index == 1 + + # NOTE: The after this point it is not legal to access the matched nodes + first_map_entry: nodes.MapEntry = self.first_parallel_map_entry + second_map_entry: nodes.MapEntry = self.second_parallel_map_entry + + # Check the structural properties of the Maps. The function will return + # the `dict` that describes how the parameters must be renamed (for caching) + # or `None` if the maps can not be structurally fused. + param_repl = self.can_topologically_be_fused( + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, + graph=graph, + sdfg=sdfg, + ) + if param_repl is None: + return False + + # Test if they have they share a node as direct ancestor. + if self.only_if_common_ancestor: + # TODO(phimuell): Improve this such that different AccessNode that refer + # to the same data are also considered the same; Probably an overkill. + first_ancestors: Set[nodes.Node] = {e1.src for e1 in graph.in_edges(first_map_entry)} + if not any(e2.src in first_ancestors for e2 in graph.in_edges(second_map_entry)): + return False + + return True + + def can_serial_map_fusion_be_applied( + self, + graph: Union[SDFGState, SDFG], + sdfg: dace.SDFG, + ) -> bool: + """Tests if the matched Maps can be merged serially. + + The two Maps are mergeable iff: + * Checks general requirements, see `can_topologically_be_fused()`. + * Tests if there are read write dependencies. + * Tests if the decomposition exists. + """ + assert self.expr_index == 0 + + # NOTE: The after this point it is not legal to access the matched nodes + first_map_entry: nodes.MapEntry = graph.entry_node(self.first_map_exit) + first_map_exit: nodes.MapExit = self.first_map_exit + second_map_entry: nodes.MapEntry = self.second_map_entry + + # Check the structural properties of the Maps. The function will return + # the `dict` that describes how the parameters must be renamed (for caching) + # or `None` if the maps can not be structurally fused. + param_repl = self.can_topologically_be_fused( + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, + graph=graph, + sdfg=sdfg, + ) + if param_repl is None: + return False + + # Tests if there are read write dependencies that are caused by the bodies + # of the Maps, such as referring to the same data. Note that this tests are + # different from the ones performed by `has_read_write_dependency()`, which + # only checks the data dependencies that go through the scope nodes. + if self.has_inner_read_write_dependency( + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, + state=graph, + sdfg=sdfg, + ): + return False + + # Tests for read write conflicts of the two maps, this is only checking + # the data that goes through the scope nodes. `has_inner_read_write_dependency()` + # if used to check if there are internal dependencies. + if self.has_read_write_dependency( + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, + param_repl=param_repl, + state=graph, + sdfg=sdfg, + ): + return False + + # Two maps can be serially fused if the node decomposition exists and + # at least one of the intermediate output sets is not empty. The state + # of the pure outputs is irrelevant for serial map fusion. + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + param_repl=param_repl, + ) + if output_partition is None: + return False + _, exclusive_outputs, shared_outputs = output_partition + if not (exclusive_outputs or shared_outputs): + return False + + return True + + def apply_parallel_map_fusion( + self, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + ) -> None: + """Performs parallel map fusion. + + Essentially this function will move all input connectors from one map, + i.e. its MapEntry and MapExit nodes, to the other map. + """ + + # NOTE: The after this point it is not legal to access the matched nodes + first_map_entry: nodes.MapEntry = self.first_parallel_map_entry + first_map_exit: nodes.MapExit = graph.exit_node(first_map_entry) + second_map_entry: nodes.MapEntry = self.second_parallel_map_entry + second_map_exit: nodes.MapExit = graph.exit_node(second_map_entry) + + # Before we do anything we perform the renaming, i.e. we will rename the + # parameters of the second map such that they match the one of the first map. + self.rename_map_parameters( + first_map=first_map_entry.map, + second_map=second_map_entry.map, + second_map_entry=second_map_entry, + state=graph, + ) + + # Now we relocate all connectors from the second to the first map and remove + # the respective node of the second map. + for to_node, from_node in [ + (first_map_entry, second_map_entry), + (first_map_exit, second_map_exit), + ]: + self.relocate_nodes( + from_node=from_node, + to_node=to_node, + state=graph, + sdfg=sdfg, + ) + # The relocate function does not remove the node, so we must do it. + graph.remove_node(from_node) + + def apply_serial_map_fusion( + self, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + ) -> None: + """Performs the serial Map fusing. + + The function first computes the map decomposition and then handles the + three sets. The pure outputs are handled by `relocate_nodes()` while + the two intermediate sets are handled by `handle_intermediate_set()`. + + By assumption we do not have to rename anything. + + :param graph: The SDFG state we are operating on. + :param sdfg: The SDFG we are operating on. + """ + assert self.expr_index == 0 + + # NOTE: The after this point it is not legal to access the matched nodes + first_map_exit: nodes.MapExit = self.first_map_exit + second_map_entry: nodes.MapEntry = self.second_map_entry + second_map_exit: nodes.MapExit = graph.exit_node(self.second_map_entry) + first_map_entry: nodes.MapEntry = graph.entry_node(self.first_map_exit) + + # Before we do anything we perform the renaming. + self.rename_map_parameters( + first_map=first_map_exit.map, + second_map=second_map_entry.map, + second_map_entry=second_map_entry, + state=graph, + ) + + # Now compute the partition. Because we have already renamed the parameters + # of the second Map, there is no need to perform any renaming, thus we can + # pass an empty `dict`. + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + param_repl=dict(), + ) + assert output_partition is not None # Make MyPy happy. + pure_outputs, exclusive_outputs, shared_outputs = output_partition + + # Now perform the actual rewiring, we handle each partition separately. + if len(exclusive_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=exclusive_outputs, + state=graph, + sdfg=sdfg, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + second_map_exit=second_map_exit, + is_exclusive_set=True, + ) + if len(shared_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=shared_outputs, + state=graph, + sdfg=sdfg, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + second_map_exit=second_map_exit, + is_exclusive_set=False, + ) + assert pure_outputs == set(graph.out_edges(first_map_exit)) + if len(pure_outputs) != 0: + self.relocate_nodes( + from_node=first_map_exit, + to_node=second_map_exit, + state=graph, + sdfg=sdfg, + ) + + # Now move the input of the second map, that has no connection to the first + # map, to the first map. This is needed because we will later delete the + # exit of the first map (which we have essentially handled above). Now + # we must handle the input of the second map (that has no connection to the + # first map) to the input of the first map. + self.relocate_nodes( + from_node=second_map_entry, + to_node=first_map_entry, + state=graph, + sdfg=sdfg, + ) + + for node_to_remove in [first_map_exit, second_map_entry]: + assert graph.degree(node_to_remove) == 0 + graph.remove_node(node_to_remove) + + # Now turn the second output node into the output node of the first Map. + second_map_exit.map = first_map_entry.map + + def partition_first_outputs( + self, + state: SDFGState, + sdfg: SDFG, + first_map_exit: nodes.MapExit, + second_map_entry: nodes.MapEntry, + param_repl: Dict[str, str], + ) -> Union[ + Tuple[ + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + ], + None, + ]: + """Partition the output edges of `first_map_exit` for serial map fusion. + + The output edges of the first map are partitioned into three distinct sets, + defined as follows: + * Pure Output Set `\mathbb{P}`: + These edges exits the first map and does not enter the second map. These + outputs will be simply be moved to the output of the second map. + * Exclusive Intermediate Set `\mathbb{E}`: + Edges in this set leaves the first map exit, enters an access node, from + where a Memlet then leads immediately to the second map. The memory + referenced by this access node is not used anywhere else, thus it can + be removed. + * Shared Intermediate Set `\mathbb{S}`: + These edges are very similar to the one in `\mathbb{E}` except that they + are used somewhere else, thus they can not be removed and have to be + recreated as output of the second map. + + If strict data flow mode is enabled the function is rather strict if an + output can be added to either intermediate set and might fail to compute + the partition, even if it would exist. + + :return: If such a decomposition exists the function will return the three sets + mentioned above in the same order. In case the decomposition does not exist, + i.e. the maps can not be fused the function returns `None`. + + :param state: The in which the two maps are located. + :param sdfg: The full SDFG in whcih we operate. + :param first_map_exit: The exit node of the first map. + :param second_map_entry: The entry node of the second map. + :param param_repl: Use this map to rename the parameter of the second Map, such + that they match the one of the first map. + """ + # The three outputs set. + pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + + # Set of intermediate nodes that we have already processed. + processed_inter_nodes: Set[nodes.Node] = set() + + # Now scan all output edges of the first exit and classify them + for out_edge in state.out_edges(first_map_exit): + intermediate_node: nodes.Node = out_edge.dst + + # We already processed the node, this should indicate that we should + # run simplify again, or we should start implementing this case. + # TODO(phimuell): Handle this case, already partially handled here. + if intermediate_node in processed_inter_nodes: + return None + processed_inter_nodes.add(intermediate_node) + + # The intermediate can only have one incoming degree. It might be possible + # to handle multiple incoming edges, if they all come from the top map. + # However, the resulting SDFG might be invalid. + # NOTE: Allow this to happen (under certain cases) if the only producer + # is the top map. + if state.in_degree(intermediate_node) != 1: + return None + + # If the second map is not reachable from the intermediate node, then + # the output is pure and we can end here. + if not self.is_node_reachable_from( + graph=state, + begin=intermediate_node, + end=second_map_entry, + ): + pure_outputs.add(out_edge) + continue + + # The following tests are _after_ we have determined if we have a pure + # output node, because this allows us to handle more exotic pure node + # cases, as handling them is essentially rerouting an edge, whereas + # handling intermediate nodes is much more complicated. + + # Empty Memlets are only allowed if they are in `\mathbb{P}`, which + # is also the only place they really make sense (for a map exit). + # Thus if we now found an empty Memlet we reject it. + if out_edge.data.is_empty(): + return None + + # For us an intermediate node must always be an access node, because + # everything else we do not know how to handle. It is important that + # we do not test for non transient data here, because they can be + # handled has shared intermediates. + if not isinstance(intermediate_node, nodes.AccessNode): + return None + intermediate_desc: dace.data.Data = intermediate_node.desc(sdfg) + if self.is_view(intermediate_desc, sdfg): + return None + + # It can happen that multiple edges converges at the `IN_` connector + # of the first map exit, but there is only one edge leaving the exit. + # It is complicate to handle this, so for now we ignore it. + # TODO(phimuell): Handle this case properly. + # To handle this we need to associate a consumer edge (the outgoing edges + # of the second map) with exactly one producer. + producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list( + state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:]) + ) + if len(producer_edges) > 1: + return None + + # Now check the constraints we have on the producers. + # - The source of the producer can not be a view (we do not handle this) + # - The edge shall also not be a reduction edge. + # - Defined location to where they write. + # - No dynamic Melets. + # Furthermore, we will also extract the subsets, i.e. the location they + # modify inside the intermediate array. + # Since we do not allow for WCR, we do not check if the producer subsets intersects. + producer_subsets: List[subsets.Subset] = [] + for producer_edge in producer_edges: + if isinstance(producer_edge.src, nodes.AccessNode) and self.is_view( + producer_edge.src, sdfg + ): + return None + if producer_edge.data.dynamic: + # TODO(phimuell): Find out if this restriction could be lifted, but it is unlikely. + return None + if producer_edge.data.wcr is not None: + return None + if producer_edge.data.dst_subset is None: + return None + producer_subsets.append(producer_edge.data.dst_subset) + + # Check if the producer do not intersect + if len(producer_subsets) == 1: + pass + elif len(producer_subsets) == 2: + if producer_subsets[0].intersects(producer_subsets[1]): + return None + else: + for i, psbs1 in enumerate(producer_subsets): + for j, psbs2 in enumerate(producer_subsets): + if i == j: + continue + if psbs1.intersects(psbs2): + return None + + # Now we determine the consumer of nodes. For this we are using the edges + # leaves the second map entry. It is not necessary to find the actual + # consumer nodes, as they might depend on symbols of nested Maps. + # For the covering test we only need their subsets, but we will perform + # some scan and filtering on them. + found_second_map = False + consumer_subsets: List[subsets.Subset] = [] + for intermediate_node_out_edge in state.out_edges(intermediate_node): + # If the second map entry is not immediately reachable from the intermediate + # node, then ensure that there is not path that goes to it. + if intermediate_node_out_edge.dst is not second_map_entry: + if self.is_node_reachable_from( + graph=state, begin=intermediate_node_out_edge.dst, end=second_map_entry + ): + return None + continue + + # Ensure that the second map is found exactly once. + # TODO(phimuell): Lift this restriction. + if found_second_map: + return None + found_second_map = True + + # The output of the top map can not define a dynamic map range in the + # second map. + if not intermediate_node_out_edge.dst_conn.startswith("IN_"): + return None + + # Now we look at all edges that leave the second map entry, i.e. the + # edges that feeds the consumer and define what is read inside the map. + # We do not check them, but collect them and inspect them. + # NOTE1: The subset still uses the old iteration variables. + # NOTE2: In case of consumer Memlet we explicitly allow dynamic Memlets. + # This is different compared to the producer Memlet. The reason is + # because in a consumer the data is conditionally read, so the data + # has to exists anyway. + for inner_consumer_edge in state.out_edges_by_connector( + second_map_entry, "OUT_" + intermediate_node_out_edge.dst_conn[3:] + ): + if inner_consumer_edge.data.src_subset is None: + return None + consumer_subsets.append(inner_consumer_edge.data.src_subset) + assert ( + found_second_map + ), f"Found '{intermediate_node}' which looked like a pure node, but is not one." + assert len(consumer_subsets) != 0 + + # The consumer still uses the original symbols of the second map, so we must rename them. + if param_repl: + consumer_subsets = copy.deepcopy(consumer_subsets) + for consumer_subset in consumer_subsets: + symbolic.safe_replace( + mapping=param_repl, replace_callback=consumer_subset.replace + ) + + # Now we are checking if a single iteration of the first (top) map + # can satisfy all data requirements of the second (bottom) map. + # For this we look if the producer covers the consumer. A consumer must + # be covered by exactly one producer. + for consumer_subset in consumer_subsets: + nb_coverings = sum( + producer_subset.covers(consumer_subset) for producer_subset in producer_subsets + ) + if nb_coverings != 1: + return None + + # After we have ensured coverage, we have to decide if the intermediate + # node can be removed (`\mathbb{E}`) or has to be restored (`\mathbb{S}`). + # Note that "removed" here means that it is reconstructed by a new + # output of the second map. + if self.is_shared_data(data=intermediate_node, state=state, sdfg=sdfg): + # The intermediate data is used somewhere else, either in this or another state. + # NOTE: If the intermediate is shared, then we will turn it into a + # sink node attached to the combined map exit. Technically this + # should be enough, even if the same data appears again in the + # dataflow down streams. However, some DaCe transformations, + # I am looking at you `auto_optimizer()` do not like that. Thus + # if the intermediate is used further down in the same datadflow, + # then we consider that the maps can not be fused. But we only + # do this in the strict data flow mode. + if self.strict_dataflow: + if self._is_data_accessed_downstream( + data=intermediate_node.data, + graph=state, + begin=intermediate_node, # is ignored itself. + ): + return None + shared_outputs.add(out_edge) + else: + # The intermediate can be removed, as it is not used anywhere else. + exclusive_outputs.add(out_edge) + + assert len(processed_inter_nodes) == sum( + len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs] + ) + return (pure_outputs, exclusive_outputs, shared_outputs) + + def relocate_nodes( + self, + from_node: Union[nodes.MapExit, nodes.MapEntry], + to_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + ) -> None: + """Move the connectors and edges from `from_node` to `to_nodes` node. + + This function will only rewire the edges, it does not remove the nodes + themselves. Furthermore, this function should be called twice per Map, + once for the entry and then for the exit. + While it does not remove the node themselves if guarantees that the + `from_node` has degree zero. + The function assumes that the parameter renaming was already done. + + :param from_node: Node from which the edges should be removed. + :param to_node: Node to which the edges should reconnect. + :param state: The state in which the operation happens. + :param sdfg: The SDFG that is modified. + """ + + # Now we relocate empty Memlets, from the `from_node` to the `to_node` + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.out_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_src=to_node) + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.in_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_dst=to_node) + + # We now ensure that there is only one empty Memlet from the `to_node` to any other node. + # Although it is allowed, we try to prevent it. + empty_targets: Set[nodes.Node] = set() + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.all_edges(to_node))): + if empty_edge.dst in empty_targets: + state.remove_edge(empty_edge) + empty_targets.add(empty_edge.dst) + + # We now determine which edges we have to migrate, for this we are looking at + # the incoming edges, because this allows us also to detect dynamic map ranges. + # TODO(phimuell): If there is already a connection to the node, reuse this. + for edge_to_move in list(state.in_edges(from_node)): + assert isinstance(edge_to_move.dst_conn, str) + + if not edge_to_move.dst_conn.startswith("IN_"): + # Dynamic Map Range + # The connector name simply defines a variable name that is used, + # inside the Map scope to define a variable. We handle it directly. + dmr_symbol = edge_to_move.dst_conn + + # TODO(phimuell): Check if the symbol is really unused in the target scope. + if dmr_symbol in to_node.in_connectors: + raise NotImplementedError( + f"Tried to move the dynamic map range '{dmr_symbol}' from {from_node}'" + f" to '{to_node}', but the symbol is already known there, but the" + " renaming is not implemented." + ) + if not to_node.add_in_connector(dmr_symbol, force=False): + raise RuntimeError( # Might fail because of out connectors. + f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'." + ) + helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) + from_node.remove_in_connector(dmr_symbol) + + else: + # We have a Passthrough connection, i.e. there exists a matching `OUT_`. + old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix + new_conn = to_node.next_connector(old_conn) + + to_node.add_in_connector("IN_" + new_conn) + for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): + helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) + to_node.add_out_connector("OUT_" + new_conn) + for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): + helpers.redirect_edge(state, e, new_src=to_node, new_src_conn="OUT_" + new_conn) + from_node.remove_in_connector("IN_" + old_conn) + from_node.remove_out_connector("OUT_" + old_conn) + + # Check if we succeeded. + if state.out_degree(from_node) != 0: + raise validation.InvalidSDFGError( + f"Failed to relocate the outgoing edges from `{from_node}`, there are still `{state.out_edges(from_node)}`", + sdfg, + sdfg.node_id(state), + ) + if state.in_degree(from_node) != 0: + raise validation.InvalidSDFGError( + f"Failed to relocate the incoming edges from `{from_node}`, there are still `{state.in_edges(from_node)}`", + sdfg, + sdfg.node_id(state), + ) + assert len(from_node.in_connectors) == 0 + assert len(from_node.out_connectors) == 0 + + def handle_intermediate_set( + self, + intermediate_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]], + state: SDFGState, + sdfg: SDFG, + first_map_exit: nodes.MapExit, + second_map_entry: nodes.MapEntry, + second_map_exit: nodes.MapExit, + is_exclusive_set: bool, + ) -> None: + """This function handles the intermediate sets. + + The function is able to handle both the shared and exclusive intermediate + output set, see `partition_first_outputs()`. The main difference is that + in exclusive mode the intermediate nodes will be fully removed from + the SDFG. While in shared mode the intermediate node will be preserved. + The function assumes that the parameter renaming was already done. + + :param intermediate_outputs: The set of outputs, that should be processed. + :param state: The state in which the map is processed. + :param sdfg: The SDFG that should be optimized. + :param first_map_exit: The exit of the first/top map. + :param second_map_entry: The entry of the second map. + :param second_map_exit: The exit of the second map. + :param is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. + + :note: Before the transformation the `state` does not have to be valid and + after this function has run the state is (most likely) invalid. + """ + + map_params = first_map_exit.map.params.copy() + + # Now we will iterate over all intermediate edges and process them. + # If not stated otherwise the comments assume that we run in exclusive mode. + for out_edge in intermediate_outputs: + # This is the intermediate node that, that we want to get rid of. + # In shared mode we want to recreate it after the second map. + inter_node: nodes.AccessNode = out_edge.dst + inter_name = inter_node.data + inter_desc = inter_node.desc(sdfg) + + # Now we will determine the shape of the new intermediate. This size of + # this temporary is given by the Memlet that goes into the first map exit. + pre_exit_edges = list( + state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:]) + ) + if len(pre_exit_edges) != 1: + raise NotImplementedError() + pre_exit_edge = pre_exit_edges[0] + + (new_inter_shape_raw, new_inter_shape, squeezed_dims) = ( + self.compute_reduced_intermediate( + producer_subset=pre_exit_edge.data.dst_subset, + inter_desc=inter_desc, + ) + ) + + # This is the name of the new "intermediate" node that we will create. + # It will only have the shape `new_inter_shape` which is basically its + # output within one Map iteration. + # NOTE: The insertion process might generate a new name. + new_inter_name: str = f"__s{self.state_id}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" + + # Now generate the intermediate data container. + if len(new_inter_shape) == 0: + assert pre_exit_edge.data.subset.num_elements() == 1 + is_scalar = True + new_inter_name, new_inter_desc = sdfg.add_scalar( + new_inter_name, + dtype=inter_desc.dtype, + transient=True, + find_new_name=True, + ) + + else: + assert (pre_exit_edge.data.subset.num_elements() > 1) or all( + x == 1 for x in new_inter_shape + ) + is_scalar = False + new_inter_name, new_inter_desc = sdfg.add_transient( + new_inter_name, + shape=new_inter_shape, + dtype=inter_desc.dtype, + find_new_name=True, + ) + new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) + + # Get the subset that defined into which part of the old intermediate + # the old output edge wrote to. We need that to adjust the producer + # Memlets, since they now write into the new (smaller) intermediate. + producer_offset = self.compute_offset_subset( + original_subset=pre_exit_edge.data.dst_subset, + intermediate_desc=inter_desc, + map_params=map_params, + producer_offset=None, + ) + + # Memlets have a lot of additional informations, to ensure that we get + # all of them, we have to do it this way. The main reason for this is + # to handle the case were the "Memlet reverse direction", i.e. `data` + # refers to the other end of the connection than before. + assert pre_exit_edge.data.dst_subset is not None + new_pre_exit_memlet_src_subset = copy.deepcopy(pre_exit_edge.data.src_subset) + new_pre_exit_memlet_dst_subset = subsets.Range.from_array(new_inter_desc) + + new_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + new_pre_exit_memlet.data = new_inter_name + + new_pre_exit_edge = state.add_edge( + pre_exit_edge.src, + pre_exit_edge.src_conn, + new_inter_node, + None, + new_pre_exit_memlet, + ) + + # We can update `{src, dst}_subset` only after we have inserted the + # edge, this is because the direction of the Memlet might change. + new_pre_exit_edge.data.src_subset = new_pre_exit_memlet_src_subset + new_pre_exit_edge.data.dst_subset = new_pre_exit_memlet_dst_subset + + # We now handle the MemletTree defined by this edge. + # The newly created edge, only handled the last collection step. + for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children( + include_self=False + ): + producer_edge = producer_tree.edge + + # In order to preserve the intrinsic direction of Memlets we only have to change + # the `.data` attribute of the producer Memlet if it refers to the old intermediate. + # If it refers to something different we keep it. Note that this case can only + # occur if the producer is an AccessNode. + if producer_edge.data.data == inter_name: + producer_edge.data.data = new_inter_name + + # Regardless of the intrinsic direction of the Memlet, the subset we care about + # is always `dst_subset`. + if is_scalar: + producer_edge.data.dst_subset = "0" + elif producer_edge.data.dst_subset is not None: + # Since we now write into a smaller memory patch, we must + # compensate for that. We do this by substracting where the write + # originally had begun. + producer_edge.data.dst_subset.offset(producer_offset, negative=True) + producer_edge.data.dst_subset.pop(squeezed_dims) + + # Now after we have handled the input of the new intermediate node, + # we must handle its output. For this we have to "inject" the newly + # created intermediate into the second map. We do this by finding + # the input connectors on the map entry, such that we know where we + # have to reroute inside the Map. + # NOTE: Assumes that map (if connected is the direct neighbour). + conn_names: Set[str] = set() + for inter_node_out_edge in state.out_edges(inter_node): + if inter_node_out_edge.dst == second_map_entry: + assert inter_node_out_edge.dst_conn.startswith("IN_") + conn_names.add(inter_node_out_edge.dst_conn) + else: + # If we found another target than the second map entry from the + # intermediate node it means that the node _must_ survive, + # i.e. we are not in exclusive mode. + assert not is_exclusive_set + + # Now we will reroute the connections inside the second map, i.e. + # instead of consuming the old intermediate node, they will now + # consume the new intermediate node. + for in_conn_name in conn_names: + out_conn_name = "OUT_" + in_conn_name[3:] + + for inner_edge in state.out_edges_by_connector(second_map_entry, out_conn_name): + # As for the producer side, we now read from a smaller array, + # So we must offset them, we use the original edge for this. + assert inner_edge.data.src_subset is not None + consumer_offset = self.compute_offset_subset( + original_subset=inner_edge.data.src_subset, + intermediate_desc=inter_desc, + map_params=map_params, + producer_offset=producer_offset, + ) + + # Now create the memlet for the new consumer. To make sure that we get all attributes + # of the Memlet we make a deep copy of it. There is a tricky part here, we have to + # access `src_subset` however, this is only correctly set once it is put inside the + # SDFG. Furthermore, we have to make sure that the Memlet does not change its direction. + # i.e. that the association of `subset` and `other_subset` does not change. For this + # reason we only modify `.data` attribute of the Memlet if its name refers to the old + # intermediate. Furthermore, to play it safe, we only access the subset, `src_subset` + # after we have inserted it to the SDFG. + new_inner_memlet = copy.deepcopy(inner_edge.data) + if inner_edge.data.data == inter_name: + new_inner_memlet.data = new_inter_name + + # Now we replace the edge from the SDFG. + state.remove_edge(inner_edge) + new_inner_edge = state.add_edge( + new_inter_node, + None, + inner_edge.dst, + inner_edge.dst_conn, + new_inner_memlet, + ) + + # Now modifying the Memlet, we do it after the insertion to make + # sure that the Memlet was properly initialized. + if is_scalar: + new_inner_memlet.subset = "0" + elif new_inner_memlet.src_subset is not None: + # TODO(phimuell): Figuring out if `src_subset` is None is an error. + new_inner_memlet.src_subset.offset(consumer_offset, negative=True) + new_inner_memlet.src_subset.pop(squeezed_dims) + + # Now we have to make sure that all consumers are properly updated. + for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children( + include_self=False + ): + consumer_edge = consumer_tree.edge + + # We only modify the data if the Memlet refers to the old intermediate data. + # We can not do this unconditionally, because it might change the intrinsic + # direction of a Memlet and then `src_subset` would at the next `try_initialize` + # be wrong. Note that this case only occurs if the destination is an AccessNode. + if consumer_edge.data.data == inter_name: + consumer_edge.data.data = new_inter_name + + # Now we have to adapt the subsets. + if is_scalar: + consumer_edge.data.src_subset = "0" + elif consumer_edge.data.src_subset is not None: + # TODO(phimuell): Figuring out if `src_subset` is None is an error. + consumer_edge.data.src_subset.offset(consumer_offset, negative=True) + consumer_edge.data.src_subset.pop(squeezed_dims) + + # The edge that leaves the second map entry was already deleted. We now delete + # the edges that connected the intermediate node with the second map entry. + for edge in list(state.in_edges_by_connector(second_map_entry, in_conn_name)): + assert edge.src == inter_node + state.remove_edge(edge) + second_map_entry.remove_in_connector(in_conn_name) + second_map_entry.remove_out_connector(out_conn_name) + + if is_exclusive_set: + # In exclusive mode the old intermediate node is no longer needed. + # This will also remove `out_edge` from the SDFG. + assert state.degree(inter_node) == 1 + state.remove_edge_and_connectors(out_edge) + state.remove_node(inter_node) + + state.remove_edge(pre_exit_edge) + first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) + first_map_exit.remove_out_connector(out_edge.src_conn) + del sdfg.arrays[inter_name] + + else: + # TODO(phimuell): Lift this restriction + assert pre_exit_edge.data.data == inter_name + + # This is the shared mode, so we have to recreate the intermediate + # node, but this time it is at the exit of the second map. + state.remove_edge(pre_exit_edge) + first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) + + # This is the Memlet that goes from the map internal intermediate + # temporary node to the Map output. This will essentially restore + # or preserve the output for the intermediate node. It is important + # that we use the data that `preExitEdge` was used. + final_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc) + + new_pre_exit_conn = second_map_exit.next_connector() + state.add_edge( + new_inter_node, + None, + second_map_exit, + "IN_" + new_pre_exit_conn, + final_pre_exit_memlet, + ) + state.add_edge( + second_map_exit, + "OUT_" + new_pre_exit_conn, + inter_node, + out_edge.dst_conn, + copy.deepcopy(out_edge.data), + ) + second_map_exit.add_in_connector("IN_" + new_pre_exit_conn) + second_map_exit.add_out_connector("OUT_" + new_pre_exit_conn) + + first_map_exit.remove_out_connector(out_edge.src_conn) + state.remove_edge(out_edge) + + def compute_reduced_intermediate( + self, + producer_subset: subsets.Range, + inter_desc: dace.data.Data, + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], List[int]]: + """Compute the size of the new (reduced) intermediate. + + `MapFusion` does not only fuses map, but, depending on the situation, also + eliminates intermediate arrays between the two maps. To transmit data between + the two maps a new, but much smaller intermediate is needed. + + :return: The function returns a tuple with three values with the following meaning: + * The raw shape of the reduced intermediate. + * The cleared shape of the reduced intermediate, essentially the raw shape + with all shape 1 dimensions removed. + * Which dimensions of the raw shape have been removed to get the cleared shape. + + :param producer_subset: The subset that was used to write into the intermediate. + :param inter_desc: The data descriptor for the intermediate. + """ + assert producer_subset is not None + + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + new_inter_shape_raw = symbolic.overapproximate(producer_subset.size()) + inter_shape = inter_desc.shape + if not self.strict_dataflow: + squeezed_dims: List[int] = [] # These are the dimensions we removed. + new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate( + zip(new_inter_shape_raw, inter_shape) + ): + if full_dim_size == 1: # Must be kept! + new_inter_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_inter_shape.append(proposed_dim_size) + else: + squeezed_dims = [] + new_inter_shape = list(new_inter_shape_raw) + + return (tuple(new_inter_shape_raw), tuple(new_inter_shape), squeezed_dims) + + def compute_offset_subset( + self, + original_subset: subsets.Range, + intermediate_desc: data.Data, + map_params: List[str], + producer_offset: Union[subsets.Range, None], + ) -> subsets.Range: + """Computes the memlet to correct read and writes of the intermediate. + + This is the value that must be substracted from the memlets to adjust, i.e + (`memlet_to_adjust(correction, negative=True)`). If `producer_offset` is + `None` then the function computes the correction that should be applied to + the producer memlets, i.e. the memlets of the tree converging at + `intermediate_node`. If `producer_offset` is given, it should be the output + of the previous call to this function, with `producer_offset=None`. In this + case the function computes the correction for the consumer side, i.e. the + memlet tree that originates at `intermediate_desc`. + + :param original_subset: The original subset that was used to write into the + intermediate, must be renamed to the final map parameter. + :param intermediate_desc: The original intermediate data descriptor. + :param map_params: The parameter of the final map. + :param producer_offset: The correction that was applied to the producer side. + """ + assert not isinstance(intermediate_desc, data.View) + final_offset: subsets.Range = None + if isinstance(intermediate_desc, data.Scalar): + # If the intermediate was a scalar, then it will remain a scalar. + # Thus there is no correction that we must apply. + return subsets.Range.from_string("0") + + elif isinstance(intermediate_desc, data.Array): + basic_offsets = original_subset.min_element() + offset_list = [] + for d in range(original_subset.dims()): + d_range = subsets.Range([original_subset[d]]) + if d_range.free_symbols.intersection(map_params): + offset_list.append(d_range[0]) + else: + offset_list.append((basic_offsets[d], basic_offsets[d], 1)) + final_offset = subsets.Range(offset_list) + + else: + raise TypeError( + f"Does not know how to compute the subset offset for '{type(intermediate_desc).__name__}'." + ) + + if producer_offset is not None: + # Here we are correcting some parts that over approximate (which partially + # does under approximate) might screw up. Consider two maps, the first + # map only writes the subset `[:, 2:6]`, thus the new intermediate will + # have shape `(1, 4)`. Now also imagine that the second map only reads + # the elements `[:, 3]`. From this we see that we can only correct the + # consumer side if we also take the producer side into consideration! + # See also the `transformations/mapfusion_test.py::test_offset_correction_*` + # tests for more. + final_offset.offset( + final_offset.offset_new( + producer_offset, + negative=True, + ), + negative=True, + ) + return final_offset + + def can_topologically_be_fused( + self, + first_map_entry: nodes.MapEntry, + second_map_entry: nodes.MapEntry, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + permissive: bool = False, + ) -> Optional[Dict[str, str]]: + """Performs basic checks if the maps can be fused. + + This function only checks constrains that are common between serial and + parallel map fusion process, which includes: + * The scope of the maps. + * The scheduling of the maps. + * The map parameters. + + :return: If the maps can not be topologically fused the function returns `None`. + If they can be fused the function returns `dict` that describes parameter + replacement, see `find_parameter_remapping()` for more. + + :param first_map_entry: The entry of the first (in serial case the top) map. + :param second_map_exit: The entry of the second (in serial case the bottom) map. + :param graph: The SDFGState in which the maps are located. + :param sdfg: The SDFG itself. + :param permissive: Currently unused. + """ + if self.only_inner_maps and self.only_toplevel_maps: + raise ValueError( + "Only one of `only_inner_maps` and `only_toplevel_maps` is allowed per MapFusion instance." + ) + + # Ensure that both have the same schedule + if first_map_entry.map.schedule != second_map_entry.map.schedule: + return None + + # Fusing is only possible if the two entries are in the same scope. + scope = graph.scope_dict() + if scope[first_map_entry] != scope[second_map_entry]: + return None + elif self.only_inner_maps: + if scope[first_map_entry] is None: + return None + elif self.only_toplevel_maps: + if scope[first_map_entry] is not None: + return None + + # We will now check if we can rename the Map parameter of the second Map such that they + # match the one of the first Map. + param_repl = self.find_parameter_remapping( + first_map=first_map_entry.map, second_map=second_map_entry.map + ) + return param_repl + + def is_parallel( + self, + graph: SDFGState, + node1: nodes.Node, + node2: nodes.Node, + ) -> bool: + """Tests if `node1` and `node2` are parallel in the data flow graph. + + The function considers two nodes parallel in the data flow graph, if `node2` + can not be reached from `node1` and vice versa. + + :param graph: The state on which we operate. + :param node1: The first node to check. + :param node2: The second node to check. + """ + # In order to be parallel they must be in the same scope. + scope = graph.scope_dict() + if scope[node1] != scope[node2]: + return False + + # The `all_nodes_between()` function traverse the graph and returns `None` if + # `end` was not found. We have to call it twice, because we do not know + # which node is upstream if they are not parallel. + if self.is_node_reachable_from(graph=graph, begin=node1, end=node2): + return False + elif self.is_node_reachable_from(graph=graph, begin=node2, end=node1): + return False + return True + + def has_inner_read_write_dependency( + self, + first_map_entry: nodes.MapEntry, + second_map_entry: nodes.MapEntry, + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """This function tests if there are dependency inside the Maps. + + The function will scan and anaysize the body of the two Maps and look for + inconsistencies. To detect them the function will scan the body of the maps + and examine the all AccessNodes and apply the following rules: + * If an AccessNode refers to a View, it is ignored. Because the source is + either on the outside, in which case `has_read_write_dependency()` + takes care of it, or the data source is inside the Map body itself. + * An inconsistency is detected, if in each bodies there exists an AccessNode + that refer to the same data. + * An inconsistency is detected, if there exists an AccessNode that refers + to non transient data. This is an implementation detail and could be + lifted. + + Note that some of the restrictions of this function could be relaxed by + performing more analysis. + + :return: The function returns `True` if an inconsistency has been found. + + :param first_map_entry: The entry node of the first map. + :param second_map_entry: The entry node of the second map. + :param state: The state on which we operate. + :param sdfg: The SDFG on which we operate. + """ + first_map_body = state.scope_subgraph(first_map_entry, False, False) + second_map_body = state.scope_subgraph(second_map_entry, False, False) + + # Find the data that is internally referenced. Because of the first rule above, + # we filter all views above. + first_map_body_data, second_map_body_data = [ + { + dnode.data + for dnode in map_body.nodes() + if isinstance(dnode, nodes.AccessNode) and not self.is_view(dnode, sdfg) + } + for map_body in [first_map_body, second_map_body] + ] + + # If there is data that is referenced in both, then we consider this as an error + # this is the second rule above. + if not first_map_body_data.isdisjoint(second_map_body_data): + return True + + # We consider it as a problem if any map refers to non-transient data. + # This is an implementation detail and could be dropped if we do further + # analysis. + if any( + not sdfg.arrays[data].transient + for data in first_map_body_data.union(second_map_body_data) + ): + return True + + return False + + def has_read_write_dependency( + self, + first_map_entry: nodes.MapEntry, + second_map_entry: nodes.MapEntry, + param_repl: Dict[str, str], + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """Test if there is a read write dependency between the two maps to be fused. + + The function checks three different things. + * The function will make sure that there is no read write dependency between + the input and output of the fused maps. For that it will inspect the + respective subsets of the inputs of the MapEntry of the first and the + outputs of the MapExit node of the second map. + * The second part partially checks the intermediate nodes, it mostly ensures + that there are not views and that they are not used as output of the + combined map. Note that it is allowed that an intermediate node is also + an input to the first map. + * In case an intermediate node, is also used as input node of the first map, + it is forbidden that the data is used as output of the second map, the + function will do additional checks. This is needed as the partition function + only checks the data consumption of the second map can be satisfied by the + data production of the first map, it ignores any potential reads made by + the first map's MapEntry. + + :return: `True` if there is a conflict between the maps that can not be handled. + If there is no conflict or if the conflict can be handled `False` is returned. + + :param first_map_entry: The entry node of the first map. + :param second_map_entry: The entry node of the second map. + :param param_repl: Dict that describes how to rename the parameters of the second Map. + :param state: The state on which we operate. + :param sdfg: The SDFG on which we operate. + """ + first_map_exit: nodes.MapExit = state.exit_node(first_map_entry) + second_map_exit: nodes.MapExit = state.exit_node(second_map_entry) + + # Get the read and write sets of the different maps, note that Views + # are not resolved yet. + access_sets: List[Dict[str, nodes.AccessNode]] = [] + for scope_node in [first_map_entry, first_map_exit, second_map_entry, second_map_exit]: + access_set: Set[nodes.AccessNode] = self.get_access_set(scope_node, state) + access_sets.append({node.data: node for node in access_set}) + # If two different access nodes of the same scoping node refers to the + # same data, then we consider this as a dependency we can not handle. + # It is only a problem for the intermediate nodes and might be possible + # to handle, but doing so is hard, so we just forbid it. + if len(access_set) != len(access_sets[-1]): + return True + read_map_1, write_map_1, read_map_2, write_map_2 = access_sets + + # It might be possible that there are views, so we have to resolve them. + # We also already get the name of the data container. + # Note that `len(real_read_map_1) <= len(read_map_1)` holds because of Views. + resolved_sets: List[Set[str]] = [] + for unresolved_set in [read_map_1, write_map_1, read_map_2, write_map_2]: + resolved_sets.append( + { + self.track_view(node, state, sdfg).data + if self.is_view(node, sdfg) + else node.data + for node in unresolved_set.values() + } + ) + # If the resolved and unresolved names do not have the same length. + # Then different views point to the same location, which we forbid + if len(unresolved_set) != len(resolved_sets[-1]): + return False + real_read_map_1, real_write_map_1, real_read_map_2, real_write_map_2 = resolved_sets + + # We do not allow that the first and second map each write to the same data. + # This essentially ensures that an intermediate can not be used as output of + # the second map at the same time. It is actually stronger as it does not + # take their role into account. + if not real_write_map_1.isdisjoint(real_write_map_2): + return True + + # These are the names (unresolved) and the access nodes of the data that is used + # to transmit information between the maps. The partition function ensures that + # these nodes are directly connected to the two maps. + exchange_names: Set[str] = set(write_map_1.keys()).intersection(read_map_2.keys()) + exchange_nodes: Set[nodes.AccessNode] = set(write_map_1.values()).intersection( + read_map_2.values() + ) + + # If the number are different then a data is accessed through different + # AccessNodes. We could analyse this, but we will consider this as a data race. + if len(exchange_names) != len(exchange_nodes): + return True + assert all(exchange_node.data in exchange_names for exchange_node in exchange_nodes) + + # For simplicity we assume that the nodes used for exchange are not views. + if any(self.is_view(exchange_node, sdfg) for exchange_node in exchange_nodes): + return True + + # This is the names of the node that are used as input of the first map and + # as output of the second map. We have to ensure that there is no data + # dependency between these nodes. + # NOTE: This set is not required to be empty. It might look as this would + # create a data race, but it is save. The reason is because all data has + # to pass through the intermediate we create, this will separate the reads + # from the writes. + fused_inout_data_names: Set[str] = set(read_map_1.keys()).intersection(write_map_2.keys()) + + # If a data container is used as input and output then it can not be a view (simplicity) + if any(self.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names): + return True + + # A data container can not be used as output (of the second as well as the + # combined map) and as intermediate. If we would allow that the map would + # have two output nodes one the original one and the second is the created + # node that is created because the intermediate is shared. + # TODO(phimuell): Handle this case. + if not fused_inout_data_names.isdisjoint(exchange_names): + return True + + # While it is forbidden that a data container, used as intermediate, is also + # used as output of the second map. It is allowed that the data container + # is used as intermediate and as input of the first map. The partition only + # checks that the data dependencies are mean, i.e. what is read by the second + # map is also computed (written to the intermediate) it does not take into + # account the first map's read to the data container. + # To make an example: The partition function will make sure that if the + # second map reads index `i` from the intermediate that the first map writes + # to that index. But it will not care if the first map reads (through its + # MapEntry) index `i + 1`. In order to be valid me must ensure that the first + # map's reads and writes to the intermediate are pointwise. + # Note that we only have to make this check if it is also an intermediate node. + # Because if it is not read by the second map it is not a problem as the node + # will end up as an pure output node anyway. + read_write_map_1 = set(read_map_1.keys()).intersection(write_map_1.keys()) + datas_to_inspect = read_write_map_1.intersection(exchange_names) + for data_to_inspect in datas_to_inspect: + # Now get all subsets of the data container that the first map reads + # from or writes to and check if they are pointwise. + all_subsets: List[subsets.Subset] = [] + all_subsets.extend( + self.find_subsets( + node=read_map_1[data_to_inspect], + scope_node=first_map_entry, + state=state, + sdfg=sdfg, + param_repl=None, + ) + ) + all_subsets.extend( + self.find_subsets( + node=write_map_1[data_to_inspect], + scope_node=first_map_exit, + state=state, + sdfg=sdfg, + param_repl=None, + ) + ) + if not self.test_if_subsets_are_point_wise(all_subsets): + return True + del all_subsets + + # If there is no intersection between the input and output data, then we can + # we have nothing to check. + if len(fused_inout_data_names) == 0: + return False + + # Now we inspect if there is a read write dependency, between data that is + # used as input and output of the fused map. There is no problem is they + # are pointwise, i.e. in each iteration the same locations are accessed. + # Essentially they all boil down to `a += 1`. + for inout_data_name in fused_inout_data_names: + all_subsets = [] + # The subsets that define reading are given by the first map's entry node + all_subsets.extend( + self.find_subsets( + node=read_map_1[inout_data_name], + scope_node=first_map_entry, + state=state, + sdfg=sdfg, + param_repl=None, + ) + ) + # While the subsets defining writing are given by the second map's exit + # node, there we also have to apply renaming. + all_subsets.extend( + self.find_subsets( + node=write_map_2[inout_data_name], + scope_node=second_map_exit, + state=state, + sdfg=sdfg, + param_repl=param_repl, + ) + ) + # Now we can test if these subsets are point wise + if not self.test_if_subsets_are_point_wise(all_subsets): + return True + del all_subsets + + # No read write dependency was found. + return False + + def test_if_subsets_are_point_wise(self, subsets_to_check: List[subsets.Subset]) -> bool: + """Point wise means that they are all the same. + + If a series of subsets are point wise it means that all Memlets, access + the same data. This is an important property because the whole map fusion + is build upon this. + If the subsets originates from different maps, then they must have been + renamed. + + :param subsets_to_check: The list of subsets that should be checked. + """ + assert len(subsets_to_check) > 1 + + # We will check everything against the master subset. + master_subset = subsets_to_check[0] + for ssidx in range(1, len(subsets_to_check)): + subset = subsets_to_check[ssidx] + if isinstance(subset, subsets.Indices): + subset = subsets.Range.from_indices(subset) + # Do we also need the reverse? See below why. + if any(r != (0, 0, 1) for r in subset.offset_new(master_subset, negative=True)): + return False + else: + # The original code used `Range.offset` here, but that one had trouble + # for `r1 = 'j, 0:10'` and `r2 = 'j, 0`. The solution would be to test + # symmetrically, i.e. `r1 - r2` and `r2 - r1`. However, if we would + # have `r2_1 = 'j, 0:10'` it consider it as failing, which is not + # what we want. Thus we will use symmetric cover. + if not master_subset.covers(subset): + return False + if not subset.covers(master_subset): + return False + + # All subsets are equal to the master subset, thus they are equal to each other. + # This means that the data accesses, described by this transformation is + # point wise + return True + + def is_shared_data( + self, + data: nodes.AccessNode, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> bool: + """Tests if `data` is shared data, i.e. it can not be removed from the SDFG. + + Depending on the situation, the function will not perform a scan of the whole SDFG: + 1) If `assume_always_shared` was set to `True`, the function will return `True` unconditionally. + 2) If `data` is non transient then the function will return `True`, as non transient data + must be reconstructed always. + 3) If the AccessNode `data` has more than one outgoing edge or more than one incoming edge + it is classified as shared. + 2) If `FindSingleUseData` is in the pipeline it will be used and no scan will be performed. + 3) The function will perform a scan. + + :param data: The transient that should be checked. + :param state: The state in which the fusion is performed. + :param sdfg: The SDFG in which we want to perform the fusing. + + """ + # `assume_always_shared` takes precedence. + if self.assume_always_shared: + return True + + # If `data` is non transient then return `True` as the intermediate can not be removed. + if not data.desc(sdfg).transient: + return True + + # This means the data is consumed by multiple Maps, through the same AccessNode, in this state + # Note currently multiple incoming edges are not handled, but in the spirit of this function + # we consider such AccessNodes as shared, because we can not remove the intermediate. + if state.out_degree(data) > 1: + return True + if state.in_degree(data) > 1: + return True + + # NOTE: Actually, if this transformation is run through the `FullMapFusion` pass, it should + # read the results from `FindSingelUseData`, that was computed because it is a dependent + # pass through the `self._pipeline_results` which is set by the `SingleStateTransformation`. + # However, this member is only set during when `apply()` is called, but not during + # `can_be_applied()`, see [issue#1911](https://github.com/spcl/dace/issues/1911). + # Because, the whole goal of this separation of scanning and fusion was to make the + # transformation stateless, the member `_single_use_data` was introduced. If it is set + # then we use it otherwise we use the scanner. + # This value is set for example by the `FullMapFusion` pass. + # TODO(phimuell): Change this once the issue is resolved. + if self._single_use_data is not None: + assert ( + sdfg in self._single_use_data + ), f"`_single_use_data` was set, but does not contain information about the SDFG '{sdfg.name}'." + single_use_data: Set[str] = self._single_use_data[sdfg] + return data.data not in single_use_data + + # We have to perform the full scan of the SDFG. + return self._scan_sdfg_if_data_is_shared(data=data, state=state, sdfg=sdfg) + + def _scan_sdfg_if_data_is_shared( + self, + data: nodes.AccessNode, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> bool: + """Scans `sdfg` to determine if `data` is shared. + + Essentially, this function determine, if the intermediate AccessNode `data` is + can be removed or if it has to be restored as output of the Map. + A data descriptor is classified as shared if any of the following is true: + - `data` is non transient data. + - `data` has at most one incoming and/or outgoing edge. + - There are other AccessNodes beside `data` that refer to the same data. + - The data is accessed on an interstate edge. + + This function should not be called directly. Instead it is called indirectly + by `is_shared_data()` if there is no short cut. + + :param data: The AccessNode that should checked if it is shared. + :param sdfg: The SDFG for which the set of shared data should be computed. + """ + if not data.desc(sdfg).transient: + return True + + # See description in `is_shared_data()` for more. + if state.out_degree(data) > 1: + return True + if state.in_degree(data) > 1: + return True + + data_name: str = data.data + for state in sdfg.states(): + for dnode in state.data_nodes(): + if dnode is data: + # We have found the `data` AccessNode, which we must ignore. + continue + if dnode.data == data_name: + # We found a different AccessNode that refers to the same data + # as `data`. Thus `data` is shared. + return True + + # Test if the data is referenced in the interstate edges. + for edge in sdfg.edges(): + if data_name in edge.data.free_symbols: + # The data is used in the inter state edges. So it is shared. + return True + + # Test if the data is referenced inside a control flow, such as a conditional + # block or loop condition. + for cfr in sdfg.all_control_flow_regions(): + if data_name in cfr.used_symbols(all_symbols=True, with_contents=False): + return True + + # The `data` is not used anywhere else, thus `data` is not shared. + return False + + def find_parameter_remapping( + self, first_map: nodes.Map, second_map: nodes.Map + ) -> Optional[Dict[str, str]]: + """Computes the parameter remapping for the parameters of the _second_ map. + + The returned `dict` maps the parameters of the second map (keys) to parameter + names of the first map (values). Because of how the replace function works + the `dict` describes how to replace the parameters of the second map + with parameters of the first map. + Parameters that already have the correct name and compatible range, are not + included in the return value, thus the keys and values are always different. + If no renaming at is _needed_, i.e. all parameter have the same name and range, + then the function returns an empty `dict`. + If no remapping exists, then the function will return `None`. + + :param first_map: The first map (these parameters will be replaced). + :param second_map: The second map, these parameters acts as source. + + :note: This function currently fails if the renaming is not unique. Consider the + case were the first map has the structure `for i, j in map[0:20, 0:20]` and it + writes `T[i, j]`, while the second map is equivalent to + `for l, k in map[0:20, 0:20]` which reads `T[l, k]`. For this case we have + the following valid remappings `{l: i, k: j}` and `{l: j, k: i}` but + only the first one allows to fuse the map. This is because if the second + one is used the second map will read `T[j, i]` which leads to a data + dependency that can not be satisfied. + To avoid this issue the renaming algorithm will process them in order, i.e. + assuming that the order of the parameters in the map matches. But this is + not perfect, the only way to really solve this is by trying possible + remappings. At least the algorithm used here is deterministic. + """ + + # The parameter names + first_params: List[str] = first_map.params + second_params: List[str] = second_map.params + + if len(first_params) != len(second_params): + return None + + # The ranges, however, we apply some post processing to them. + simp = lambda e: symbolic.simplify_ext(symbolic.simplify(e)) # noqa: E731 [lambda-assignment] + first_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) for param, rng in zip(first_params, first_map.range) + } + second_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) + for param, rng in zip(second_params, second_map.range) + } + + # Parameters of the second map that have not yet been matched to a parameter + # of the first map and the parameters of the first map that are still free. + # That we use a `list` instead of a `set` is intentional, because it counter + # acts the issue that is described in the doc string. Using a list ensures + # that they indexes are matched in order. This assume that in real world + # code the order of the loop is not arbitrary but kind of matches. + unmapped_second_params: List[str] = list(second_params) + unused_first_params: List[str] = list(first_params) + + # This is the result (`second_param -> first_param`), note that if no renaming + # is needed then the parameter is not present in the mapping. + final_mapping: Dict[str, str] = {} + + # First we identify the parameters that already have the correct name. + for param in set(first_params).intersection(second_params): + first_rng = first_rngs[param] + second_rng = second_rngs[param] + + if first_rng == second_rng: + # They have the same name and the same range, this is already a match. + # Because the names are already the same, we do not have to enter them + # in the `final_mapping` + unmapped_second_params.remove(param) + unused_first_params.remove(param) + + # Check if no remapping is needed. + if len(unmapped_second_params) == 0: + return {} + + # Now we go through all the parameters that we have not mapped yet. + # All of them will result in a remapping. + for unmapped_second_param in unmapped_second_params: + second_rng = second_rngs[unmapped_second_param] + assert unmapped_second_param not in final_mapping + + # Now look in all not yet used parameters of the first map which to use. + for candidate_param in list(unused_first_params): + candidate_rng = first_rngs[candidate_param] + if candidate_rng == second_rng: + final_mapping[unmapped_second_param] = candidate_param + unused_first_params.remove(candidate_param) + break + else: + # We did not find a candidate, so the remapping does not exist + return None + + assert len(unused_first_params) == 0 + assert len(final_mapping) == len(unmapped_second_params) + return final_mapping + + def rename_map_parameters( + self, + first_map: nodes.Map, + second_map: nodes.Map, + second_map_entry: nodes.MapEntry, + state: SDFGState, + ) -> None: + """Replaces the map parameters of the second map with names from the first. + + The replacement is done in a safe way, thus `{'i': 'j', 'j': 'i'}` is + handled correct. The function assumes that a proper replacement exists. + The replacement is computed by calling `self.find_parameter_remapping()`. + + :param first_map: The first map (these are the final parameter). + :param second_map: The second map, this map will be replaced. + :param second_map_entry: The entry node of the second map. + :param state: The SDFGState on which we operate. + """ + # Compute the replacement dict. + repl_dict: Dict[str, str] = self.find_parameter_remapping( # type: ignore[assignment] # Guaranteed to be not `None`. + first_map=first_map, second_map=second_map + ) + + if repl_dict is None: + raise RuntimeError("The replacement does not exist") + if len(repl_dict) == 0: + return + + second_map_scope = state.scope_subgraph(entry_node=second_map_entry) + # Why is this thing is symbolic and not in replace? + symbolic.safe_replace( + mapping=repl_dict, + replace_callback=second_map_scope.replace_dict, + ) + + # For some odd reason the replace function does not modify the range and + # parameter of the map, so we will do it the hard way. + second_map.params = copy.deepcopy(first_map.params) + second_map.range = copy.deepcopy(first_map.range) + + def is_node_reachable_from( + self, + graph: dace.SDFGState, + begin: nodes.Node, + end: nodes.Node, + ) -> bool: + """Test if the node `end` can be reached from `begin`. + + Essentially the function starts a DFS at `begin`. If an edge is found that lead + to `end` the function returns `True`. If the node is never found `False` is + returned. + + :param graph: The graph to operate on. + :param begin: The start of the DFS. + :param end: The node that should be located. + """ + + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.dst for edge in graph.out_edges(node)) + + to_visit: List[nodes.Node] = [begin] + seen: Set[nodes.Node] = set() + + while len(to_visit) > 0: + node: nodes.Node = to_visit.pop() + if node == end: + return True + elif node not in seen: + to_visit.extend(next_nodes(node)) + seen.add(node) + + # We never found `end` + return False + + def _is_data_accessed_downstream( + self, + data: str, + graph: dace.SDFGState, + begin: nodes.Node, + ) -> bool: + """Tests if there is an AccessNode for `data` downstream of `begin`. + + Essentially, this function starts a DFS at `begin` and checks every + AccessNode that is reachable from it. If it finds such a node it will + check if it refers to `data` and if so, it will return `True`. + If no such node is found it will return `False`. + Note that the node `begin` will be ignored. + + :param data: The name of the data to look for. + :param graph: The graph to explore. + :param begin: The node to start exploration; The node itself is ignored. + """ + + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.dst for edge in graph.out_edges(node)) + + # Dataflow graph is acyclic, so we do not need to keep a list of + # what we have visited. + to_visit: List[nodes.Node] = list(next_nodes(begin)) + while len(to_visit) > 0: + node = to_visit.pop() + if isinstance(node, nodes.AccessNode) and node.data == data: + return True + to_visit.extend(next_nodes(node)) + + return False + + def get_access_set( + self, + scope_node: Union[nodes.MapEntry, nodes.MapExit], + state: SDFGState, + ) -> Set[nodes.AccessNode]: + """Computes the access set of a "scope node". + + If `scope_node` is a `MapEntry` it will operate on the set of incoming edges + and if it is an `MapExit` on the set of outgoing edges. The function will + then determine all access nodes that have a connection through these edges + to the scope nodes (edges that does not lead to access nodes are ignored). + The function returns a set that contains all access nodes that were found. + It is important that this set will also contain views. + + :param scope_node: The scope node that should be evaluated. + :param state: The state in which we operate. + """ + if isinstance(scope_node, nodes.MapEntry): + get_edges = lambda node: state.in_edges(node) # noqa: E731 [lambda-assignment] + other_node = lambda e: e.src # noqa: E731 [lambda-assignment] + else: + get_edges = lambda node: state.out_edges(node) # noqa: E731 [lambda-assignment] + other_node = lambda e: e.dst # noqa: E731 [lambda-assignment] + access_set: Set[nodes.AccessNode] = { + node + for node in map(other_node, get_edges(scope_node)) + if isinstance(node, nodes.AccessNode) + } + + return access_set + + def find_subsets( + self, + node: nodes.AccessNode, + scope_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + param_repl: Optional[Dict[str, str]], + ) -> List[subsets.Subset]: + """Finds all subsets that access `node` within `scope_node`. + + The function will not start a search for all consumer/producers. + Instead it will locate the edges which is immediately inside the + map scope. + + :param node: The access node that should be examined. + :param scope_node: We are only interested in data that flows through this node. + :param state: The state in which we operate. + :param sdfg: The SDFG object. + :param param_repl: `dict` that describes the parameter renaming that should be + performed. Can be `None` to skip the processing. + """ + # Is the node used for reading or for writing. + # This influences how we have to proceed. + if isinstance(scope_node, nodes.MapEntry): + outer_edges_to_inspect = [e for e in state.in_edges(scope_node) if e.src == node] + get_subset = lambda e: e.data.src_subset # noqa: E731 [lambda-assignment] + get_inner_edges = ( # noqa: E731 [lambda-assignment] + lambda e: state.out_edges_by_connector(scope_node, "OUT_" + e.dst_conn[3:]) + ) + else: + outer_edges_to_inspect = [e for e in state.out_edges(scope_node) if e.dst == node] + get_subset = lambda e: e.data.dst_subset # noqa: E731 [lambda-assignment] + get_inner_edges = ( # noqa: E731 [lambda-assignment] + lambda e: state.in_edges_by_connector(scope_node, "IN_" + e.src_conn[4:]) + ) + + found_subsets: List[subsets.Subset] = [] + for edge in outer_edges_to_inspect: + found_subsets.extend(get_subset(e) for e in get_inner_edges(edge)) + assert len(found_subsets) > 0, "Could not find any subsets." + assert not any(subset is None for subset in found_subsets) + + found_subsets = copy.deepcopy(found_subsets) + if param_repl: + for subset in found_subsets: + # Replace happens in place + symbolic.safe_replace(param_repl, subset.replace) + + return found_subsets + + def is_view( + self, + node: Union[nodes.AccessNode, data.Data], + sdfg: SDFG, + ) -> bool: + """Tests if `node` points to a view or not.""" + node_desc: data.Data = node if isinstance(node, data.Data) else node.desc(sdfg) + return isinstance(node_desc, data.View) + + def track_view( + self, + view: nodes.AccessNode, + state: SDFGState, + sdfg: SDFG, + ) -> nodes.AccessNode: + """Find the original data of a View. + + Given the View `view`, the function will trace the view back to the original + access node. For convenience, if `view` is not a `View` the argument will be + returned. + + :param view: The view that should be traced. + :param state: The state in which we operate. + :param sdfg: The SDFG on which we operate. + """ + + # Test if it is a view at all, if not return the passed node as source. + if not self.is_view(view, sdfg): + return view + + # This is the node that defines the view. + defining_node = dace.sdfg.utils.get_last_view_node(state, view) + assert isinstance(defining_node, nodes.AccessNode) + assert not self.is_view(defining_node, sdfg) + return defining_node diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_helper.py deleted file mode 100644 index 03e5973c3c..0000000000 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_helper.py +++ /dev/null @@ -1,676 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -"""Implements Helper functionaliyies for map fusion - -THIS FILE WAS COPIED FROM DACE TO FACILITATE DEVELOPMENT UNTIL THE PR#1625 IN -DACE IS MERGED AND THE VERSION WAS UPGRADED. -""" - - -# ruff: noqa - -import copy -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union, Callable, TypeAlias - -import dace -from dace import data, properties, subsets, symbolic, transformation -from dace.sdfg import SDFG, SDFGState, nodes, validation -from dace.transformation import helpers - -FusionCallback: TypeAlias = Callable[ - ["MapFusionHelper", nodes.MapEntry, nodes.MapEntry, dace.SDFGState, dace.SDFG, bool], bool -] -"""Callback for the map fusion transformation to check if a fusion should be performed. -""" - - -@properties.make_properties -class MapFusionHelper(transformation.SingleStateTransformation): - """Common parts of the parallel and serial map fusion transformation. - - Args: - only_inner_maps: Only match Maps that are internal, i.e. inside another Map. - only_toplevel_maps: Only consider Maps that are at the top. - strict_dataflow: If `True`, the transformation ensures a more - stricter version of the data flow. - apply_fusion_callback: A user supplied function, same signature as `can_be_fused()`, - to check if a fusion should be performed. - - Note: - If `strict_dataflow` mode is enabled then the transformation will not remove - _direct_ data flow dependency from the graph. Furthermore, the transformation - will not remove size 1 dimensions of intermediate it creates. - This is a compatibility mode, that will limit the applicability of the - transformation, but might help transformations that do not fully analyse - the graph. - """ - - only_toplevel_maps = properties.Property( - dtype=bool, - default=False, - desc="Only perform fusing if the Maps are in the top level.", - ) - only_inner_maps = properties.Property( - dtype=bool, - default=False, - desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", - ) - strict_dataflow = properties.Property( - dtype=bool, - default=False, - desc="If `True` then the transformation will ensure a more stricter data flow.", - ) - - # Callable that can be specified by the user, if it is specified, it should be - # a callable with the same signature as `can_be_fused()`. If the function returns - # `False` then the fusion will be rejected. - _apply_fusion_callback: Optional[FusionCallback] - - def __init__( - self, - only_inner_maps: Optional[bool] = None, - only_toplevel_maps: Optional[bool] = None, - strict_dataflow: Optional[bool] = None, - apply_fusion_callback: Optional[FusionCallback] = None, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - self._shared_data = {} # type: ignore[var-annotated] - self._apply_fusion_callback = None - if only_toplevel_maps is not None: - self.only_toplevel_maps = bool(only_toplevel_maps) - if only_inner_maps is not None: - self.only_inner_maps = bool(only_inner_maps) - if strict_dataflow is not None: - self.strict_dataflow = bool(strict_dataflow) - if apply_fusion_callback is not None: - self._apply_fusion_callback = apply_fusion_callback - - @classmethod - def expressions(cls) -> bool: - raise RuntimeError("The `MapFusionHelper` is not a transformation on its own.") - - def can_be_fused( - self, - map_entry_1: nodes.MapEntry, - map_entry_2: nodes.MapEntry, - graph: Union[dace.SDFGState, dace.SDFG], - sdfg: dace.SDFG, - permissive: bool = False, - ) -> bool: - """Performs basic checks if the maps can be fused. - - This function only checks constrains that are common between serial and - parallel map fusion process, which includes: - - The registered callback, if specified. - - The scope of the maps. - - The scheduling of the maps. - - The map parameters. - - Args: - map_entry_1: The entry of the first (in serial case the top) map. - map_exit_2: The entry of the second (in serial case the bottom) map. - graph: The SDFGState in which the maps are located. - sdfg: The SDFG itself. - permissive: Currently unused. - """ - # Consult the callback if defined. - if self._apply_fusion_callback is not None: - if not self._apply_fusion_callback( - self, map_entry_1, map_entry_2, graph, sdfg, permissive - ): - return False - - if self.only_inner_maps and self.only_toplevel_maps: - raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") - - # Ensure that both have the same schedule - if map_entry_1.map.schedule != map_entry_2.map.schedule: - return False - - # Fusing is only possible if the two entries are in the same scope. - scope = graph.scope_dict() - if scope[map_entry_1] != scope[map_entry_2]: - return False - elif self.only_inner_maps: - if scope[map_entry_1] is None: - return False - elif self.only_toplevel_maps: - if scope[map_entry_1] is not None: - return False - - # We will now check if there exists a remapping of the map parameter - if ( - self.find_parameter_remapping(first_map=map_entry_1.map, second_map=map_entry_2.map) - is None - ): - return False - - return True - - def relocate_nodes( - self, - from_node: Union[nodes.MapExit, nodes.MapEntry], - to_node: Union[nodes.MapExit, nodes.MapEntry], - state: SDFGState, - sdfg: SDFG, - ) -> None: - """Move the connectors and edges from `from_node` to `to_nodes` node. - - This function will only rewire the edges, it does not remove the nodes - themselves. Furthermore, this function should be called twice per Map, - once for the entry and then for the exit. - While it does not remove the node themselves if guarantees that the - `from_node` has degree zero. - The function assumes that the parameter renaming was already done. - - Args: - from_node: Node from which the edges should be removed. - to_node: Node to which the edges should reconnect. - state: The state in which the operation happens. - sdfg: The SDFG that is modified. - """ - - # Now we relocate empty Memlets, from the `from_node` to the `to_node` - for empty_edge in list(filter(lambda e: e.data.is_empty(), state.out_edges(from_node))): - helpers.redirect_edge(state, empty_edge, new_src=to_node) - for empty_edge in list(filter(lambda e: e.data.is_empty(), state.in_edges(from_node))): - helpers.redirect_edge(state, empty_edge, new_dst=to_node) - - # We now ensure that there is only one empty Memlet from the `to_node` to any other node. - # Although it is allowed, we try to prevent it. - empty_targets: Set[nodes.Node] = set() - for empty_edge in list(filter(lambda e: e.data.is_empty(), state.all_edges(to_node))): - if empty_edge.dst in empty_targets: - state.remove_edge(empty_edge) - empty_targets.add(empty_edge.dst) - - # We now determine which edges we have to migrate, for this we are looking at - # the incoming edges, because this allows us also to detect dynamic map ranges. - # TODO(phimuell): If there is already a connection to the node, reuse this. - for edge_to_move in list(state.in_edges(from_node)): - assert isinstance(edge_to_move.dst_conn, str) - - if not edge_to_move.dst_conn.startswith("IN_"): - # Dynamic Map Range - # The connector name simply defines a variable name that is used, - # inside the Map scope to define a variable. We handle it directly. - dmr_symbol = edge_to_move.dst_conn - - # TODO(phimuell): Check if the symbol is really unused in the target scope. - if dmr_symbol in to_node.in_connectors: - raise NotImplementedError( - f"Tried to move the dynamic map range '{dmr_symbol}' from {from_node}'" - f" to '{to_node}', but the symbol is already known there, but the" - " renaming is not implemented." - ) - if not to_node.add_in_connector(dmr_symbol, force=False): - raise RuntimeError( # Might fail because of out connectors. - f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'." - ) - helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) - from_node.remove_in_connector(dmr_symbol) - - else: - # We have a Passthrough connection, i.e. there exists a matching `OUT_`. - old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix - new_conn = to_node.next_connector(old_conn) - - to_node.add_in_connector("IN_" + new_conn) - for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): - helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) - to_node.add_out_connector("OUT_" + new_conn) - for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): - helpers.redirect_edge(state, e, new_src=to_node, new_src_conn="OUT_" + new_conn) - from_node.remove_in_connector("IN_" + old_conn) - from_node.remove_out_connector("OUT_" + old_conn) - - # Check if we succeeded. - if state.out_degree(from_node) != 0: - raise validation.InvalidSDFGError( - f"Failed to relocate the outgoing edges from `{from_node}`, there are still `{state.out_edges(from_node)}`", - sdfg, - sdfg.node_id(state), - ) - if state.in_degree(from_node) != 0: - raise validation.InvalidSDFGError( - f"Failed to relocate the incoming edges from `{from_node}`, there are still `{state.in_edges(from_node)}`", - sdfg, - sdfg.node_id(state), - ) - assert len(from_node.in_connectors) == 0 - assert len(from_node.out_connectors) == 0 - - def find_parameter_remapping( - self, first_map: nodes.Map, second_map: nodes.Map - ) -> Union[Dict[str, str], None]: - """Computes the parameter remapping for the parameters of the _second_ map. - - The returned `dict` maps the parameters of the second map (keys) to parameter - names of the first map (values). Because of how the replace function works - the `dict` describes how to replace the parameters of the second map - with parameters of the first map. - Parameters that already have the correct name and compatible range, are not - included in the return value, thus the keys and values are always different. - If no renaming at all is _needed_, i.e. all parameter have the same name and - range, then the function returns an empty `dict`. - If no remapping exists, then the function will return `None`. - - Args: - first_map: The first map (these parameters will be replaced). - second_map: The second map, these parameters acts as source. - """ - - # The parameter names - first_params: List[str] = first_map.params - second_params: List[str] = second_map.params - - if len(first_params) != len(second_params): - return None - - # The ranges, however, we apply some post processing to them. - simp = lambda e: symbolic.simplify_ext(symbolic.simplify(e)) # noqa: E731 - first_rngs: Dict[str, Tuple[Any, Any, Any]] = { - param: tuple(simp(r) for r in rng) for param, rng in zip(first_params, first_map.range) - } - second_rngs: Dict[str, Tuple[Any, Any, Any]] = { - param: tuple(simp(r) for r in rng) - for param, rng in zip(second_params, second_map.range) - } - - # Parameters of the second map that have not yet been matched to a parameter - # of the first map and vice versa. - unmapped_second_params: Set[str] = set(second_params) - unused_first_params: Set[str] = set(first_params) - - # This is the result (`second_param -> first_param`), note that if no renaming - # is needed then the parameter is not present in the mapping. - final_mapping: Dict[str, str] = {} - - # First we identify the parameters that already have the correct name. - for param in set(first_params).intersection(second_params): - first_rng = first_rngs[param] - second_rng = second_rngs[param] - - if first_rng == second_rng: - # They have the same name and the same range, this is already a match. - # Because the names are already the same, we do not have to enter them - # in the `final_mapping` - unmapped_second_params.discard(param) - unused_first_params.discard(param) - - # Check if no remapping is needed. - if len(unmapped_second_params) == 0: - return {} - - # Now we go through all the parameters that we have not mapped yet. - # All of them will result in a remapping. - for unmapped_second_param in unmapped_second_params: - second_rng = second_rngs[unmapped_second_param] - assert unmapped_second_param not in final_mapping - - # Now look in all not yet used parameters of the first map which to use. - for candidate_param in unused_first_params: - candidate_rng = first_rngs[candidate_param] - if candidate_rng == second_rng: - final_mapping[unmapped_second_param] = candidate_param - unused_first_params.discard(candidate_param) - break - else: - # We did not find a candidate, so the remapping does not exist - return None - - assert len(unused_first_params) == 0 - assert len(final_mapping) == len(unmapped_second_params) - return final_mapping - - def rename_map_parameters( - self, - first_map: nodes.Map, - second_map: nodes.Map, - second_map_entry: nodes.MapEntry, - state: SDFGState, - ) -> None: - """Replaces the map parameters of the second map with names from the first. - - The replacement is done in a safe way, thus `{'i': 'j', 'j': 'i'}` is - handled correct. The function assumes that a proper replacement exists. - The replacement is computed by calling `self.find_parameter_remapping()`. - - Args: - first_map: The first map (these are the final parameter). - second_map: The second map, this map will be replaced. - second_map_entry: The entry node of the second map. - state: The SDFGState on which we operate. - """ - # Compute the replacement dict. - repl_dict: Dict[str, str] = self.find_parameter_remapping( # type: ignore[assignment] - first_map=first_map, second_map=second_map - ) - - if repl_dict is None: - raise RuntimeError("The replacement does not exist") - if len(repl_dict) == 0: - return - - second_map_scope = state.scope_subgraph(entry_node=second_map_entry) - # Why is this thing is symbolic and not in replace? - symbolic.safe_replace( - mapping=repl_dict, - replace_callback=second_map_scope.replace_dict, - ) - - # For some odd reason the replace function does not modify the range and - # parameter of the map, so we will do it the hard way. - second_map.params = copy.deepcopy(first_map.params) - second_map.range = copy.deepcopy(first_map.range) - - def is_shared_data( - self, - data: nodes.AccessNode, - state: dace.SDFGState, - sdfg: dace.SDFG, - ) -> bool: - """Tests if `data` is shared data, i.e. it can not be removed from the SDFG. - - Depending on the situation, the function will not perform a scan of the whole SDFG: - 1) If `data` is non transient then the function will return `True`, as non transient data - must be reconstructed always. - 2) If the AccessNode `data` has more than one outgoing edge or more than one incoming edge - it is classified as shared. - 3) If `FindSingleUseData` is in the pipeline it will be used and no scan will be performed. - 4) The function will perform a scan. - - :param data: The transient that should be checked. - :param state: The state in which the fusion is performed. - :param sdfg: The SDFG in which we want to perform the fusing. - - """ - # If `data` is non transient then return `True` as the intermediate can not be removed. - if not data.desc(sdfg).transient: - return True - - # This means the data is consumed by multiple Maps, through the same AccessNode, in this state - # Note currently multiple incoming edges are not handled, but in the spirit of this function - # we consider such AccessNodes as shared, because we can not remove the intermediate. - if state.out_degree(data) > 1: - return True - if state.in_degree(data) > 1: - return True - - # We have to perform the full scan of the SDFG. - return self._scan_sdfg_if_data_is_shared(data=data, state=state, sdfg=sdfg) - - def _scan_sdfg_if_data_is_shared( - self, - data: nodes.AccessNode, - state: dace.SDFGState, - sdfg: dace.SDFG, - ) -> bool: - """Scans `sdfg` to determine if `data` is shared. - - Essentially, this function determines if the intermediate AccessNode `data` - can be removed or if it has to be restored as output of the Map. - A data descriptor is classified as shared if any of the following is true: - - `data` is non transient data. - - `data` has at most one incoming and/or outgoing edge. - - There are other AccessNodes beside `data` that refer to the same data. - - The data is accessed on an interstate edge. - - This function should not be called directly. Instead it is called indirectly - by `is_shared_data()` if there is no short cut. - - :param data: The AccessNode that should checked if it is shared. - :param sdfg: The SDFG for which the set of shared data should be computed. - """ - if not data.desc(sdfg).transient: - return True - - # See description in `is_shared_data()` for more. - if state.out_degree(data) > 1: - return True - if state.in_degree(data) > 1: - return True - - data_name: str = data.data - for state in sdfg.states(): - for dnode in state.data_nodes(): - if dnode is data: - # We have found the `data` AccessNode, which we must ignore. - continue - if dnode.data == data_name: - # We found a different AccessNode that refers to the same data - # as `data`. Thus `data` is shared. - return True - - # Test if the data is referenced in the interstate edges. - for edge in sdfg.edges(): - if data_name in edge.data.free_symbols: - # The data is used in the inter state edges. So it is shared. - return True - - # Test if they are accessed in a condition of a loop or conditional block. - for cfr in sdfg.all_control_flow_regions(): - if data_name in cfr.used_symbols(all_symbols=True, with_contents=False): - return True - - # The `data` is not used anywhere else, thus `data` is not shared. - return False - - def _compute_multi_write_data( - self, - state: SDFGState, - sdfg: SDFG, - ) -> Set[str]: - """Computes data inside a _single_ state, that is written multiple times. - - Essentially this function computes the set of data that does not follow - the single static assignment idiom. The function also resolves views. - If an access node, refers to a view, not only the view itself, but also - the data it refers to is added to the set. - - Args: - state: The state that should be examined. - sdfg: The SDFG object. - - Note: - This information is used by the partition function (in case strict data - flow mode is enabled), in strict data flow mode only. The current - implementation is rather simple as it only checks if a data is written - to multiple times in the same state. - """ - data_written_to: Set[str] = set() - multi_write_data: Set[str] = set() - - for access_node in state.data_nodes(): - if state.in_degree(access_node) == 0: - continue - if access_node.data in data_written_to: - multi_write_data.add(access_node.data) - elif self.is_view(access_node, sdfg): - # This is an over approximation. - multi_write_data.update( - [access_node.data, self.track_view(access_node, state, sdfg).data] - ) - data_written_to.add(access_node.data) - return multi_write_data - - def is_node_reachable_from( - self, - graph: dace.SDFGState, - begin: nodes.Node, - end: nodes.Node, - ) -> bool: - """Test if the node `end` can be reached from `begin`. - - Essentially the function starts a DFS at `begin`. If an edge is found that lead - to `end` the function returns `True`. If the node is never found `False` is - returned. - - Args: - graph: The graph to operate on. - begin: The start of the DFS. - end: The node that should be located. - """ - - def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: - return (edge.dst for edge in graph.out_edges(node)) - - to_visit: List[nodes.Node] = [begin] - seen: Set[nodes.Node] = set() - - while len(to_visit) > 0: - node: nodes.Node = to_visit.pop() - if node == end: - return True - elif node not in seen: - to_visit.extend(next_nodes(node)) - seen.add(node) - - # We never found `end` - return False - - def get_access_set( - self, - scope_node: Union[nodes.MapEntry, nodes.MapExit], - state: SDFGState, - ) -> Set[nodes.AccessNode]: - """Computes the access set of a "scope node". - - If `scope_node` is a `MapEntry` it will operate on the set of incoming edges - and if it is an `MapExit` on the set of outgoing edges. The function will - then determine all access nodes that have a connection through these edges - to the scope nodes (edges that does not lead to access nodes are ignored). - The function returns a set that contains all access nodes that were found. - It is important that this set will also contain views. - - Args: - scope_node: The scope node that should be evaluated. - state: The state in which we operate. - """ - if isinstance(scope_node, nodes.MapEntry): - get_edges = lambda node: state.in_edges(node) # noqa: E731 - other_node = lambda e: e.src # noqa: E731 - else: - get_edges = lambda node: state.out_edges(node) # noqa: E731 - other_node = lambda e: e.dst # noqa: E731 - access_set: Set[nodes.AccessNode] = { - node - for node in map(other_node, get_edges(scope_node)) - if isinstance(node, nodes.AccessNode) - } - - return access_set - - def find_subsets( - self, - node: nodes.AccessNode, - scope_node: Union[nodes.MapExit, nodes.MapEntry], - state: SDFGState, - sdfg: SDFG, - repl_dict: Optional[Dict[str, str]], - ) -> List[subsets.Subset]: - """Finds all subsets that access `node` within `scope_node`. - - The function will not start a search for all consumer/producers. - Instead it will locate the edges which is immediately inside the - map scope. - - Args: - node: The access node that should be examined. - scope_node: We are only interested in data that flows through this node. - state: The state in which we operate. - sdfg: The SDFG object. - """ - - # Is the node used for reading or for writing. - # This influences how we have to proceed. - if isinstance(scope_node, nodes.MapEntry): - outer_edges_to_inspect = [e for e in state.in_edges(scope_node) if e.src == node] - get_subset = lambda e: e.data.src_subset # noqa: E731 - get_inner_edges = lambda e: state.out_edges_by_connector( - scope_node, "OUT_" + e.dst_conn[3:] - ) - else: - outer_edges_to_inspect = [e for e in state.out_edges(scope_node) if e.dst == node] - get_subset = lambda e: e.data.dst_subset # noqa: E731 - get_inner_edges = lambda e: state.in_edges_by_connector( - scope_node, "IN_" + e.src_conn[4:] - ) - - found_subsets: List[subsets.Subset] = [] - for edge in outer_edges_to_inspect: - found_subsets.extend(get_subset(e) for e in get_inner_edges(edge)) - assert len(found_subsets) > 0, "Could not find any subsets." - assert not any(subset is None for subset in found_subsets) - - found_subsets = copy.deepcopy(found_subsets) - if repl_dict: - for subset in found_subsets: - # Replace happens in place - symbolic.safe_replace(repl_dict, subset.replace) - - return found_subsets - - def is_view( - self, - node: nodes.AccessNode, - sdfg: SDFG, - ) -> bool: - """Tests if `node` points to a view or not.""" - node_desc: data.Data = node.desc(sdfg) - return isinstance(node_desc, data.View) - - def track_view( - self, - view: nodes.AccessNode, - state: SDFGState, - sdfg: SDFG, - ) -> nodes.AccessNode: - """Find the original data of a View. - - Given the View `view`, the function will trace the view back to the original - access node. For convenience, if `view` is not a `View` the argument will be - returned. - - Args: - view: The view that should be traced. - state: The state in which we operate. - sdfg: The SDFG on which we operate. - """ - - # Test if it is a view at all, if not return the passed node as source. - if not self.is_view(view, sdfg): - return view - - # First determine if the view is used for reading or writing. - curr_edge = dace.sdfg.utils.get_view_edge(state, view) - if curr_edge is None: - raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") - if curr_edge.dst_conn == "views": - # The view is used for reading. - next_node = lambda curr_edge: curr_edge.src # noqa: E731 - elif curr_edge.src_conn == "views": - # The view is used for writing. - next_node = lambda curr_edge: curr_edge.dst # noqa: E731 - else: - raise RuntimeError( - f"Failed to determine the direction of the view '{view}' | {curr_edge}." - ) - - # Now trace the view back. - org_view = view - view = next_node(curr_edge) - while self.is_view(view, sdfg): - curr_edge = dace.sdfg.utils.get_view_edge(state, view) - if curr_edge is None: - raise RuntimeError(f"View tracing of '{org_view}' failed at note '{view}'.") - view = next_node(curr_edge) - return view diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_parallel.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_parallel.py deleted file mode 100644 index 19412b9dfa..0000000000 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_parallel.py +++ /dev/null @@ -1,170 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -"""Implements the parallel map fusing transformation. - -THIS FILE WAS COPIED FROM DACE TO FACILITATE DEVELOPMENT UNTIL THE PR#1625 IN -DACE IS MERGED AND THE VERSION WAS UPGRADED. -""" - -from typing import Any, Optional, Set, Union - -import dace -from dace import properties, transformation -from dace.sdfg import SDFG, SDFGState, graph, nodes - -from . import map_fusion_helper as mfh - - -@properties.make_properties -class MapFusionParallel(mfh.MapFusionHelper): - """The `MapFusionParallel` transformation allows to merge two parallel maps. - - While the `MapFusionSerial` transformation fuses maps that are sequentially - connected through an intermediate node, this transformation is able to fuse any - two maps that are not sequential and in the same scope. - - Args: - only_if_common_ancestor: Only perform fusion if both Maps share at least one - node as direct ancestor. This will increase the locality of the merge. - only_inner_maps: Only match Maps that are internal, i.e. inside another Map. - only_toplevel_maps: Only consider Maps that are at the top. - apply_fusion_callback: A user supplied function, same signature as `can_be_fused()`, - to check if a fusion should be performed. - - Note: - This transformation only matches the entry nodes of the Map, but will also - modify the exit nodes of the Maps. - """ - - map_entry_1 = transformation.transformation.PatternNode(nodes.MapEntry) - map_entry_2 = transformation.transformation.PatternNode(nodes.MapEntry) - - only_if_common_ancestor = properties.Property( - dtype=bool, - default=False, - allow_none=False, - desc="Only perform fusing if the Maps share a node as parent.", - ) - - def __init__( - self, - only_if_common_ancestor: Optional[bool] = None, - **kwargs: Any, - ) -> None: - if only_if_common_ancestor is not None: - self.only_if_common_ancestor = only_if_common_ancestor - super().__init__(**kwargs) - - @classmethod - def expressions(cls) -> Any: - # This just matches _any_ two Maps inside a state. - state = graph.OrderedMultiDiConnectorGraph() - state.add_nodes_from([cls.map_entry_1, cls.map_entry_2]) - return [state] - - def can_be_applied( - self, - graph: Union[SDFGState, SDFG], - expr_index: int, - sdfg: dace.SDFG, - permissive: bool = False, - ) -> bool: - """Checks if the fusion can be done. - - The function checks the general fusing conditions and if the maps are parallel. - """ - map_entry_1: nodes.MapEntry = self.map_entry_1 - map_entry_2: nodes.MapEntry = self.map_entry_2 - - # Check the structural properties of the maps, this will also ensure that - # the two maps are in the same scope and the parameters can be renamed - if not self.can_be_fused( - map_entry_1=map_entry_1, - map_entry_2=map_entry_2, - graph=graph, - sdfg=sdfg, - permissive=permissive, - ): - return False - - # Since the match expression matches any two Maps, we have to ensure that - # the maps are parallel. The `can_be_fused()` function already verified - # if they are in the same scope. - if not self.is_parallel(graph=graph, node1=map_entry_1, node2=map_entry_2): - return False - - # Test if they have they share a node as direct ancestor. - if self.only_if_common_ancestor: - # This assumes that there is only one access node per data container in the state. - ancestors_1: Set[nodes.Node] = {e1.src for e1 in graph.in_edges(map_entry_1)} - if not any(e2.src in ancestors_1 for e2 in graph.in_edges(map_entry_2)): - return False - - return True - - def is_parallel( - self, - graph: SDFGState, - node1: nodes.Node, - node2: nodes.Node, - ) -> bool: - """Tests if `node1` and `node2` are parallel. - - The nodes are parallel if `node2` can not be reached from `node1` and vice versa. - - Args: - graph: The graph to traverse. - node1: The first node to check. - node2: The second node to check. - """ - - # In order to be parallel they must be in the same scope. - scope = graph.scope_dict() - if scope[node1] != scope[node2]: - return False - - # The `all_nodes_between()` function traverse the graph and returns `None` if - # `end` was not found. We have to call it twice, because we do not know - # which node is upstream if they are not parallel. - if self.is_node_reachable_from(graph=graph, begin=node1, end=node2): - return False - elif self.is_node_reachable_from(graph=graph, begin=node2, end=node1): - return False - return True - - def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: - """Performs the Map fusing. - - Essentially, the function relocate all edges from the scope nodes (`MapEntry` - and `MapExit`) of the second map to the scope nodes of the first map. - """ - - map_entry_1: nodes.MapEntry = self.map_entry_1 - map_exit_1: nodes.MapExit = graph.exit_node(map_entry_1) - map_entry_2: nodes.MapEntry = self.map_entry_2 - map_exit_2: nodes.MapExit = graph.exit_node(map_entry_2) - - # Before we do anything we perform the renaming. - self.rename_map_parameters( - first_map=map_entry_1.map, - second_map=map_entry_2.map, - second_map_entry=map_entry_2, - state=graph, - ) - - for to_node, from_node in zip((map_entry_1, map_exit_1), (map_entry_2, map_exit_2)): - self.relocate_nodes( - from_node=from_node, - to_node=to_node, - state=graph, - sdfg=sdfg, - ) - # The relocate function does not remove the node, so we must do it. - graph.remove_node(from_node) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py deleted file mode 100644 index 977f2933b5..0000000000 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py +++ /dev/null @@ -1,1053 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -"""Implements the serial map fusing transformation. - -THIS FILE WAS COPIED FROM DACE TO FACILITATE DEVELOPMENT UNTIL THE PR#1625 IN -DACE IS MERGED AND THE VERSION WAS UPGRADED. -""" - -import copy -from typing import Any, Dict, List, Optional, Set, Tuple, Union - -import dace -from dace import data, properties, subsets, symbolic, transformation -from dace.sdfg import SDFG, SDFGState, graph, nodes - -from . import map_fusion_helper as mfh - - -@properties.make_properties -class MapFusionSerial(mfh.MapFusionHelper): - """Fuse two serial maps together. - - The transformation combines two maps into one that are connected through some - access nodes. Conceptually this transformation removes the exit of the first - or upper map and the entry of the lower or second map and then rewrites the - connections appropriately. Depending on the situation the transformation will - either fully remove or make the intermediate a new output of the second map. - - By default, the transformation does not use the strict data flow mode, see - `MapFusionHelper` for more, however, it might be useful in come cases to enable - it, especially in the context of DaCe's auto optimizer. - - Args: - only_inner_maps: Only match Maps that are internal, i.e. inside another Map. - only_toplevel_maps: Only consider Maps that are at the top. - strict_dataflow: If `True`, the transformation ensures a more - stricter version of the data flow. - apply_fusion_callback: A user supplied function, same signature as `can_be_fused()`, - to check if a fusion should be performed. - - Notes: - - This transformation modifies more nodes than it matches. - - After the transformation has been applied simplify should be run to remove - some dead data flow, that was introduced to ensure validity. - - A `MapFusionSerial` object can be initialized and be reused. However, - after new access nodes are added to any state, it is no longer valid - to use the object. - - Todo: - - Consider the case that only shared nodes are created (thus no inspection of - the graph is needed) and make all shared. Then use the dead dataflow - elimination transformation to get rid of the ones we no longer need. - - Increase the applicability. - """ - - map_exit_1 = transformation.transformation.PatternNode(nodes.MapExit) - intermediate_access_node = transformation.transformation.PatternNode(nodes.AccessNode) - map_entry_2 = transformation.transformation.PatternNode(nodes.MapEntry) - - def __init__( - self, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - - @classmethod - def expressions(cls) -> Any: - """Get the match expression. - - The transformation matches the exit node of the top Map that is connected to - an access node that again is connected to the entry node of the second Map. - An important note is, that the transformation operates not just on the - matched nodes, but more or less on anything that has an incoming connection - from the first Map or an outgoing connection to the second Map entry. - """ - return [ - dace.sdfg.utils.node_path_graph( - cls.map_exit_1, cls.intermediate_access_node, cls.map_entry_2 - ) - ] - - def can_be_applied( - self, - graph: Union[SDFGState, SDFG], - expr_index: int, - sdfg: dace.SDFG, - permissive: bool = False, - ) -> bool: - """Tests if the matched Maps can be merged. - - The two Maps are mergeable iff: - - Satisfies general requirements, see `MapFusionHelper.can_be_fused()`. - - Tests if the decomposition exists. - - Tests if there are read write dependencies. - """ - map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit_1) - map_exit_1: nodes.MapExit = self.map_exit_1 - map_entry_2: nodes.MapEntry = self.map_entry_2 - - # This essentially test the structural properties of the two Maps. - if not self.can_be_fused( - map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg - ): - return False - - # Test for read-write conflicts - if self.has_read_write_dependency( - map_entry_1=map_entry_1, - map_entry_2=map_entry_2, - state=graph, - sdfg=sdfg, - ): - return False - - # Two maps can be serially fused if the node decomposition exists and - # at least one of the intermediate output sets is not empty. The state - # of the pure outputs is irrelevant for serial map fusion. - output_partition = self.partition_first_outputs( - state=graph, - sdfg=sdfg, - first_map_exit=map_exit_1, - second_map_entry=map_entry_2, - ) - if output_partition is None: - return False - _, exclusive_outputs, shared_outputs = output_partition - if not (exclusive_outputs or shared_outputs): - return False - return True - - def has_read_write_dependency( - self, - map_entry_1: nodes.MapEntry, - map_entry_2: nodes.MapEntry, - state: SDFGState, - sdfg: SDFG, - ) -> bool: - """Test if there is a read write dependency between the two maps to be fused. - - The function checks two different things. - - The function will make sure that there is no read write dependency between - the input and output of the fused maps. For that it will inspect the - respective subsets. - - The second part partially checks the intermediate nodes, it mostly ensures - that there are not views and that they are not used as inputs or outputs - at the same time. However, the function will not check for read write - conflicts in this set, this is done in the partition function. - - Returns: - `True` if there is a conflict between the maps that can not be handled. - If there is no conflict or if the conflict can be handled `False` - is returned. - - Args: - map_entry_1: The entry node of the first map. - map_entry_2: The entry node of the second map. - state: The state on which we operate. - sdfg: The SDFG on which we operate. - """ - map_exit_1: nodes.MapExit = state.exit_node(map_entry_1) - map_exit_2: nodes.MapExit = state.exit_node(map_entry_2) - - # Get the read and write sets of the different maps, note that Views - # are not resolved yet. - access_sets: List[Dict[str, nodes.AccessNode]] = [] - for scope_node in [map_entry_1, map_exit_1, map_entry_2, map_exit_2]: - access_set: Set[nodes.AccessNode] = self.get_access_set(scope_node, state) - access_sets.append({node.data: node for node in access_set}) - # If two different access nodes of the same scoping node refers to the - # same data, then we consider this as a dependency we can not handle. - # It is only a problem for the intermediate nodes and might be possible - # to handle, but doing so is hard, so we just forbid it. - if len(access_set) != len(access_sets[-1]): - return True - read_map_1, write_map_1, read_map_2, write_map_2 = access_sets - - # It might be possible that there are views, so we have to resolve them. - # We also already get the name of the data container. - # Note that `len(real_read_map_1) <= len(read_map_1)` holds because of Views. - resolved_sets: List[Set[str]] = [] - for unresolved_set in [read_map_1, write_map_1, read_map_2, write_map_2]: - resolved_sets.append( - { - self.track_view(node, state, sdfg).data - if self.is_view(node, sdfg) - else node.data - for node in unresolved_set.values() - } - ) - # If the resolved and unresolved names do not have the same length. - # Then different views point to the same location, which we forbid - if len(unresolved_set) != len(resolved_sets[-1]): - return False - real_read_map_1, real_write_map_1, real_read_map_2, real_write_map_2 = resolved_sets - - # We do not allow that the first and second map each write to the same data. - if not real_write_map_1.isdisjoint(real_write_map_2): - return True - - # If there is no overlap in what is (totally) read and written, there will be no conflict. - # This must come before the check of disjoint write. - if (real_read_map_1 | real_read_map_2).isdisjoint(real_write_map_1 | real_write_map_2): - return False - - # These are the names (unresolved) and the access nodes of the data that is used - # to transmit information between the maps. The partition function ensures that - # these nodes are directly connected to the two maps. - exchange_names: Set[str] = set(write_map_1.keys()).intersection(read_map_2.keys()) - exchange_nodes: Set[nodes.AccessNode] = set(write_map_1.values()).intersection( - read_map_2.values() - ) - - # If the number are different then a data is accessed through multiple nodes. - if len(exchange_names) != len(exchange_nodes): - return True - assert all(exchange_node.data in exchange_names for exchange_node in exchange_nodes) - - # For simplicity we assume that the nodes used for exchange are not views. - if any(self.is_view(exchange_node, sdfg) for exchange_node in exchange_nodes): - return True - - # This is the names of the node that are used as input of the first map and - # as output of the second map. We have to ensure that there is no data - # dependency between these nodes. - fused_inout_data_names: Set[str] = set(read_map_1.keys()).intersection(write_map_2.keys()) - - # If a data container is used as input and output then it can not be a view (simplicity) - if any(self.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names): - return True - - # A data container can be used as input and output. But we do not allow that - # it is also used as intermediate or exchange data. This is an important check. - if not fused_inout_data_names.isdisjoint(exchange_names): - return True - - # Get the replacement dict for changing the map variables from the subsets of - # the second map. - repl_dict = self.find_parameter_remapping(map_entry_1.map, map_exit_2.map) - - # Now we inspect if there is a read write dependency, between data that is - # used as input and output of the fused map. There is no problem is they - # are pointwise, i.e. in each iteration the same locations are accessed. - # Essentially they all boil down to `a += 1`. - for inout_data_name in fused_inout_data_names: - all_subsets: List[subsets.Subset] = [] - # The subsets that define reading are given by the first map's entry node - all_subsets.extend( - self.find_subsets( - node=read_map_1[inout_data_name], - scope_node=map_entry_1, - state=state, - sdfg=sdfg, - repl_dict=None, - ) - ) - # While the subsets defining writing are given by the second map's exit - # node, there we also have to apply renaming. - all_subsets.extend( - self.find_subsets( - node=write_map_2[inout_data_name], - scope_node=map_exit_2, - state=state, - sdfg=sdfg, - repl_dict=repl_dict, - ) - ) - # Now we can test if these subsets are point wise - if not self.test_if_subsets_are_point_wise(all_subsets): - return True - - # No read write dependency was found. - return False - - def test_if_subsets_are_point_wise(self, subsets_to_check: List[subsets.Subset]) -> bool: - """Point wise means that they are all the same. - - If a series of subsets are point wise it means that all Memlets, access - the same data. This is an important property because the whole map fusion - is build upon this. - If the subsets originates from different maps, then they must have been - renamed. - - Args: - subsets_to_check: The list of subsets that should be checked. - """ - assert len(subsets_to_check) > 1 - - # We will check everything against the master subset. - master_subset = subsets_to_check[0] - for ssidx in range(1, len(subsets_to_check)): - subset = subsets_to_check[ssidx] - if isinstance(subset, subsets.Indices): - subset = subsets.Range.from_indices(subset) - # Do we also need the reverse? See below why. - if any(r != (0, 0, 1) for r in subset.offset_new(master_subset, negative=True)): - return False - else: - # The original code used `Range.offset` here, but that one had trouble - # for `r1 = 'j, 0:10'` and `r2 = 'j, 0`. The solution would be to test - # symmetrically, i.e. `r1 - r2` and `r2 - r1`. However, if we would - # have `r2_1 = 'j, 0:10'` it consider it as failing, which is not - # what we want. Thus we will use symmetric cover. - if not master_subset.covers(subset): - return False - if not subset.covers(master_subset): - return False - - # All subsets are equal to the master subset, thus they are equal to each other. - # This means that the data accesses, described by this transformation is - # point wise - return True - - def compute_offset_subset( - self, - original_subset: subsets.Range, - intermediate_desc: data.Data, - map_params: List[str], - producer_offset: Optional[subsets.Range] = None, - ) -> subsets.Range: - """Computes the memlet to correct read and writes of the intermediate. - - Args: - original_subset: The original subset that was used to write into the - intermediate, must be renamed to the final map parameter. - intermediate_desc: The original intermediate data descriptor. - map_params: The parameter of the final map. - """ - assert not isinstance(intermediate_desc, data.View) - final_offset: subsets.Range = None - if isinstance(intermediate_desc, data.Scalar): - final_offset = subsets.Range.from_string("0") - - elif isinstance(intermediate_desc, data.Array): - basic_offsets = original_subset.min_element() - offset_list = [] - for d in range(original_subset.dims()): - d_range = subsets.Range([original_subset[d]]) - if d_range.free_symbols.intersection(map_params): - offset_list.append(d_range[0]) - else: - offset_list.append((basic_offsets[d], basic_offsets[d], 1)) - final_offset = subsets.Range(offset_list) - - else: - raise TypeError( - f"Does not know how to compute the subset offset for '{type(intermediate_desc).__name__}'." - ) - - if producer_offset is not None: - # Here we are correcting some parts that over approximate (which partially - # does under approximate) might screw up. Consider two maps, the first - # map only writes the subset `[:, 2:6]`, thus the new intermediate will - # have shape `(1, 4)`. Now also imagine that the second map only reads - # the elements `[:, 3]`. From this we see that we can only correct the - # consumer side if we also take the producer side into consideration! - # See also the `transformations/mapfusion_test.py::test_offset_correction_*` - # tests for more. - final_offset.offset( - final_offset.offset_new( - producer_offset, - negative=True, - ), - negative=True, - ) - return final_offset - - def partition_first_outputs( - self, - state: SDFGState, - sdfg: SDFG, - first_map_exit: nodes.MapExit, - second_map_entry: nodes.MapEntry, - ) -> Union[ - Tuple[ - Set[graph.MultiConnectorEdge[dace.Memlet]], - Set[graph.MultiConnectorEdge[dace.Memlet]], - Set[graph.MultiConnectorEdge[dace.Memlet]], - ], - None, - ]: - """Partition the output edges of `first_map_exit` for serial map fusion. - - The output edges of the first map are partitioned into three distinct sets, - defined as follows: - * Pure Output Set `\mathbb{P}`: - These edges exits the first map and does not enter the second map. These - outputs will be simply be moved to the output of the second map. - * Exclusive Intermediate Set `\mathbb{E}`: - Edges in this set leaves the first map exit, enters an access node, from - where a Memlet then leads immediately to the second map. The memory - referenced by this access node is not used anywhere else, thus it can - be removed. - * Shared Intermediate Set `\mathbb{S}`: - These edges are very similar to the one in `\mathbb{E}` except that they - are used somewhere else, thus they can not be removed and have to be - recreated as output of the second map. - - If strict data flow mode is enabled the function is rather strict if an - output can be added to either intermediate set and might fail to compute - the partition, even if it would exist. - - :return: If such a decomposition exists the function will return the three sets - mentioned above in the same order. In case the decomposition does not exist, - i.e. the maps can not be fused the function returns `None`. - - :param state: The in which the two maps are located. - :param sdfg: The full SDFG in whcih we operate. - :param first_map_exit: The exit node of the first map. - :param second_map_entry: The entry node of the second map. - """ - # The three outputs set. - pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() - exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() - shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() - - # Compute the renaming that for translating the parameter of the _second_ - # map to the ones used by the first map. - param_repl: Dict[str, str] = self.find_parameter_remapping( # type: ignore[assignment] - first_map=first_map_exit.map, - second_map=second_map_entry.map, - ) - assert param_repl is not None - - # Set of intermediate nodes that we have already processed. - processed_inter_nodes: Set[nodes.Node] = set() - - # Now scan all output edges of the first exit and classify them - for out_edge in state.out_edges(first_map_exit): - intermediate_node: nodes.Node = out_edge.dst - - # We already processed the node, this should indicate that we should - # run simplify again, or we should start implementing this case. - # TODO(phimuell): Handle this case, already partially handled here. - if intermediate_node in processed_inter_nodes: - return None - processed_inter_nodes.add(intermediate_node) - - # The intermediate can only have one incoming degree. It might be possible - # to handle multiple incoming edges, if they all come from the top map. - # However, the resulting SDFG might be invalid. - # NOTE: Allow this to happen (under certain cases) if the only producer - # is the top map. - if state.in_degree(intermediate_node) != 1: - return None - - # If the second map is not reachable from the intermediate node, then - # the output is pure and we can end here. - if not self.is_node_reachable_from( - graph=state, - begin=intermediate_node, - end=second_map_entry, - ): - pure_outputs.add(out_edge) - continue - - # The following tests are _after_ we have determined if we have a pure - # output node, because this allows us to handle more exotic pure node - # cases, as handling them is essentially rerouting an edge, whereas - # handling intermediate nodes is much more complicated. - - # Empty Memlets are only allowed if they are in `\mathbb{P}`, which - # is also the only place they really make sense (for a map exit). - # Thus if we now found an empty Memlet we reject it. - if out_edge.data.is_empty(): - return None - - # For us an intermediate node must always be an access node, because - # everything else we do not know how to handle. It is important that - # we do not test for non transient data here, because they can be - # handled has shared intermediates. - if not isinstance(intermediate_node, nodes.AccessNode): - return None - if self.is_view(intermediate_node, sdfg): - return None - - # It can happen that multiple edges converges at the `IN_` connector - # of the first map exit, but there is only one edge leaving the exit. - # It is complicate to handle this, so for now we ignore it. - # TODO(phimuell): Handle this case properly. - # To handle this we need to associate a consumer edge (the outgoing edges - # of the second map) with exactly one producer. - producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list( - state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:]) - ) - if len(producer_edges) > 1: - return None - - # Now check the constraints we have on the producers. - # - The source of the producer can not be a view (we do not handle this) - # - The edge shall also not be a reduction edge. - # - Defined location to where they write. - # - No dynamic Melets. - # Furthermore, we will also extract the subsets, i.e. the location they - # modify inside the intermediate array. - # Since we do not allow for WCR, we do not check if the producer subsets intersects. - producer_subsets: List[subsets.Subset] = [] - for producer_edge in producer_edges: - if isinstance(producer_edge.src, nodes.AccessNode) and self.is_view( - producer_edge.src, sdfg - ): - return None - if producer_edge.data.dynamic: - # TODO(phimuell): Find out if this restriction could be lifted, but it is unlikely. - return None - if producer_edge.data.wcr is not None: - return None - if producer_edge.data.dst_subset is None: - return None - producer_subsets.append(producer_edge.data.dst_subset) - - # Check if the producer do not intersect - if len(producer_subsets) == 1: - pass - elif len(producer_subsets) == 2: - if producer_subsets[0].intersects(producer_subsets[1]): - return None - else: - for i, psbs1 in enumerate(producer_subsets): - for j, psbs2 in enumerate(producer_subsets): - if i == j: - continue - if psbs1.intersects(psbs2): - return None - - # Now we determine the consumer of nodes. For this we are using the edges - # leaves the second map entry. It is not necessary to find the actual - # consumer nodes, as they might depend on symbols of nested Maps. - # For the covering test we only need their subsets, but we will perform - # some scan and filtering on them. - found_second_map = False - consumer_subsets: List[subsets.Subset] = [] - for intermediate_node_out_edge in state.out_edges(intermediate_node): - # If the second map entry is not immediately reachable from the intermediate - # node, then ensure that there is not path that goes to it. - if intermediate_node_out_edge.dst is not second_map_entry: - if self.is_node_reachable_from( - graph=state, begin=intermediate_node_out_edge.dst, end=second_map_entry - ): - return None - continue - - # Ensure that the second map is found exactly once. - # TODO(phimuell): Lift this restriction. - if found_second_map: - return None - found_second_map = True - - # The output of the top map can not define a dynamic map range in the - # second map. - if not intermediate_node_out_edge.dst_conn.startswith("IN_"): - return None - - # Now we look at all edges that leave the second map entry, i.e. the - # edges that feeds the consumer and define what is read inside the map. - # We do not check them, but collect them and inspect them. - # NOTE1: The subset still uses the old iteration variables. - # NOTE2: In case of consumer Memlet we explicitly allow dynamic Memlets. - # This is different compared to the producer Memlet. The reason is - # because in a consumer the data is conditionally read, so the data - # has to exists anyway. - for inner_consumer_edge in state.out_edges_by_connector( - second_map_entry, "OUT_" + intermediate_node_out_edge.dst_conn[3:] - ): - if inner_consumer_edge.data.src_subset is None: - return None - consumer_subsets.append(inner_consumer_edge.data.src_subset) - assert ( - found_second_map - ), f"Found '{intermediate_node}' which looked like a pure node, but is not one." - assert len(consumer_subsets) != 0 - - # The consumer still uses the original symbols of the second map, so we must rename them. - if param_repl: - consumer_subsets = copy.deepcopy(consumer_subsets) - for consumer_subset in consumer_subsets: - symbolic.safe_replace( - mapping=param_repl, replace_callback=consumer_subset.replace - ) - - # Now we are checking if a single iteration of the first (top) map - # can satisfy all data requirements of the second (bottom) map. - # For this we look if the producer covers the consumer. A consumer must - # be covered by exactly one producer. - for consumer_subset in consumer_subsets: - nb_coverings = sum( - producer_subset.covers(consumer_subset) for producer_subset in producer_subsets - ) - if nb_coverings != 1: - return None - - # After we have ensured coverage, we have to decide if the intermediate - # node can be removed (`\mathbb{E}`) or has to be restored (`\mathbb{S}`). - # Note that "removed" here means that it is reconstructed by a new - # output of the second map. - if self.is_shared_data(data=intermediate_node, state=state, sdfg=sdfg): - # The intermediate data is used somewhere else, either in this or another state. - # NOTE: If the intermediate is shared, then we will turn it into a - # sink node attached to the combined map exit. Technically this - # should be enough, even if the same data appears again in the - # dataflow down streams. However, some DaCe transformations, - # I am looking at you `auto_optimizer()` do not like that. Thus - # if the intermediate is used further down in the same datadflow, - # then we consider that the maps can not be fused. But we only - # do this in the strict data flow mode. - if self.strict_dataflow: - if self._is_data_accessed_downstream( - data=intermediate_node.data, - graph=state, - begin=intermediate_node, # is ignored itself. - ): - return None - shared_outputs.add(out_edge) - else: - # The intermediate can be removed, as it is not used anywhere else. - exclusive_outputs.add(out_edge) - - assert len(processed_inter_nodes) == sum( - len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs] - ) - return (pure_outputs, exclusive_outputs, shared_outputs) - - def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: - """Performs the serial Map fusing. - - The function first computes the map decomposition and then handles the - three sets. The pure outputs are handled by `relocate_nodes()` while - the two intermediate sets are handled by `handle_intermediate_set()`. - - By assumption we do not have to rename anything. - - Args: - graph: The SDFG state we are operating on. - sdfg: The SDFG we are operating on. - """ - # NOTE: `self.map_*` actually stores the ID of the node. - # once we start adding and removing nodes it seems that their ID changes. - # Thus we have to save them here, this is a known behaviour in DaCe. - assert isinstance(graph, dace.SDFGState) - assert isinstance(self.map_exit_1, nodes.MapExit) - assert isinstance(self.map_entry_2, nodes.MapEntry) - - map_exit_1: nodes.MapExit = self.map_exit_1 - map_entry_2: nodes.MapEntry = self.map_entry_2 - map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry_2) - map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit_1) - - # Before we do anything we perform the renaming. - self.rename_map_parameters( - first_map=map_exit_1.map, - second_map=map_entry_2.map, - second_map_entry=map_entry_2, - state=graph, - ) - - output_partition = self.partition_first_outputs( - state=graph, - sdfg=sdfg, - first_map_exit=map_exit_1, - second_map_entry=map_entry_2, - ) - assert output_partition is not None # Make MyPy happy. - pure_outputs, exclusive_outputs, shared_outputs = output_partition - - if len(exclusive_outputs) != 0: - self.handle_intermediate_set( - intermediate_outputs=exclusive_outputs, - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - map_exit_2=map_exit_2, - is_exclusive_set=True, - ) - if len(shared_outputs) != 0: - self.handle_intermediate_set( - intermediate_outputs=shared_outputs, - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - map_exit_2=map_exit_2, - is_exclusive_set=False, - ) - assert pure_outputs == set(graph.out_edges(map_exit_1)) - if len(pure_outputs) != 0: - self.relocate_nodes( - from_node=map_exit_1, - to_node=map_exit_2, - state=graph, - sdfg=sdfg, - ) - - # Above we have handled the input of the second map and moved them - # to the first map, now we must move the output of the first map - # to the second one, as this one is used. - self.relocate_nodes( - from_node=map_entry_2, - to_node=map_entry_1, - state=graph, - sdfg=sdfg, - ) - - for node_to_remove in [map_exit_1, map_entry_2]: - assert graph.degree(node_to_remove) == 0 - graph.remove_node(node_to_remove) - - # Now turn the second output node into the output node of the first Map. - map_exit_2.map = map_entry_1.map - - def handle_intermediate_set( - self, - intermediate_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]], - state: SDFGState, - sdfg: SDFG, - map_exit_1: nodes.MapExit, - map_entry_2: nodes.MapEntry, - map_exit_2: nodes.MapExit, - is_exclusive_set: bool, - ) -> None: - """This function handles the intermediate sets. - - The function is able to handle both the shared and exclusive intermediate - output set, see `partition_first_outputs()`. The main difference is that - in exclusive mode the intermediate nodes will be fully removed from - the SDFG. While in shared mode the intermediate node will be preserved. - The function assumes that the parameter renaming was already done. - - Args: - intermediate_outputs: The set of outputs, that should be processed. - state: The state in which the map is processed. - sdfg: The SDFG that should be optimized. - map_exit_1: The exit of the first/top map. - map_entry_2: The entry of the second map. - map_exit_2: The exit of the second map. - is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. - - Notes: - Before the transformation the `state` does not have to be valid and - after this function has run the state is (most likely) invalid. - """ - first_map_exit = map_exit_1 - second_map_entry = map_entry_2 - second_map_exit = map_exit_2 - map_params = first_map_exit.map.params.copy() - - # Now we will iterate over all intermediate edges and process them. - # If not stated otherwise the comments assume that we run in exclusive mode. - for out_edge in intermediate_outputs: - # This is the intermediate node that, that we want to get rid of. - # In shared mode we want to recreate it after the second map. - inter_node: nodes.AccessNode = out_edge.dst - inter_name = inter_node.data - inter_desc = inter_node.desc(sdfg) - - # Now we will determine the shape of the new intermediate. This size of - # this temporary is given by the Memlet that goes into the first map exit. - pre_exit_edges = list( - state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:]) - ) - if len(pre_exit_edges) != 1: - raise NotImplementedError() - pre_exit_edge = pre_exit_edges[0] - - (new_inter_shape_raw, new_inter_shape, squeezed_dims) = ( - self.compute_reduced_intermediate( - producer_subset=pre_exit_edge.data.dst_subset, - inter_desc=inter_desc, - ) - ) - - # This is the name of the new "intermediate" node that we will create. - # It will only have the shape `new_inter_shape` which is basically its - # output within one Map iteration. - # NOTE: The insertion process might generate a new name. - new_inter_name: str = f"__s{self.state_id}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" - - # Now generate the intermediate data container. - if len(new_inter_shape) == 0: - assert pre_exit_edge.data.subset.num_elements() == 1 - is_scalar = True - new_inter_name, new_inter_desc = sdfg.add_scalar( - new_inter_name, - dtype=inter_desc.dtype, - transient=True, - find_new_name=True, - ) - - else: - assert (pre_exit_edge.data.subset.num_elements() > 1) or all( - x == 1 for x in new_inter_shape - ) - is_scalar = False - new_inter_name, new_inter_desc = sdfg.add_transient( - new_inter_name, - shape=new_inter_shape, - dtype=inter_desc.dtype, - find_new_name=True, - ) - new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) - - # Get the subset that defined into which part of the old intermediate - # the old output edge wrote to. We need that to adjust the producer - # Memlets, since they now write into the new (smaller) intermediate. - producer_offset = self.compute_offset_subset( - original_subset=pre_exit_edge.data.dst_subset, - intermediate_desc=inter_desc, - map_params=map_params, - producer_offset=None, - ) - - # Memlets have a lot of additional informations, to ensure that we get - # all of them, we have to do it this way. The main reason for this is - # to handle the case were the "Memlet reverse direction", i.e. `data` - # refers to the other end of the connection than before. - assert pre_exit_edge.data.dst_subset is not None - new_pre_exit_memlet_src_subset = copy.deepcopy(pre_exit_edge.data.src_subset) - new_pre_exit_memlet_dst_subset = subsets.Range.from_array(new_inter_desc) - - new_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) - new_pre_exit_memlet.data = new_inter_name - - new_pre_exit_edge = state.add_edge( - pre_exit_edge.src, - pre_exit_edge.src_conn, - new_inter_node, - None, - new_pre_exit_memlet, - ) - - # We can update `{src, dst}_subset` only after we have inserted the - # edge, this is because the direction of the Memlet might change. - new_pre_exit_edge.data.src_subset = new_pre_exit_memlet_src_subset - new_pre_exit_edge.data.dst_subset = new_pre_exit_memlet_dst_subset - - # We now handle the MemletTree defined by this edge. - # The newly created edge, only handled the last collection step. - for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children( - include_self=False - ): - producer_edge = producer_tree.edge - - # In order to preserve the intrinsic direction of Memlets we only have to change - # the `.data` attribute of the producer Memlet if it refers to the old intermediate. - # If it refers to something different we keep it. Note that this case can only - # occur if the producer is an AccessNode. - if producer_edge.data.data == inter_name: - producer_edge.data.data = new_inter_name - - # Regardless of the intrinsic direction of the Memlet, the subset we care about - # is always `dst_subset`. - if is_scalar: - producer_edge.data.dst_subset = "0" - elif producer_edge.data.dst_subset is not None: - # Since we now write into a smaller memory patch, we must - # compensate for that. We do this by substracting where the write - # originally had begun. - producer_edge.data.dst_subset.offset(producer_offset, negative=True) - producer_edge.data.dst_subset.pop(squeezed_dims) - - # Now after we have handled the input of the new intermediate node, - # we must handle its output. For this we have to "inject" the newly - # created intermediate into the second map. We do this by finding - # the input connectors on the map entry, such that we know where we - # have to reroute inside the Map. - # NOTE: Assumes that map (if connected is the direct neighbour). - conn_names: Set[str] = set() - for inter_node_out_edge in state.out_edges(inter_node): - if inter_node_out_edge.dst == second_map_entry: - assert inter_node_out_edge.dst_conn.startswith("IN_") - conn_names.add(inter_node_out_edge.dst_conn) - else: - # If we found another target than the second map entry from the - # intermediate node it means that the node _must_ survive, - # i.e. we are not in exclusive mode. - assert not is_exclusive_set - - # Now we will reroute the connections inside the second map, i.e. - # instead of consuming the old intermediate node, they will now - # consume the new intermediate node. - for in_conn_name in conn_names: - out_conn_name = "OUT_" + in_conn_name[3:] - - for inner_edge in state.out_edges_by_connector(second_map_entry, out_conn_name): - # As for the producer side, we now read from a smaller array, - # So we must offset them, we use the original edge for this. - assert inner_edge.data.src_subset is not None - consumer_offset = self.compute_offset_subset( - original_subset=inner_edge.data.src_subset, - intermediate_desc=inter_desc, - map_params=map_params, - producer_offset=producer_offset, - ) - - # Now create the memlet for the new consumer. To make sure that we get all attributes - # of the Memlet we make a deep copy of it. There is a tricky part here, we have to - # access `src_subset` however, this is only correctly set once it is put inside the - # SDFG. Furthermore, we have to make sure that the Memlet does not change its direction. - # i.e. that the association of `subset` and `other_subset` does not change. For this - # reason we only modify `.data` attribute of the Memlet if its name refers to the old - # intermediate. Furthermore, to play it safe, we only access the subset, `src_subset` - # after we have inserted it to the SDFG. - new_inner_memlet = copy.deepcopy(inner_edge.data) - if inner_edge.data.data == inter_name: - new_inner_memlet.data = new_inter_name - - # Now we replace the edge from the SDFG. - state.remove_edge(inner_edge) - new_inner_edge = state.add_edge( - new_inter_node, - None, - inner_edge.dst, - inner_edge.dst_conn, - new_inner_memlet, - ) - - # Now modifying the Memlet, we do it after the insertion to make - # sure that the Memlet was properly initialized. - if is_scalar: - new_inner_memlet.subset = "0" - elif new_inner_memlet.src_subset is not None: - # TODO(phimuell): Figuring out if `src_subset` is None is an error. - new_inner_memlet.src_subset.offset(consumer_offset, negative=True) - new_inner_memlet.src_subset.pop(squeezed_dims) - - # Now we have to make sure that all consumers are properly updated. - for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children( - include_self=False - ): - consumer_edge = consumer_tree.edge - - # We only modify the data if the Memlet refers to the old intermediate data. - # We can not do this unconditionally, because it might change the intrinsic - # direction of a Memlet and then `src_subset` would at the next `try_initialize` - # be wrong. Note that this case only occurs if the destination is an AccessNode. - if consumer_edge.data.data == inter_name: - consumer_edge.data.data = new_inter_name - - # Now we have to adapt the subsets. - if is_scalar: - consumer_edge.data.src_subset = "0" - elif consumer_edge.data.src_subset is not None: - # TODO(phimuell): Figuring out if `src_subset` is None is an error. - consumer_edge.data.src_subset.offset(consumer_offset, negative=True) - consumer_edge.data.src_subset.pop(squeezed_dims) - - # The edge that leaves the second map entry was already deleted. We now delete - # the edges that connected the intermediate node with the second map entry. - for edge in list(state.in_edges_by_connector(second_map_entry, in_conn_name)): - assert edge.src == inter_node - state.remove_edge(edge) - second_map_entry.remove_in_connector(in_conn_name) - second_map_entry.remove_out_connector(out_conn_name) - - if is_exclusive_set: - # In exclusive mode the old intermediate node is no longer needed. - # This will also remove `out_edge` from the SDFG. - assert state.degree(inter_node) == 1 - state.remove_edge_and_connectors(out_edge) - state.remove_node(inter_node) - - state.remove_edge(pre_exit_edge) - first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) - first_map_exit.remove_out_connector(out_edge.src_conn) - del sdfg.arrays[inter_name] - - else: - # TODO(phimuell): Lift this restriction - assert pre_exit_edge.data.data == inter_name - - # This is the shared mode, so we have to recreate the intermediate - # node, but this time it is at the exit of the second map. - state.remove_edge(pre_exit_edge) - first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) - - # This is the Memlet that goes from the map internal intermediate - # temporary node to the Map output. This will essentially restore - # or preserve the output for the intermediate node. It is important - # that we use the data that `preExitEdge` was used. - final_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) - final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc) - - new_pre_exit_conn = second_map_exit.next_connector() - state.add_edge( - new_inter_node, - None, - second_map_exit, - "IN_" + new_pre_exit_conn, - final_pre_exit_memlet, - ) - state.add_edge( - second_map_exit, - "OUT_" + new_pre_exit_conn, - inter_node, - out_edge.dst_conn, - copy.deepcopy(out_edge.data), - ) - second_map_exit.add_in_connector("IN_" + new_pre_exit_conn) - second_map_exit.add_out_connector("OUT_" + new_pre_exit_conn) - - first_map_exit.remove_out_connector(out_edge.src_conn) - state.remove_edge(out_edge) - - def compute_reduced_intermediate( - self, - producer_subset: subsets.Range, - inter_desc: dace.data.Data, - ) -> Tuple[Tuple[int, ...], Tuple[int, ...], List[int]]: - """Compute the size of the new (reduced) intermediate. - - `MapFusion` does not only fuses map, but, depending on the situation, also - eliminates intermediate arrays between the two maps. To transmit data between - the two maps a new, but much smaller intermediate is needed. - - :return: The function returns a tuple with three values with the following meaning: - * The raw shape of the reduced intermediate. - * The cleared shape of the reduced intermediate, essentially the raw shape - with all shape 1 dimensions removed. - * Which dimensions of the raw shape have been removed to get the cleared shape. - - :param producer_subset: The subset that was used to write into the intermediate. - :param inter_desc: The data descriptor for the intermediate. - """ - assert producer_subset is not None - - # Over approximation will leave us with some unneeded size one dimensions. - # If they are removed some dace transformations (especially auto optimization) - # will have problems. - new_inter_shape_raw = symbolic.overapproximate(producer_subset.size()) - inter_shape = inter_desc.shape - if not self.strict_dataflow: - squeezed_dims: List[int] = [] # These are the dimensions we removed. - new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. - for dim, (proposed_dim_size, full_dim_size) in enumerate( - zip(new_inter_shape_raw, inter_shape, strict=True) - ): - if full_dim_size == 1: # Must be kept! - new_inter_shape.append(proposed_dim_size) - elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. - squeezed_dims.append(dim) - else: - new_inter_shape.append(proposed_dim_size) - else: - squeezed_dims = [] - new_inter_shape = list(new_inter_shape_raw) - - return (tuple(new_inter_shape_raw), tuple(new_inter_shape), squeezed_dims) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_promoter.py index 90ad67e7cb..14f5f56689 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_promoter.py @@ -372,9 +372,9 @@ def _test_if_promoted_maps_can_be_fused( "only_inner_maps": self.only_inner_maps, "only_toplevel_maps": self.only_toplevel_maps, }, - map_exit_1=first_map_exit, - intermediate_access_node=access_node, - map_entry_2=second_map_entry, + first_map_exit=first_map_exit, + array=access_node, + second_map_entry=second_map_entry, ): return False diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py index cd4ad77787..154c2bb46e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py @@ -21,6 +21,8 @@ transformations as gtx_transformations, ) +import dace + from . import util @@ -64,11 +66,11 @@ def _make_serial_sdfg_1( state.add_mapped_tasklet( name="second_computation", - map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], + map_ranges=[("__i4", f"0:{N}"), ("__i6", f"0:{N}")], input_nodes={tmp}, - inputs={"__in0": dace.Memlet("tmp[__i0, __i1]")}, + inputs={"__in0": dace.Memlet("tmp[__i4, __i6]")}, code="__out = __in0 + 3.0", - outputs={"__out": dace.Memlet("b[__i0, __i1]")}, + outputs={"__out": dace.Memlet("b[__i4, __i6]")}, external_edges=True, ) @@ -133,11 +135,11 @@ def _make_serial_sdfg_2( ) state.add_mapped_tasklet( name="second_computation", - map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], + map_ranges=[("__i3", f"0:{N}"), ("__i6", f"0:{N}")], input_nodes={tmp_2}, - inputs={"__in0": dace.Memlet("tmp_2[__i0, __i1]")}, + inputs={"__in0": dace.Memlet("tmp_2[__i3, __i6]")}, code="__out = __in0 - 3.0", - outputs={"__out": dace.Memlet("c[__i0, __i1]")}, + outputs={"__out": dace.Memlet("c[__i3, __i6]")}, external_edges=True, ) @@ -197,20 +199,68 @@ def _make_serial_sdfg_3( state.add_mapped_tasklet( name="indirect_access", - map_ranges=[("__i0", f"0:{N_output}")], + map_ranges=[("__i1", f"0:{N_output}")], input_nodes={tmp}, inputs={ - "__index": dace.Memlet("idx[__i0]"), + "__index": dace.Memlet("idx[__i1]"), "__array": dace.Memlet.simple("tmp", subset_str=f"0:{N_input}", num_accesses=1), }, code="__out = __array[__index]", - outputs={"__out": dace.Memlet("c[__i0]")}, + outputs={"__out": dace.Memlet("c[__i1]")}, external_edges=True, ) return sdfg +def _make_parallel_sdfg_1( + single_input_node: bool, +) -> tuple[dace.SDFG, dace.SDFGState]: + """Make a parallel SDFG. + + The maps access both the same Data but uses different AccessNodes for that. + If `single_input_node` is `True` then there will only one AccessNode for `a` + be created, otherwise each map has its own. + """ + sdfg = dace.SDFG(util.unique_name("parallel_sdfg_1")) + state = sdfg.add_state(is_start_block=True) + + for name in "abc": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + + a1, b, c = (state.add_access(name) for name in "abc") + a2 = a1 if single_input_node else state.add_access("a") + + state.add_mapped_tasklet( + "map1", + map_ranges={"__i0": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0]")}, + code="__out = __in + 10.", + outputs={"__out": dace.Memlet("b[__i0]")}, + input_nodes={a1}, + output_nodes={b}, + external_edges=True, + ) + state.add_mapped_tasklet( + "map2", + map_ranges={"__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i1]")}, + code="__out = __in - 10.", + outputs={"__out": dace.Memlet("c[__i1]")}, + input_nodes={a2}, + output_nodes={c}, + external_edges=True, + ) + sdfg.validate() + + return sdfg, state + + def test_exclusive_itermediate(): """Tests if the exclusive intermediate branch works.""" N = 10 @@ -507,3 +557,37 @@ def test_indirect_access_2(): validate_all=True, ) assert count == 0 + + +def test_parallel_1(): + sdfg, state = _make_parallel_sdfg_1(single_input_node=False) + assert util.count_nodes(state, dace_nodes.AccessNode) == 4 + assert util.count_nodes(state, dace_nodes.MapEntry) == 2 + + # Because we request a common ancestor it will not apply. + # NOTE: We might have to change that if the implementation changes. + nb_applies = sdfg.apply_transformations_repeated( + [gtx_transformations.MapFusionParallel(only_if_common_ancestor=True)] + ) + assert nb_applies == 0 + + # If we do not restrict common ancestor then it will work. + nb_applies = sdfg.apply_transformations_repeated( + [gtx_transformations.MapFusionParallel(only_if_common_ancestor=False)] + ) + + assert nb_applies == 1 + assert util.count_nodes(state, dace_nodes.AccessNode) == 4 + assert util.count_nodes(state, dace_nodes.MapEntry) == 1 + + +def test_parallel_2(): + sdfg, state = _make_parallel_sdfg_1(single_input_node=True) + assert util.count_nodes(state, dace_nodes.AccessNode) == 3 + assert util.count_nodes(state, dace_nodes.MapEntry) == 2 + + nb_applies = sdfg.apply_transformations_repeated([gtx_transformations.MapFusionParallel()]) + + assert nb_applies == 1 + assert util.count_nodes(state, dace_nodes.AccessNode) == 3 + assert util.count_nodes(state, dace_nodes.MapEntry) == 1 From 29f95816efe7b95129a760b8120e48f7eec86316 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 14 Feb 2025 11:52:02 +0100 Subject: [PATCH 144/178] bug[next]: Fix domain pickle after `slice_at` call (#1865) --- src/gt4py/next/common.py | 6 +++++ .../embedded_tests/test_domain_pickle.py | 22 +++++++++++++++++++ 2 files changed, 28 insertions(+) create mode 100644 tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 9b2870e1c0..e5b393f1ae 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -574,6 +574,12 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: return Domain(dims=dims, ranges=ranges) + def __getstate__(self) -> dict[str, Any]: + state = self.__dict__.copy() + # remove cached property + state.pop("slice_at", None) + return state + FiniteDomain: TypeAlias = Domain[FiniteUnitRange] diff --git a/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py b/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py new file mode 100644 index 0000000000..b69950928d --- /dev/null +++ b/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py @@ -0,0 +1,22 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pickle + +from gt4py.next import common + +I = common.Dimension("I") +J = common.Dimension("J") + + +def test_domain_pickle_after_slice(): + domain = common.domain(((I, (2, 4)), (J, (3, 5)))) + # use slice_at to populate cached property + domain.slice_at[2:5, 5:7] + + pickle.dumps(domain) From c76f6d3cb9bda87c3055883e19abe0ffb2906736 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 14 Feb 2025 11:52:13 +0100 Subject: [PATCH 145/178] bug[next]: Fix applied lift extraction in CSE (#1864) --- src/gt4py/next/iterator/transforms/cse.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index ccaaf563f5..cc1ffc3c50 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -82,11 +82,13 @@ def _is_collectable_expr(node: itir.Node) -> bool: if isinstance(node, itir.FunCall): # do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be # visited, to ensure symbol dependencies are recognized correctly. - # do also not collect reduce nodes if they are left in the it at this point, this may lead to + # do also not collect reduce nodes if they are left in the IR at this point, this may lead to # conceptual problems (other parts of the tool chain rely on the arguments being present directly # on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend # backend (single pass eager depth first visit approach) - if isinstance(node.fun, itir.SymRef) and node.fun.id in ["lift", "shift", "reduce", "map_"]: + # do also not collect lifts or applied lifts as they become invisible to the lift inliner + # otherwise + if cpm.is_call_to(node, ("lift", "shift", "reduce", "map_")) or cpm.is_applied_lift(node): return False return True elif isinstance(node, itir.Lambda): From 8595cfde23b3dd92c679eccb537c07d0a52d2370 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Fri, 14 Feb 2025 13:20:28 +0100 Subject: [PATCH 146/178] fix[next][dace]: Fixed a bug in the map fusion's default parameter (#1866) `MapFusion` was behaving as the one from DaCe, i.e. it only did serial fusion by default, now it allows both by default. There are still `MapFusionSerial` and `MapFusionParallel` with the respective settings. This commit also changes auto optimizer to explicitly request them. --- .../dace/transformations/auto_optimize.py | 2 ++ .../runners/dace/transformations/map_fusion.py | 15 ++++++++++++--- .../dace/transformations/map_fusion_dace.py | 16 +++++++++++++--- .../transformation_tests/test_map_fusion.py | 15 +++++++++++++++ 4 files changed, 42 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index d6e9fc259d..684ec9e764 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -343,6 +343,8 @@ def gt_auto_fuse_top_level_maps( fusion_transformation = gtx_transformations.MapFusion( only_toplevel_maps=True, + allow_parallel_map_fusion=True, + allow_serial_map_fusion=True, only_if_common_ancestor=False, ) fusion_transformation._single_use_data = single_use_data diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion.py index 00828520c8..0f1dabf0d2 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion.py @@ -49,8 +49,10 @@ class MapFusion(dace_map_fusion.MapFusion): It is a wrapper that adds some functionality to the transformation that is not present in the DaCe version of this transformation. - There are two important differences when compared with DaCe's MapFusion: + There are three important differences when compared with DaCe's MapFusion: - In DaCe strict data flow is enabled by default, in GT4Py it is disabled by default. + - In DaCe `MapFusion` only performs the fusion of serial maps by default. In GT4Py + `MapFusion` will also perform parallel map fusion by default. - GT4Py accepts an additional argument `apply_fusion_callback`. This is a function that is called by the transformation, at the _beginning_ of `self.can_be_applied()`, i.e. before the transformation does any check if @@ -65,7 +67,7 @@ class MapFusion(dace_map_fusion.MapFusion): strict_dataflow: Strict dataflow mode should be used, it is disabled by default. assume_always_shared: Assume that all intermediates are shared. allow_serial_map_fusion: Allow serial map fusion, by default `True`. - allow_parallel_fusion: Allow to merge parallel maps, by default `False`. + allow_parallel_fusion: Allow to merge parallel maps, by default `True`. only_if_common_ancestor: In parallel map fusion mode, only fuse if both maps have a common direct ancestor. apply_fusion_callback: The callback function that is used. @@ -81,11 +83,18 @@ class MapFusion(dace_map_fusion.MapFusion): def __init__( self, strict_dataflow: bool = False, + allow_serial_map_fusion: bool = True, + allow_parallel_map_fusion: bool = True, apply_fusion_callback: Optional[FusionTestCallback] = None, **kwargs: Any, ) -> None: self._apply_fusion_callback = None - super().__init__(strict_dataflow=strict_dataflow, **kwargs) + super().__init__( + strict_dataflow=strict_dataflow, + allow_serial_map_fusion=allow_serial_map_fusion, + allow_parallel_map_fusion=allow_parallel_map_fusion, + **kwargs, + ) if apply_fusion_callback is not None: self._apply_fusion_callback = apply_fusion_callback diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_dace.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_dace.py index c301ce0ac4..67fbe4182d 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_dace.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_dace.py @@ -293,12 +293,19 @@ def can_parallel_map_fusion_be_applied( sdfg: dace.SDFG, ) -> bool: """Check if the matched Maps can be fused in parallel.""" - assert self.expr_index == 1 # NOTE: The after this point it is not legal to access the matched nodes first_map_entry: nodes.MapEntry = self.first_parallel_map_entry second_map_entry: nodes.MapEntry = self.second_parallel_map_entry + assert self.expr_index == 1 + assert isinstance(first_map_entry, nodes.MapEntry) + assert isinstance(second_map_entry, nodes.MapEntry) + + # We will now check if the two maps are parallel. + if not self.is_parallel(graph=graph, node1=first_map_entry, node2=second_map_entry): + return False + # Check the structural properties of the Maps. The function will return # the `dict` that describes how the parameters must be renamed (for caching) # or `None` if the maps can not be structurally fused. @@ -333,13 +340,16 @@ def can_serial_map_fusion_be_applied( * Tests if there are read write dependencies. * Tests if the decomposition exists. """ - assert self.expr_index == 0 - # NOTE: The after this point it is not legal to access the matched nodes first_map_entry: nodes.MapEntry = graph.entry_node(self.first_map_exit) first_map_exit: nodes.MapExit = self.first_map_exit second_map_entry: nodes.MapEntry = self.second_map_entry + assert self.expr_index == 0 + assert isinstance(first_map_exit, nodes.MapExit) + assert isinstance(second_map_entry, nodes.MapEntry) + assert isinstance(self.array, nodes.AccessNode) + # Check the structural properties of the Maps. The function will return # the `dict` that describes how the parameters must be renamed (for caching) # or `None` if the maps can not be structurally fused. diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py index 154c2bb46e..ecf5a4762b 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py @@ -591,3 +591,18 @@ def test_parallel_2(): assert nb_applies == 1 assert util.count_nodes(state, dace_nodes.AccessNode) == 3 assert util.count_nodes(state, dace_nodes.MapEntry) == 1 + + +def test_parallel_3(): + """Test that the parallel map fusion does not apply for serial maps.""" + sdfg = _make_serial_sdfg_1(20) + assert util.count_nodes(sdfg, dace_nodes.AccessNode) == 3 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 + + # Because the maps are fully serial, parallel map fusion should never apply. + nb_applies = sdfg.apply_transformations_repeated( + [gtx_transformations.MapFusionParallel(only_if_common_ancestor=False)] + ) + assert nb_applies == 0 + assert util.count_nodes(sdfg, dace_nodes.AccessNode) == 3 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 From 3f85165916ca1359a053fdf900dc4d254ecec1c8 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 17 Feb 2025 10:05:11 +0100 Subject: [PATCH 147/178] refactor[cartesian]: unexpanded sdfg cleanups (#1860) ## Description Refactors from (recent) debugging sessions around transient arrays in the "unexpanded SDFG" (the one with the library nodes): - Remove unused `**kwargs` from `OirSDFGBuilder` - Forward debugging information about transient arrays to DaCe - Use a (constant) variable for connector prefixes of data going into/out of the library nodes - Configure the lifetime of transient arrays directly in `OirSDFGBuilder` This is a follow-up from PR https://github.com/GridTools/gt4py/pull/1843. In this PR, we separate the DaCe backend cleanup from the refactor around (re-)using `DeviceType` instead of `"cpu" | "gpu"` string literals. ## Requirements - [x] All fixes and/or new features come with corresponding tests. Should be covered by existing tests. - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A --------- Co-authored-by: Roman Cattaneo <> Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Co-authored-by: Florian Deconinck --- src/gt4py/cartesian/backend/dace_backend.py | 4 --- src/gt4py/cartesian/gtc/dace/constants.py | 15 +++++++++ .../cartesian/gtc/dace/expansion/expansion.py | 17 ++++++---- .../gtc/dace/expansion/sdfg_builder.py | 10 +++--- .../cartesian/gtc/dace/expansion/utils.py | 14 -------- src/gt4py/cartesian/gtc/dace/nodes.py | 6 ++-- src/gt4py/cartesian/gtc/dace/oir_to_dace.py | 32 +++++++++++-------- src/gt4py/cartesian/gtc/dace/utils.py | 9 ++++++ 8 files changed, 61 insertions(+), 46 deletions(-) create mode 100644 src/gt4py/cartesian/gtc/dace/constants.py diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 5b822a1ab5..8ca18705c9 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -151,10 +151,6 @@ def _pre_expand_transformations(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, la sdfg.add_state(gtir_pipeline.gtir.name) return sdfg - for array in sdfg.arrays.values(): - if array.transient: - array.lifetime = dace.AllocationLifetime.Persistent - sdfg.simplify(validate=False) _set_expansion_orders(sdfg) diff --git a/src/gt4py/cartesian/gtc/dace/constants.py b/src/gt4py/cartesian/gtc/dace/constants.py new file mode 100644 index 0000000000..5565f1c186 --- /dev/null +++ b/src/gt4py/cartesian/gtc/dace/constants.py @@ -0,0 +1,15 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +from typing import Final + + +# StencilComputation in/out connector prefixes +CONNECTOR_PREFIX_IN: Final = "__in_" +CONNECTOR_PREFIX_OUT: Final = "__out_" diff --git a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py index 27f55d451d..20d7743661 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py @@ -18,6 +18,7 @@ import sympy from gt4py.cartesian.gtc.dace import daceir as dcir +from gt4py.cartesian.gtc.dace.constants import CONNECTOR_PREFIX_IN, CONNECTOR_PREFIX_OUT from gt4py.cartesian.gtc.dace.expansion.daceir_builder import DaCeIRBuilder from gt4py.cartesian.gtc.dace.expansion.sdfg_builder import StencilComputationSDFGBuilder @@ -77,11 +78,11 @@ def _fix_context( """ # change connector names for in_edge in parent_state.in_edges(node): - assert in_edge.dst_conn.startswith("__in_") - in_edge.dst_conn = in_edge.dst_conn[len("__in_") :] + assert in_edge.dst_conn.startswith(CONNECTOR_PREFIX_IN) + in_edge.dst_conn = in_edge.dst_conn.removeprefix(CONNECTOR_PREFIX_IN) for out_edge in parent_state.out_edges(node): - assert out_edge.src_conn.startswith("__out_") - out_edge.src_conn = out_edge.src_conn[len("__out_") :] + assert out_edge.src_conn.startswith(CONNECTOR_PREFIX_OUT) + out_edge.src_conn = out_edge.src_conn.removeprefix(CONNECTOR_PREFIX_OUT) # union input and output subsets subsets = {} @@ -125,9 +126,13 @@ def _get_parent_arrays( ) -> Dict[str, dace.data.Data]: parent_arrays: Dict[str, dace.data.Data] = {} for edge in (e for e in parent_state.in_edges(node) if e.dst_conn is not None): - parent_arrays[edge.dst_conn[len("__in_") :]] = parent_sdfg.arrays[edge.data.data] + parent_arrays[edge.dst_conn.removeprefix(CONNECTOR_PREFIX_IN)] = parent_sdfg.arrays[ + edge.data.data + ] for edge in (e for e in parent_state.out_edges(node) if e.src_conn is not None): - parent_arrays[edge.src_conn[len("__out_") :]] = parent_sdfg.arrays[edge.data.data] + parent_arrays[edge.src_conn.removeprefix(CONNECTOR_PREFIX_OUT)] = parent_sdfg.arrays[ + edge.data.data + ] return parent_arrays @staticmethod diff --git a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py index 6728ccaa7d..3aeda7a484 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py @@ -20,9 +20,8 @@ from gt4py import eve from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.expansion.tasklet_codegen import TaskletCodegen -from gt4py.cartesian.gtc.dace.expansion.utils import get_dace_debuginfo from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass -from gt4py.cartesian.gtc.dace.utils import make_dace_subset +from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo, make_dace_subset class StencilComputationSDFGBuilder(eve.VisitorWithSymbolTableTrait): @@ -268,13 +267,13 @@ def visit_ComputationState( for memlet in computation.read_memlets: if memlet.field not in read_acc_and_conn: read_acc_and_conn[memlet.field] = ( - sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), + sdfg_ctx.state.add_access(memlet.field), None, ) for memlet in computation.write_memlets: if memlet.field not in write_acc_and_conn: write_acc_and_conn[memlet.field] = ( - sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), + sdfg_ctx.state.add_access(memlet.field), None, ) node_ctx = StencilComputationSDFGBuilder.NodeContext( @@ -298,7 +297,7 @@ def visit_FieldDecl( dtype=data_type_to_dace_typeclass(node.dtype), storage=node.storage.to_dace_storage(), transient=node.name not in non_transients, - debuginfo=dace.DebugInfo(0), + debuginfo=get_dace_debuginfo(node), ) def visit_SymbolDecl( @@ -343,7 +342,6 @@ def visit_NestedSDFG( inputs=node.input_connectors, outputs=node.output_connectors, symbol_mapping=symbol_mapping, - debuginfo=dace.DebugInfo(0), ) self.visit( node.read_memlets, diff --git a/src/gt4py/cartesian/gtc/dace/expansion/utils.py b/src/gt4py/cartesian/gtc/dace/expansion/utils.py index 919ec02996..7a29ec99a6 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/utils.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/utils.py @@ -10,11 +10,6 @@ from typing import TYPE_CHECKING, List -import dace -import dace.data -import dace.library -import dace.subsets - from gt4py import eve from gt4py.cartesian.gtc import common, oir from gt4py.cartesian.gtc.dace import daceir as dcir @@ -25,15 +20,6 @@ from gt4py.cartesian.gtc.dace.nodes import StencilComputation -def get_dace_debuginfo(node: common.LocNode): - if node.loc is not None: - return dace.dtypes.DebugInfo( - node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename - ) - else: - return dace.dtypes.DebugInfo(0) - - class HorizontalIntervalRemover(eve.NodeTranslator): def visit_HorizontalMask(self, node: common.HorizontalMask, *, axis: dcir.Axis): mask_attrs = dict(i=node.i, j=node.j) diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py index 34401e18b9..13fb6ecc6e 100644 --- a/src/gt4py/cartesian/gtc/dace/nodes.py +++ b/src/gt4py/cartesian/gtc/dace/nodes.py @@ -23,12 +23,12 @@ from gt4py.cartesian.gtc import common, oir from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.expansion.expansion import StencilComputationExpansion +from gt4py.cartesian.gtc.dace.expansion.utils import HorizontalExecutionSplitter +from gt4py.cartesian.gtc.dace.expansion_specification import ExpansionItem, make_expansion_order +from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo from gt4py.cartesian.gtc.definitions import Extent from gt4py.cartesian.gtc.oir import Decl, FieldDecl, VerticalLoop, VerticalLoopSection -from .expansion.utils import HorizontalExecutionSplitter, get_dace_debuginfo -from .expansion_specification import ExpansionItem, make_expansion_order - def _set_expansion_order( node: StencilComputation, expansion_order: Union[List[ExpansionItem], List[str]] diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py index 14448bb08e..9dd66bac82 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py @@ -18,9 +18,14 @@ import gt4py.cartesian.gtc.oir as oir from gt4py import eve from gt4py.cartesian.gtc.dace import daceir as dcir +from gt4py.cartesian.gtc.dace.constants import CONNECTOR_PREFIX_IN, CONNECTOR_PREFIX_OUT from gt4py.cartesian.gtc.dace.nodes import StencilComputation from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass -from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos, make_dace_subset +from gt4py.cartesian.gtc.dace.utils import ( + compute_dcir_access_infos, + get_dace_debuginfo, + make_dace_subset, +) from gt4py.cartesian.gtc.definitions import Extent from gt4py.cartesian.gtc.passes.oir_optimizations.utils import ( AccessCollector, @@ -93,9 +98,7 @@ def _make_dace_subset(self, local_access_info, field): global_access_info, local_access_info, self.decls[field].data_dims ) - def visit_VerticalLoop( - self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext, **kwargs - ): + def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext): declarations = { acc.name: ctx.decls[acc.name] for acc in node.walk_values().if_isinstance(oir.FieldAccess, oir.ScalarAccess) @@ -117,22 +120,24 @@ def visit_VerticalLoop( access_collection = AccessCollector.apply(node) for field in access_collection.read_fields(): - access_node = state.add_access(field, debuginfo=dace.DebugInfo(0)) - library_node.add_in_connector("__in_" + field) + access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field])) + connector_name = f"{CONNECTOR_PREFIX_IN}{field}" + library_node.add_in_connector(connector_name) subset = ctx.make_input_dace_subset(node, field) state.add_edge( - access_node, None, library_node, "__in_" + field, dace.Memlet(field, subset=subset) + access_node, None, library_node, connector_name, dace.Memlet(field, subset=subset) ) for field in access_collection.write_fields(): - access_node = state.add_access(field, debuginfo=dace.DebugInfo(0)) - library_node.add_out_connector("__out_" + field) + access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field])) + connector_name = f"{CONNECTOR_PREFIX_OUT}{field}" + library_node.add_out_connector(connector_name) subset = ctx.make_output_dace_subset(node, field) state.add_edge( - library_node, "__out_" + field, access_node, None, dace.Memlet(field, subset=subset) + library_node, connector_name, access_node, None, dace.Memlet(field, subset=subset) ) - def visit_Stencil(self, node: oir.Stencil, **kwargs): + def visit_Stencil(self, node: oir.Stencil): ctx = OirSDFGBuilder.SDFGContext(stencil=node) for param in node.params: if isinstance(param, oir.FieldDecl): @@ -148,7 +153,7 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs): ], dtype=data_type_to_dace_typeclass(param.dtype), transient=False, - debuginfo=dace.DebugInfo(0), + debuginfo=get_dace_debuginfo(param), ) else: ctx.sdfg.add_symbol(param.name, stype=data_type_to_dace_typeclass(param.dtype)) @@ -166,7 +171,8 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs): ], dtype=data_type_to_dace_typeclass(decl.dtype), transient=True, - debuginfo=dace.DebugInfo(0), + lifetime=dace.AllocationLifetime.Persistent, + debuginfo=get_dace_debuginfo(decl), ) self.generic_visit(node, ctx=ctx) ctx.sdfg.validate() diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index bd65861a49..4e8a0f0c7b 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -23,6 +23,15 @@ from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_horizontal_block_extents +def get_dace_debuginfo(node: common.LocNode) -> dace.dtypes.DebugInfo: + if node.loc is None: + return dace.dtypes.DebugInfo(0) + + return dace.dtypes.DebugInfo( + node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename + ) + + def array_dimensions(array: dace.data.Array): dims = [ any( From f7f2a0a9a52d54ab3881777a69dc3fe46c8dde62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Mon, 17 Feb 2025 16:20:01 +0100 Subject: [PATCH 148/178] fix[dace][next]: Fixed name of pattern nodes (#1871) In PR#1857 we updated the `MapFusion` because of that the names of the pattern nodes have been changed. Apparently, in that PR not all occurrences of the old names have been replaced. --- .../runners/dace/transformations/gpu_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index 6359cc1127..10ed652ec2 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -718,9 +718,9 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non if not self.do_not_fuse: gtx_transformations.MapFusionSerial.apply_to( sdfg=sdfg, - map_exit_1=trivial_map_exit, - intermediate_access_node=access_node, - map_entry_2=second_map_entry, + first_map_exit=trivial_map_exit, + array=access_node, + second_map_entry=second_map_entry, verify=True, ) From cae8753fd3fa1c740e9c18c9fe5ba95580cb0d13 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 17 Feb 2025 16:33:08 +0100 Subject: [PATCH 149/178] ci: Read env variable from os environment, not from nox (#1869) Contains two types of changes: - Cleanup of `DOCKER_BUILD_ARGS`: this variable is passed to the docker build, but not used, which occasionally triggers an unnecessary rebuild. - As consequence of `DOCKER_BUILD_ARGS` cleanup the base image has been rebuilt. In the new image we pull a new nox version, which removes the os environment variables from the nox session environment (refer to nox PR https://github.com/wntrblm/nox/pull/874). Therefore, such variables need to be read from the os environment. --- ci/cscs-ci.yml | 2 +- noxfile.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index 05955913ba..712c2450d6 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -34,7 +34,7 @@ stages: DOCKERFILE: ci/base.Dockerfile # change to 'always' if you want to rebuild, even if target tag exists already (if-not-exists is the default, i.e. we could also skip the variable) CSCS_REBUILD_POLICY: if-not-exists - DOCKER_BUILD_ARGS: '["CUDA_VERSION=$CUDA_VERSION", "CUPY_PACKAGE=$CUPY_PACKAGE", "CUPY_VERSION=$CUPY_VERSION", "UBUNTU_VERSION=$UBUNTU_VERSION", "PYVERSION=$PYVERSION", "CI_PROJECT_DIR=$CI_PROJECT_DIR"]' + DOCKER_BUILD_ARGS: '["CUDA_VERSION=$CUDA_VERSION", "CUPY_PACKAGE=$CUPY_PACKAGE", "CUPY_VERSION=$CUPY_VERSION", "UBUNTU_VERSION=$UBUNTU_VERSION", "PYVERSION=$PYVERSION"]' .build_baseimage_x86_64: extends: [.container-builder-cscs-zen2, .build_baseimage] variables: diff --git a/noxfile.py b/noxfile.py index 81b1354157..3aad565837 100644 --- a/noxfile.py +++ b/noxfile.py @@ -87,7 +87,7 @@ def test_cartesian( groups=["test"], ) - num_processes = session.env.get("NUM_PROCESSES", "auto") + num_processes = os.environ.get("NUM_PROCESSES", "auto") markers = " and ".join(codegen_settings["markers"] + device_settings["markers"]) session.run( @@ -111,7 +111,7 @@ def test_examples(session: nox.Session) -> None: session.run(*"jupytext docs/user/next/QuickstartGuide.md --to .ipynb".split()) session.run(*"jupytext docs/user/next/advanced/*.md --to .ipynb".split()) - num_processes = session.env.get("NUM_PROCESSES", "auto") + num_processes = os.environ.get("NUM_PROCESSES", "auto") for notebook, extra_args in [ ("docs/user/next/workshop/slides", None), ("docs/user/next/workshop/exercises", ["-k", "solutions"]), @@ -131,7 +131,7 @@ def test_eve(session: nox.Session) -> None: _install_session_venv(session, groups=["test"]) - num_processes = session.env.get("NUM_PROCESSES", "auto") + num_processes = os.environ.get("NUM_PROCESSES", "auto") session.run( *f"pytest --cache-clear -sv -n {num_processes}".split(), @@ -180,7 +180,7 @@ def test_next( groups=groups, ) - num_processes = session.env.get("NUM_PROCESSES", "auto") + num_processes = os.environ.get("NUM_PROCESSES", "auto") markers = " and ".join(codegen_settings["markers"] + device_settings["markers"] + mesh_markers) session.run( @@ -211,7 +211,7 @@ def test_storage( session, extras=["performance", "testing", *device_settings["extras"]], groups=["test"] ) - num_processes = session.env.get("NUM_PROCESSES", "auto") + num_processes = os.environ.get("NUM_PROCESSES", "auto") markers = " and ".join(device_settings["markers"]) session.run( From 0d65ae95de5b6fd7b60d07d0a4341bdfa65e7a65 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 18 Feb 2025 09:10:45 +0100 Subject: [PATCH 150/178] fix[next][dace]: use logical and/or/xor operators, not bitwise (#1872) The current mapping from GTIR logical operators to python code was inconsistent. The mapping was using bitwise operators instead of logical ones. This still resulted in functionally correct code because the boolean type has an integer representation in python language. This PR introduces the correct mapping, and leaves the dace toolchain and target compiler the possibility to generate optimized code. --- .../runners/dace/gtir_python_codegen.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py index 763c292836..199783d893 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py @@ -66,11 +66,11 @@ "less_equal": "({} <= {})", "greater": "({} > {})", "greater_equal": "({} >= {})", - "and_": "({} & {})", - "or_": "({} | {})", - "xor_": "({} ^ {})", + "and_": "({} and {})", + "or_": "({} or {})", + "xor_": "({} != {})", "mod": "({} % {})", - "not_": "(not {})", # ~ is not bitwise in numpy + "not_": "(not {})", } From 4456f433019a7db8aa185398ca5a90181a50f1a8 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 20 Feb 2025 01:35:18 +0100 Subject: [PATCH 151/178] fix[next]: Git ignore gt4py cache (#1875) --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index b1c8ed26e9..ebbbfaebeb 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ _local /src/__init__.py /tests/__init__.py .gt_cache/ +.gt4py_cache/ .gt_cache_pytest*/ # DaCe From 1a46fb0f7bce91ddd793c771d2f5d31be26528af Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 20 Feb 2025 13:13:19 +0100 Subject: [PATCH 152/178] fix[next][dace]: remove temporary arrays with runtime shape on the output of a mapped nested SDFG (#1877) This PR provides a better fix than the one delivered earlier in https://github.com/GridTools/gt4py/pull/1828. It adds a check to detect whether the temporary output data has compile-time or runtime size. In case of runtime size, the transient array on the output connector of a mapped nested SDFG is removed. This is needed in order to avoid dynamic memory allocation inside the cuda kernel that represents a parallel map scope. --- .../runners/dace/gtir_dataflow.py | 39 ++++++++++++++----- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py index e6f33208e3..43e7c6354d 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -232,15 +232,36 @@ def connect( dest: dace.nodes.AccessNode, subset: dace_subsets.Range, ) -> None: - # retrieve the node which writes the result - last_node = self.state.in_edges(self.result.dc_node)[0].src - if isinstance(last_node, dace.nodes.Tasklet): - # the last transient node can be deleted - # Note that it could also be applied when `last_node` is a NestedSDFG, - # but an exception would be when the inner write to global data is a - # WCR memlet, because that prevents fusion of the outer map. This case - # happens for the reduce with skip values, which uses a map with WCR. - last_node_connector = self.state.in_edges(self.result.dc_node)[0].src_conn + write_edge = self.state.in_edges(self.result.dc_node)[0] + write_size = write_edge.data.dst_subset.num_elements() + # check the kind of node which writes the result + if isinstance(write_edge.src, dace.nodes.Tasklet): + # The temporary data written by a tasklet can be safely deleted + assert write_size.is_constant() + remove_last_node = True + elif isinstance(write_edge.src, dace.nodes.NestedSDFG): + if write_size.is_constant(): + # Temporary data with compile-time size is allocated on the stack + # and therefore is safe to keep. We decide to keep it as a workaround + # for a dace issue with memlet propagation in combination with + # nested SDFGs containing conditional blocks. The output memlet + # of such blocks will be marked as dynamic because dace is not able + # to detect the exact size of a conditional branch dataflow, even + # in case of if-else expressions with exact same output data. + remove_last_node = False + else: + # In case the output data has runtime size it is necessary to remove + # it in order to avoid dynamic memory allocation inside a parallel + # map scope. Otherwise, the memory allocation will for sure lead + # to performance degradation, and eventually illegal memory issues + # when the gpu runs out of local memory. + remove_last_node = True + else: + remove_last_node = False + + if remove_last_node: + last_node = write_edge.src + last_node_connector = write_edge.src_conn self.state.remove_node(self.result.dc_node) else: last_node = self.result.dc_node From 1176b2d4bf733f1f43b36c12d837904a9e2b52ad Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 20 Feb 2025 19:40:00 +0100 Subject: [PATCH 153/178] fix[cartesian]: DataType.isinteger() for 16-bit integers (#1878) ## Description 16-bit integers were missing in the set of data types that return true for `DataType.isinteger()`. Added missing test coverage. ## Requirements - [x] All fixes and/or new features come with corresponding tests. - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- src/gt4py/cartesian/gtc/common.py | 2 +- .../unit_tests/test_gtc/test_common.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index ef38a9a658..60236a3e97 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -118,7 +118,7 @@ def isbool(self): return self == self.BOOL def isinteger(self): - return self in (self.INT8, self.INT32, self.INT64) + return self in (self.INT8, self.INT16, self.INT32, self.INT64) def isfloat(self): return self in (self.FLOAT32, self.FLOAT64) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py index 68006c113b..4e799d2090 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py @@ -41,6 +41,24 @@ # - For testing non-leave nodes, introduce builders with defaults (for leave nodes as well) +def test_data_type_methods(): + for type in DataType: + if type == DataType.BOOL: + assert type.isbool() + else: + assert not type.isbool() + + if type in (DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64): + assert type.isinteger() + else: + assert not type.isinteger() + + if type in (DataType.FLOAT32, DataType.FLOAT64): + assert type.isfloat() + else: + assert not type.isfloat() + + class DummyExpr(Expr): """Fake expression for cases where a concrete expression is not needed.""" From 198469177aa7f7cd493589fa155ec65bd74fd5dc Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 21 Feb 2025 15:02:20 +0100 Subject: [PATCH 154/178] fix[cartesian, dace]: warn about missing support for casting in variable k offsets (#1882) ## Description We figured that DaCe backends are currently missing support for casting in variable k offsets. This PR - adds a codegen test with a cast in a variable k offset - adds a node validator for the DaCe backends complaining about missing for support. - adds an `xfail` test for the node validator This should be fixed down the road. Here's the issue https://github.com/GridTools/gt4py/issues/1881 to keep track. The PR also has two smaller and unrelated commits - 741c448f5258fccbca942a6cc9548c7554e454c9 increases test coverage with another codgen test that has a couple of read after write access patterns which were breaking the "new bridge" (see https://github.com/GEOS-ESM/NDSL/issues/53). - e98ddc54f8571d8d24d2169a421955c4b4e795e1 just forwards all keyword arguments when visiting offsets. I don't think this was a problem until now, but it's best practice to forward everything. ## Requirements - [x] All fixes and/or new features come with corresponding tests. - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Co-authored-by: Florian Deconinck --- src/gt4py/cartesian/gtc/dace/daceir.py | 8 +- .../gtc/dace/expansion/tasklet_codegen.py | 4 +- tests/cartesian_tests/definitions.py | 6 ++ .../test_code_generation.py | 94 ++++++++++++++++++- 4 files changed, 109 insertions(+), 3 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/daceir.py b/src/gt4py/cartesian/gtc/dace/daceir.py index 492a9598c5..43a33fdd6d 100644 --- a/src/gt4py/cartesian/gtc/dace/daceir.py +++ b/src/gt4py/cartesian/gtc/dace/daceir.py @@ -734,7 +734,13 @@ class ScalarAccess(common.ScalarAccess, Expr): class VariableKOffset(common.VariableKOffset[Expr]): - pass + @datamodels.validator("k") + def no_casts_in_offset_expression(self, _: datamodels.Attribute, expression: Expr) -> None: + for part in expression.walk_values(): + if isinstance(part, Cast): + raise ValueError( + "DaCe backends are currently missing support for casts in variable k offsets. See issue https://github.com/GridTools/gt4py/issues/1881." + ) class IndexAccess(common.FieldAccess, Expr): diff --git a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py index 8033c64710..2948b9d76d 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py @@ -44,7 +44,9 @@ def _visit_offset( else: int_sizes.append(None) sym_offsets = [ - dace.symbolic.pystr_to_symbolic(self.visit(off, **kwargs)) + dace.symbolic.pystr_to_symbolic( + self.visit(off, access_info=access_info, decl=decl, **kwargs) + ) for off in (node.to_dict()["i"], node.to_dict()["j"], node.k) ] for axis in access_info.variable_offset_axes: diff --git a/tests/cartesian_tests/definitions.py b/tests/cartesian_tests/definitions.py index 7499ad4a95..4d52b9b773 100644 --- a/tests/cartesian_tests/definitions.py +++ b/tests/cartesian_tests/definitions.py @@ -51,6 +51,12 @@ def _get_backends_with_storage_info(storage_info_kind: str): _PERFORMANCE_BACKEND_NAMES = [name for name in _ALL_BACKEND_NAMES if name not in ("numpy", "cuda")] PERFORMANCE_BACKENDS = [_backend_name_as_param(name) for name in _PERFORMANCE_BACKEND_NAMES] +DACE_BACKENDS = [ + _backend_name_as_param(name) + for name in filter(lambda name: name.startswith("dace:"), _ALL_BACKEND_NAMES) +] +NON_DACE_BACKENDS = [backend for backend in ALL_BACKENDS if backend not in DACE_BACKENDS] + @pytest.fixture() def id_version(): diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 8ace0de740..8e5f3466d0 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -27,7 +27,13 @@ ) from gt4py.storage.cartesian import utils as storage_utils -from cartesian_tests.definitions import ALL_BACKENDS, CPU_BACKENDS, get_array_library +from cartesian_tests.definitions import ( + ALL_BACKENDS, + CPU_BACKENDS, + DACE_BACKENDS, + NON_DACE_BACKENDS, + get_array_library, +) from cartesian_tests.integration_tests.multi_feature_tests.stencil_definitions import ( EXTERNALS_REGISTRY as externals_registry, REGISTRY as stencil_definitions, @@ -762,3 +768,89 @@ def test( out_arr = gt_storage.ones(backend=backend, shape=domain, dtype=np.float64) test(in_arr, out_arr) assert (out_arr[:, :, :] == 388.0).all() + + +@pytest.mark.parametrize("backend", NON_DACE_BACKENDS) +def test_cast_in_index(backend): + @gtscript.stencil(backend) + def cast_in_index( + in_field: Field[np.float64], i32: np.int32, i64: np.int64, out_field: Field[np.float64] + ): + """Simple copy stencil with forced cast in index calculation.""" + with computation(PARALLEL), interval(...): + out_field = in_field[0, 0, i32 - i64] + + +@pytest.mark.parametrize("backend", DACE_BACKENDS) +@pytest.mark.xfail(raises=ValueError) +def test_dace_no_cast_in_index(backend): + @gtscript.stencil(backend) + def cast_in_index( + in_field: Field[np.float64], i32: np.int32, i64: np.int64, out_field: Field[np.float64] + ): + """Simple copy stencil with forced cast in index calculation.""" + with computation(PARALLEL), interval(...): + out_field = in_field[0, 0, i32 - i64] + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_read_after_write_stencil(backend): + """Stencil with multiple read after write access patterns.""" + + @gtscript.stencil(backend=backend) + def lagrangian_contributions( + q: Field[np.float64], + pe1: Field[np.float64], + pe2: Field[np.float64], + q4_1: Field[np.float64], + q4_2: Field[np.float64], + q4_3: Field[np.float64], + q4_4: Field[np.float64], + dp1: Field[np.float64], + lev: gtscript.Field[gtscript.IJ, np.int64], + ): + """ + Args: + q (out): + pe1 (in): + pe2 (in): + q4_1 (in): + q4_2 (in): + q4_3 (in): + q4_4 (in): + dp1 (in): + lev (inout): + """ + with computation(FORWARD), interval(...): + pl = (pe2 - pe1[0, 0, lev]) / dp1[0, 0, lev] + if pe2[0, 0, 1] <= pe1[0, 0, lev + 1]: + pr = (pe2[0, 0, 1] - pe1[0, 0, lev]) / dp1[0, 0, lev] + q = ( + q4_2[0, 0, lev] + + 0.5 * (q4_4[0, 0, lev] + q4_3[0, 0, lev] - q4_2[0, 0, lev]) * (pr + pl) + - q4_4[0, 0, lev] * 1.0 / 3.0 * (pr * (pr + pl) + pl * pl) + ) + else: + qsum = (pe1[0, 0, lev + 1] - pe2) * ( + q4_2[0, 0, lev] + + 0.5 * (q4_4[0, 0, lev] + q4_3[0, 0, lev] - q4_2[0, 0, lev]) * (1.0 + pl) + - q4_4[0, 0, lev] * 1.0 / 3.0 * (1.0 + pl * (1.0 + pl)) + ) + lev = lev + 1 + while pe1[0, 0, lev + 1] < pe2[0, 0, 1]: + qsum += dp1[0, 0, lev] * q4_1[0, 0, lev] + lev = lev + 1 + dp = pe2[0, 0, 1] - pe1[0, 0, lev] + esl = dp / dp1[0, 0, lev] + qsum += dp * ( + q4_2[0, 0, lev] + + 0.5 + * esl + * ( + q4_3[0, 0, lev] + - q4_2[0, 0, lev] + + q4_4[0, 0, lev] * (1.0 - (2.0 / 3.0) * esl) + ) + ) + q = qsum / (pe2[0, 0, 1] - pe2) + lev = lev - 1 From 3249aa5eb20980e3a6f4d0810577ecc30e177776 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 25 Feb 2025 17:09:55 +0100 Subject: [PATCH 155/178] test[next]: initialize output array in dace-gtir subdomain tests (#1886) --- .../runners_tests/dace_tests/test_gtir_to_sdfg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 8ebb240339..030aa9b131 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1846,7 +1846,7 @@ def test_gtir_let_lambda_with_tuple1(): sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) - z_fields = (np.empty_like(a), np.empty_like(a)) + z_fields = (np.zeros_like(a), np.zeros_like(a)) a_ref = np.concatenate((z_fields[0][:1], a[1 : N - 1], z_fields[0][N - 1 :])) b_ref = np.concatenate((z_fields[1][:1], b[1 : N - 1], z_fields[1][N - 1 :])) @@ -2037,7 +2037,7 @@ def test_gtir_index(): ], ) - v = np.empty(N, dtype=np.int32) + v = np.zeros(N, dtype=np.int32) # we need to run domain inference in order to add the domain annex information to the index node. testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) From 32c569b199f8e6946f8ffd31b3192c4b8436a46a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Wed, 26 Feb 2025 16:39:59 +0100 Subject: [PATCH 156/178] fix[dace][next]: Updated `gt_auto_optimize()` (#1889) The problem was that in the auto optimizer the "set stride correctly" function was called after the "go to GPU transformation". The GPU transformation will call `CopyToMap` transformation, this is needed such that we can set the GPU block size and their order. However, the decision if we call `CopyToMap` also depends on the strides, so we need to set them first. --------- Co-authored-by: edopao --- .../dace/transformations/auto_optimize.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 684ec9e764..739fe39584 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -66,10 +66,11 @@ def gt_auto_optimize( one with stride one. 5. If requested the function will now apply loop blocking, on the dimension indicated by `leading_dim`. - 6. If requested the SDFG will be transformed to GPU. For this the + 6. The strides of temporaries are set to match the compute order. + 7. If requested the SDFG will be transformed to GPU. For this the `gt_gpu_transformation()` function is used, that might apply several other optimizations. - 7. Afterwards some general transformations to the SDFG are applied. + 8. Afterwards some general transformations to the SDFG are applied. This includes: - Use fast implementation for library nodes. - Move small transients to stack. @@ -235,7 +236,13 @@ def gt_auto_optimize( validate_all=validate_all, ) - # Phase 6: Going to GPU + # Phase 6: Setting the strides of transients + # It is important that we set the strides before the GPU transformation. + # Because this transformation will also apply `CopyToMap` for the Memlets + # that the DaCe runtime can not handle. + gtx_transformations.gt_change_transient_strides(sdfg, gpu=gpu) + + # Phase 7: Going to GPU if gpu: # TODO(phimuell): The GPU function might modify the map iteration order. # This is because how it is implemented (promotion and @@ -251,7 +258,7 @@ def gt_auto_optimize( try_removing_trivial_maps=True, ) - # Phase 7: General Optimizations + # Phase 8: General Optimizations # The following operations apply regardless if we have a GPU or CPU. # The DaCe auto optimizer also uses them. Note that the reuse transient # is not done by DaCe. @@ -267,9 +274,6 @@ def gt_auto_optimize( # TODO(phimuell): Fix the bug, it uses the tile value and not the stack array value. dace_aoptimize.move_small_arrays_to_stack(sdfg) - # Now we modify the strides. - gtx_transformations.gt_change_transient_strides(sdfg, gpu=gpu) - if make_persistent: gtx_transformations.gt_make_transients_persistent(sdfg=sdfg, device=device) From 194d27a7769dee705bf6a88084bed8e67f1dfee7 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Wed, 26 Feb 2025 17:07:18 +0100 Subject: [PATCH 157/178] Remove the occurences of the old license headers (#1887) --- src/gt4py/next/ffront/signature.py | 14 -------------- src/gt4py/next/otf/arguments.py | 14 -------------- .../iterator_tests/test_if_stmt.py | 19 +++---------------- 3 files changed, 3 insertions(+), 44 deletions(-) diff --git a/src/gt4py/next/ffront/signature.py b/src/gt4py/next/ffront/signature.py index 9752ceaf32..4a58d56f57 100644 --- a/src/gt4py/next/ffront/signature.py +++ b/src/gt4py/next/ffront/signature.py @@ -6,20 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - # TODO(ricoh): This overlaps with `canonicalize_arguments`, solutions: # - merge the two # - extract the signature gathering functionality from canonicalize_arguments diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index c4235eaa9a..a9b52a49d0 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -6,20 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - from __future__ import annotations import dataclasses diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py index c38a29bc61..3c2ac6e7d7 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py @@ -6,30 +6,17 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later import numpy as np import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import cartesian_domain, deref, as_fieldop, named_range -from gt4py.next.iterator.runtime import set_at, if_stmt, fendef, fundef, offset -from gt4py.next.program_processors.runners import gtfn +from gt4py.next.iterator.builtins import as_fieldop, cartesian_domain, deref, named_range +from gt4py.next.iterator.runtime import fendef, fundef, if_stmt, offset, set_at from next_tests.unit_tests.conftest import program_processor_no_transforms, run_processor + i = offset("i") From 3ba7e627fabcfc6c7b7240ea5258da676533b0c4 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 26 Feb 2025 17:07:57 +0100 Subject: [PATCH 158/178] ci: Fail CI if xfails pass unexpectedly (#1888) Be strict about tests marked with xfail. If they pass unexpectedly, fail the testsuite instead of logging an XPASS. --- pyproject.toml | 2 ++ tests/next_tests/definitions.py | 5 +++++ .../feature_tests/ffront_tests/test_execution.py | 4 +--- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1efce6bd29..e182b23878 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -319,6 +319,7 @@ markers = [ 'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields', 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', 'uses_tuple_args: tests that require backend support for tuple arguments', + 'uses_tuple_args_with_different_but_promotable_dims: test that requires backend support for tuple args with different but promotable dims', 'uses_tuple_iterator: tests that require backend support to deref tuple iterators', 'uses_tuple_returns: tests that require backend support for tuple results', 'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields', @@ -330,6 +331,7 @@ markers = [ ] norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*'] testpaths = 'tests' +xfail_strict = true # -- ruff -- [tool.ruff] diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index b412c0c273..1f81076abf 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -116,6 +116,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields" USES_STRIDED_NEIGHBOR_OFFSET = "uses_strided_neighbor_offset" USES_TUPLE_ARGS = "uses_tuple_args" +USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS = ( + "uses_tuple_args_with_different_but_promotable_dims" +) USES_TUPLE_ITERATOR = "uses_tuple_iterator" USES_TUPLE_RETURNS = "uses_tuple_returns" USES_ZERO_DIMENSIONAL_FIELDS = "uses_zero_dimensional_fields" @@ -139,6 +142,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS, XFAIL, UNSUPPORTED_MESSAGE), ] # Markers to skip because of missing features in the domain inference DOMAIN_INFERENCE_SKIP_LIST = [ @@ -168,6 +172,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ] ROUNDTRIP_SKIP_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS, XFAIL, UNSUPPORTED_MESSAGE), ] GTFN_SKIP_TEST_LIST = ( COMMON_SKIP_TEST_LIST diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index d878d8d3ff..a042c60709 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -262,9 +262,7 @@ def testee(a: tuple[int32, tuple[int32, cases.IField, int32]]) -> cases.IField: @pytest.mark.uses_tuple_args -@pytest.mark.xfail( - reason="Not implemented in frontend (implicit size arg handling needs to be adopted) and GTIR embedded backend." -) +@pytest.mark.uses_tuple_args_with_different_but_promotable_dims def test_tuple_arg_with_different_but_promotable_dims(cartesian_case): @gtx.field_operator def testee(a: tuple[cases.IField, cases.IJField]) -> cases.IJField: From 847d8abad9df8b9ad6df44693ee15eca35e9ea68 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 27 Feb 2025 09:11:45 +0100 Subject: [PATCH 159/178] test[cartesian]: conditional xfail instead of two tests (#1885) Follow-up from PR https://github.com/GridTools/gt4py/pull/1882 to merge two test cases which only differed in the expected outcome. With this PR, we'll have one single test case with a conditional `xfail` as @havogt proposed in the GridTools slack channel. Related issue: https://github.com/GridTools/gt4py/issues/1881 --- tests/cartesian_tests/definitions.py | 12 +++------ .../test_code_generation.py | 27 +++++++++---------- 2 files changed, 16 insertions(+), 23 deletions(-) diff --git a/tests/cartesian_tests/definitions.py b/tests/cartesian_tests/definitions.py index 4d52b9b773..38cb6caca8 100644 --- a/tests/cartesian_tests/definitions.py +++ b/tests/cartesian_tests/definitions.py @@ -14,7 +14,6 @@ cp = None import datetime - import numpy as np import pytest @@ -22,7 +21,7 @@ from gt4py.cartesian import utils as gt_utils -def _backend_name_as_param(name): +def _backend_name_as_param(name: str): marks = [] if gt4pyc.backend.from_name(name).storage_info["device"] == "gpu": marks.append(pytest.mark.requires_gpu) @@ -48,14 +47,9 @@ def _get_backends_with_storage_info(storage_info_kind: str): GPU_BACKENDS = _get_backends_with_storage_info("gpu") ALL_BACKENDS = CPU_BACKENDS + GPU_BACKENDS -_PERFORMANCE_BACKEND_NAMES = [name for name in _ALL_BACKEND_NAMES if name not in ("numpy", "cuda")] -PERFORMANCE_BACKENDS = [_backend_name_as_param(name) for name in _PERFORMANCE_BACKEND_NAMES] - -DACE_BACKENDS = [ - _backend_name_as_param(name) - for name in filter(lambda name: name.startswith("dace:"), _ALL_BACKEND_NAMES) +PERFORMANCE_BACKENDS = [ + _backend_name_as_param(name) for name in _ALL_BACKEND_NAMES if name not in ("numpy", "cuda") ] -NON_DACE_BACKENDS = [backend for backend in ALL_BACKENDS if backend not in DACE_BACKENDS] @pytest.fixture() diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 8e5f3466d0..66d45abe21 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -30,8 +30,6 @@ from cartesian_tests.definitions import ( ALL_BACKENDS, CPU_BACKENDS, - DACE_BACKENDS, - NON_DACE_BACKENDS, get_array_library, ) from cartesian_tests.integration_tests.multi_feature_tests.stencil_definitions import ( @@ -770,20 +768,21 @@ def test( assert (out_arr[:, :, :] == 388.0).all() -@pytest.mark.parametrize("backend", NON_DACE_BACKENDS) -def test_cast_in_index(backend): - @gtscript.stencil(backend) - def cast_in_index( - in_field: Field[np.float64], i32: np.int32, i64: np.int64, out_field: Field[np.float64] - ): - """Simple copy stencil with forced cast in index calculation.""" - with computation(PARALLEL), interval(...): - out_field = in_field[0, 0, i32 - i64] +def _xfail_dace_backends(param): + if param.values[0].startswith("dace:"): + marks = param.marks + [ + pytest.mark.xfail( + raises=ValueError, + reason="Missing support in DaCe backends, see https://github.com/GridTools/gt4py/issues/1881.", + ) + ] + # make a copy because otherwise we are operating in-place + return pytest.param(*param.values, marks=marks) + return param -@pytest.mark.parametrize("backend", DACE_BACKENDS) -@pytest.mark.xfail(raises=ValueError) -def test_dace_no_cast_in_index(backend): +@pytest.mark.parametrize("backend", map(_xfail_dace_backends, ALL_BACKENDS)) +def test_cast_in_index(backend): @gtscript.stencil(backend) def cast_in_index( in_field: Field[np.float64], i32: np.int32, i64: np.int64, out_field: Field[np.float64] From 587d10723c38fb3914be8629dcb3debcb654eff0 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 27 Feb 2025 10:22:16 +0100 Subject: [PATCH 160/178] bug[next]: Respect evaluation order in `InlineCenterDerefLiftVars` (#1883) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changes the `InlineCenterDerefLiftVars` pass to respect evaluation order by lazily evaluating the inlined values. Consider the following case that is common in boundary conditions: ``` let(var, (↑deref)(it2))(if ·on_bc then 0 else ·var) ``` Then var should only be dereferenced in case `·on_bc` evalutes to False. Previously we just evaluated all values unconditionally: ``` let(_icdlv_1, ·it)(if ·on_bc then 0 else _icdlv_1) ``` Now we instead create a 0-ary lambda function for `_icdlv_1` and evaluate it when the value is needed. ``` let(_icdlv_1, λ() → ·it)(if ·on_bc then 0 else _icdlv_1()) ``` Note that as a result we do evaluate the function multiple times. To avoid redundant recompuations usage of the common subexpression elimination is required. --- .../inline_center_deref_lift_vars.py | 38 +++++++++++++------ .../transforms_tests/test_fuse_as_fieldop.py | 8 ++-- .../test_inline_center_deref_lift_vars.py | 38 ++++++++++++++++--- 3 files changed, 63 insertions(+), 21 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py index 7bd26d0f19..c0a8c9f1b7 100644 --- a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py +++ b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py @@ -36,16 +36,25 @@ class InlineCenterDerefLiftVars(eve.NodeTranslator): `let(var, (↑stencil)(it))(·var + ·var)` Directly inlining `var` would increase the size of the tree and duplicate the calculation. - Instead, this pass computes the value at the current location once and replaces all previous - references to `var` by an applied lift which captures this value. + Instead this pass, first takes the iterator `(↑stencil)(it)` and transforms it into a + 0-ary function that evaluates to the value at the current location. - `let(_icdlv_1, stencil(it))(·(↑(λ() → _icdlv_1) + ·(↑(λ() → _icdlv_1))` + `λ() → ·(↑stencil)(it)` + + Then all previous occurences of `var` are replaced by this function. + + `let(_icdlv_1, λ() → ·(↑stencil)(it))(·(↑(λ() → _icdlv_1()) + ·(↑(λ() → _icdlv_1()))` The lift inliner can then later easily transform this into a nice expression: - `let(_icdlv_1, stencil(it))(_icdlv_1 + _icdlv_1)` + `let(_icdlv_1, λ() → stencil(it))(_icdlv_1() + _icdlv_1())` + + Finally, recomputation is avoided by using the common subexpression elimination and lamba + inlining (can be configured opcount preserving). Both is up to the caller to do later. + + `λ(_cs_1) → _cs_1 + _cs_1)(stencil(it))` - Note: This pass uses and preserves the `recorded_shifts` annex. + Note: This pass uses and preserves the `domain` and `recorded_shifts` annex. """ PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("domain", "recorded_shifts") @@ -78,20 +87,25 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): assert isinstance(node.fun, itir.Lambda) # to make mypy happy eligible_params = [False] * len(node.fun.params) new_args = [] - bound_scalars: dict[str, itir.Expr] = {} + # values are 0-ary lambda functions that evaluate to the derefed argument. We don't put + # the values themselves here as they might be inside of an if to protected from an oob + # access + evaluators: dict[str, itir.Expr] = {} for i, (param, arg) in enumerate(zip(node.fun.params, node.args)): if cpm.is_applied_lift(arg) and is_center_derefed_only(param): eligible_params[i] = True - bound_arg_name = self.uids.sequential_id(prefix="_icdlv") - capture_lift = im.promote_to_const_iterator(bound_arg_name) + bound_arg_evaluator = self.uids.sequential_id(prefix="_icdlv") + capture_lift = im.promote_to_const_iterator(im.call(bound_arg_evaluator)()) trace_shifts.copy_recorded_shifts(from_=param, to=capture_lift) new_args.append(capture_lift) # since we deref an applied lift here we can (but don't need to) immediately # inline - bound_scalars[bound_arg_name] = InlineLifts( - flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT - ).visit(im.deref(arg), recurse=False) + evaluators[bound_arg_evaluator] = im.lambda_()( + InlineLifts(flags=InlineLifts.Flag.INLINE_DEREF_LIFT).visit( + im.deref(arg), recurse=False + ) + ) else: new_args.append(arg) @@ -100,6 +114,6 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): im.call(node.fun)(*new_args), eligible_params=eligible_params ) # TODO(tehrengruber): propagate let outwards - return im.let(*bound_scalars.items())(new_node) + return im.let(*evaluators.items())(new_node) return node diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index fd884e239f..14aebd032c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -145,8 +145,8 @@ def test_staged_inlining(): ) expected = im.as_fieldop( im.lambda_("a", "b")( - im.let("_icdlv_1", im.plus(im.deref("a"), im.deref("b")))( - im.plus(im.plus("_icdlv_1", 1), im.plus("_icdlv_1", 2)) + im.let("_icdlv_1", im.lambda_()(im.plus(im.deref("a"), im.deref("b"))))( + im.plus(im.plus(im.call("_icdlv_1")(), 1), im.plus(im.call("_icdlv_1")(), 2)) ) ), d, @@ -328,8 +328,8 @@ def test_chained_fusion(): ) expected = im.as_fieldop( im.lambda_("inp1", "inp2")( - im.let("_icdlv_1", im.plus(im.deref("inp1"), im.deref("inp2")))( - im.plus("_icdlv_1", "_icdlv_1") + im.let("_icdlv_1", im.lambda_()(im.plus(im.deref("inp1"), im.deref("inp2"))))( + im.plus(im.call("_icdlv_1")(), im.call("_icdlv_1")()) ) ), d, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_center_deref_lift_vars.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_center_deref_lift_vars.py index 6cc2f7cd28..2caa887803 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_center_deref_lift_vars.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_center_deref_lift_vars.py @@ -6,20 +6,33 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import cse from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars +field_type = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) -def wrap_in_program(expr: itir.Expr) -> itir.Program: + +def wrap_in_program(expr: itir.Expr, *, arg_dtypes=None) -> itir.Program: + if arg_dtypes is None: + arg_dtypes = [ts.ScalarKind.FLOAT64] + arg_types = [ts.FieldType(dims=[], dtype=ts.ScalarType(kind=dtype)) for dtype in arg_dtypes] + indices = [i for i in range(1, len(arg_dtypes) + 1)] if len(arg_dtypes) > 1 else [""] return itir.Program( id="f", function_definitions=[], - params=[im.sym("d"), im.sym("inp"), im.sym("out")], + params=[ + *(im.sym(f"inp{i}", type_) for i, type_ in zip(indices, arg_types)), + im.sym("out", field_type), + ], declarations=[], body=[ itir.SetAt( - expr=im.as_fieldop(im.lambda_("it")(expr))(im.ref("inp")), + expr=im.as_fieldop(im.lambda_(*(f"it{i}" for i in indices))(expr))( + *(im.ref(f"inp{i}") for i in indices) + ), domain=im.call("cartesian_domain")(), target=im.ref("out"), ) @@ -34,7 +47,7 @@ def unwrap_from_program(program: itir.Program) -> itir.Expr: def test_simple(): testee = im.let("var", im.lift("deref")("it"))(im.deref("var")) - expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1))())(·it)" + expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1()))())(λ() → ·it)" actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee))) assert str(actual) == expected @@ -42,7 +55,7 @@ def test_simple(): def test_double_deref(): testee = im.let("var", im.lift("deref")("it"))(im.plus(im.deref("var"), im.deref("var"))) - expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1))() + ·(↑(λ() → _icdlv_1))())(·it)" + expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1()))() + ·(↑(λ() → _icdlv_1()))())(λ() → ·it)" actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee))) assert str(actual) == expected @@ -62,3 +75,18 @@ def test_deref_at_multiple_pos(): actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee))) assert testee == actual + + +def test_bc(): + # we also check that the common subexpression is able to extract the inlined value, such + # that it is only evaluated once + testee = im.let("var", im.lift("deref")("it2"))( + im.if_(im.deref("it1"), im.literal_from_value(0), im.plus(im.deref("var"), im.deref("var"))) + ) + expected = "(λ(_icdlv_1) → if ·it1 then 0 else (λ(_cs_1) → _cs_1 + _cs_1)(·(↑(λ() → _icdlv_1()))()))(λ() → ·it2)" + + actual = InlineCenterDerefLiftVars.apply( + wrap_in_program(testee, arg_dtypes=[ts.ScalarKind.BOOL, ts.ScalarKind.FLOAT64]) + ) + simplified = unwrap_from_program(cse.CommonSubexpressionElimination.apply(actual)) + assert str(simplified) == expected From 70569bc5e917bf4000f6543573aadc5d31c86c13 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Thu, 27 Feb 2025 16:36:04 +0100 Subject: [PATCH 161/178] fix[cartesian]: Fix minimal k-range computation (#1842) ## Description The computation of the minimal k-ranges that happen during the vaildate-args step is allowing for inconsistent computation to happen. This PR stiffens the requirements on fields: - intervals `[START+X ,END+y)` are now also considered: - `interval(3,-1)` requires a minimal size of 5 for the interval to not be empty - `interval(3,None)` now requires a minimal size of 4 - intervals `[START+X, START+Y)` and `[END+X,END+Y)` are not affected. - empty intervals are still allowed to have a 0-domain as to accomodate 2-dimensional fields - `interval(...)` still requires no k-size ## Requirements - [x] All fixes and/or new features come with corresponding tests. --------- Co-authored-by: Florian Deconinck --- .../cartesian/gtc/passes/gtir_k_boundary.py | 54 +++++++++++-------- .../stencil_definitions.py | 3 +- .../test_code_generation.py | 4 +- .../test_passes/test_min_k_interval.py | 36 +++++++++---- 4 files changed, 63 insertions(+), 34 deletions(-) diff --git a/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py b/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py index 96cec5b6d4..40c31dca53 100644 --- a/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py +++ b/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py @@ -41,20 +41,21 @@ def visit_FieldAccess( node: gtir.FieldAccess, vloop: gtir.VerticalLoop, field_boundaries: Dict[str, Tuple[Union[float, int], Union[float, int]]], - include_center_interval: bool, - **kwargs: Any, + **_: Any, ): boundary = field_boundaries[node.name] interval = vloop.interval if not isinstance(node.offset, gtir.VariableKOffset): - if interval.start.level == LevelMarker.START and ( - include_center_interval or interval.end.level == LevelMarker.START - ): - boundary = (max(-interval.start.offset - node.offset.k, boundary[0]), boundary[1]) - if ( - include_center_interval or interval.start.level == LevelMarker.END - ) and interval.end.level == LevelMarker.END: - boundary = (boundary[0], max(interval.end.offset + node.offset.k, boundary[1])) + if interval.start.level == LevelMarker.START: + boundary = ( + max(-interval.start.offset - node.offset.k, boundary[0]), + boundary[1], + ) + if interval.end.level == LevelMarker.END: + boundary = ( + boundary[0], + max(interval.end.offset + node.offset.k, boundary[1]), + ) if node.name in [decl.name for decl in vloop.temporaries] and ( boundary[0] > 0 or boundary[1] > 0 ): @@ -63,24 +64,35 @@ def visit_FieldAccess( field_boundaries[node.name] = boundary -def compute_k_boundary( - node: gtir.Stencil, include_center_interval=True -) -> Dict[str, Tuple[int, int]]: +def compute_k_boundary(node: gtir.Stencil) -> Dict[str, Tuple[int, int]]: # loop from START to END is not considered as it might be empty. additional check possible in the future - return KBoundaryVisitor().visit(node, include_center_interval=include_center_interval) + return KBoundaryVisitor().visit(node) -def compute_min_k_size(node: gtir.Stencil, include_center_interval=True) -> int: +def compute_min_k_size(node: gtir.Stencil) -> int: """Compute the required number of k levels to run a stencil.""" + min_size_start = 0 min_size_end = 0 + biggest_offset = 0 for vloop in node.vertical_loops: - if vloop.interval.start.level == LevelMarker.START and ( - include_center_interval or vloop.interval.end.level == LevelMarker.START + if ( + vloop.interval.start.level == LevelMarker.START + and vloop.interval.end.level == LevelMarker.END ): - min_size_start = max(min_size_start, vloop.interval.end.offset) + if not (vloop.interval.start.offset == 0 and vloop.interval.end.offset == 0): + biggest_offset = max( + biggest_offset, + vloop.interval.start.offset - vloop.interval.end.offset + 1, + ) elif ( - include_center_interval or vloop.interval.start.level == LevelMarker.END - ) and vloop.interval.end.level == LevelMarker.END: + vloop.interval.start.level == LevelMarker.START + and vloop.interval.end.level == LevelMarker.START + ): + min_size_start = max(min_size_start, vloop.interval.end.offset) + biggest_offset = max(biggest_offset, vloop.interval.end.offset) + else: min_size_end = max(min_size_end, -vloop.interval.start.offset) - return min_size_start + min_size_end + biggest_offset = max(biggest_offset, -vloop.interval.start.offset) + minimal_size = max(min_size_start + min_size_end, biggest_offset) + return minimal_size diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py index 8112866092..217c0ee488 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py @@ -295,7 +295,8 @@ def large_k_interval(in_field: Field3D, out_field: Field3D): with computation(PARALLEL): with interval(0, 6): out_field = in_field - with interval(6, -10): # this stage will only run if field has more than 16 elements + # this stenicl is only legal to call with fields that have more than 16 elements + with interval(6, -10): out_field = in_field + 1 with interval(-10, None): out_field = in_field diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 66d45abe21..4e0fa8903c 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -56,8 +56,8 @@ def test_generation(name, backend): ) else: args[k] = v(1.5) - # vertical domain size >= 16 required for test_large_k_interval - stencil(**args, origin=(10, 10, 5), domain=(3, 3, 16)) + # vertical domain size > 16 required for test_large_k_interval + stencil(**args, origin=(10, 10, 5), domain=(3, 3, 17)) @pytest.mark.parametrize("backend", ALL_BACKENDS) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_min_k_interval.py b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_min_k_interval.py index 078adcc8da..6bb4ec63f6 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_min_k_interval.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_min_k_interval.py @@ -16,7 +16,10 @@ from gt4py import cartesian as gt4pyc from gt4py.cartesian import gtscript as gs from gt4py.cartesian.backend import from_name -from gt4py.cartesian.gtc.passes.gtir_k_boundary import compute_k_boundary, compute_min_k_size +from gt4py.cartesian.gtc.passes.gtir_k_boundary import ( + compute_k_boundary, + compute_min_k_size, +) from gt4py.cartesian.gtc.passes.gtir_pipeline import prune_unused_parameters from gt4py.cartesian.gtscript import PARALLEL, computation, interval, stencil from gt4py.cartesian.stencil_builder import StencilBuilder @@ -48,21 +51,21 @@ def stencil_no_extent_0(field_a: gs.Field[float], field_b: gs.Field[float]): field_a = field_b[0, 0, 0] -@register_test_case(k_bounds=(max(0, -2), 0), min_k_size=2) +@register_test_case(k_bounds=(0, 0), min_k_size=2) @typing.no_type_check def stencil_no_extent_1(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(0, 2): field_a = field_b[0, 0, 0] -@register_test_case(k_bounds=(max(-1, -2), 0), min_k_size=2) +@register_test_case(k_bounds=(-1, 0), min_k_size=2) @typing.no_type_check def stencil_no_extent_2(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(1, 2): field_a = field_b[0, 0, 0] -@register_test_case(k_bounds=(max(max(0, -2), max(-2, -2)), 0), min_k_size=3) +@register_test_case(k_bounds=(0, 0), min_k_size=4) @typing.no_type_check def stencil_no_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(0, 2): @@ -73,14 +76,14 @@ def stencil_no_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]): field_a = field_b[0, 0, 0] -@register_test_case(k_bounds=(0, max(-1, 0)), min_k_size=1) +@register_test_case(k_bounds=(0, 0), min_k_size=1) @typing.no_type_check def stencil_no_extent_4(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(-1, None): field_a = field_b[0, 0, 0] -@register_test_case(k_bounds=(max(0, -1), max(-2, 0)), min_k_size=3) +@register_test_case(k_bounds=(0, 0), min_k_size=3) @typing.no_type_check def stencil_no_extent_5(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(0, 1): @@ -89,6 +92,13 @@ def stencil_no_extent_5(field_a: gs.Field[float], field_b: gs.Field[float]): field_a = field_b[0, 0, 0] +@register_test_case(k_bounds=(-1, -2), min_k_size=4) +@typing.no_type_check +def stencil_no_extent_6(field_a: gs.Field[float], field_b: gs.Field[float]): + with computation(PARALLEL), interval(1, -2): + field_a[0, 0, 0] = field_b[0, 0, 0] + + # stencils with extent @register_test_case(k_bounds=(5, -5), min_k_size=0) @typing.no_type_check @@ -111,7 +121,7 @@ def stencil_with_extent_2(field_a: gs.Field[float], field_b: gs.Field[float]): field_a = field_b[0, 0, 5] -@register_test_case(k_bounds=(3, -3), min_k_size=3) +@register_test_case(k_bounds=(3, -3), min_k_size=4) @typing.no_type_check def stencil_with_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(0, 2): @@ -122,7 +132,7 @@ def stencil_with_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]): field_a = field_b[0, 0, -3] -@register_test_case(k_bounds=(-5, 5), min_k_size=1) +@register_test_case(k_bounds=(-5, 5), min_k_size=2) @typing.no_type_check def stencil_with_extent_4(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(0, -1): @@ -171,7 +181,10 @@ def test_min_k_size(definition, expected_min_k_size): @pytest.mark.parametrize("definition,expected", test_data) def test_k_bounds_exec(definition, expected): - expected_k_bounds, expected_min_k_size = expected["k_bounds"], expected["min_k_size"] + expected_k_bounds, expected_min_k_size = ( + expected["k_bounds"], + expected["min_k_size"], + ) required_field_size = expected_min_k_size + expected_k_bounds[0] + expected_k_bounds[1] @@ -234,7 +247,10 @@ def stencil_with_invalid_temporary_access_end(field_a: gs.Field[float], field_b: @pytest.mark.parametrize( "definition", - [stencil_with_invalid_temporary_access_start, stencil_with_invalid_temporary_access_end], + [ + stencil_with_invalid_temporary_access_start, + stencil_with_invalid_temporary_access_end, + ], ) def test_invalid_temporary_access(definition): builder = StencilBuilder(definition, backend=from_name("numpy")) From 9d1bdd6000bf5cc484f84bcfb96c33da6204acfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Fri, 28 Feb 2025 15:12:09 +0100 Subject: [PATCH 162/178] feat[next][dace]: Added `MultiStateGlobalSelfCopyElimination` (#1890) The transformation `MultiStateGlobalSelfCopyElimination` is very similar to the `GT4PyGlobalSelfCopyElimination` but they target slightly different cases. The transformation `GT4PyGlobalSelfCopyElimination`, which is already there and was renamed to `SingleStateGlobalSelfCopyElimination`, handles the pattern `(G) -> (T) -> (G)`, i.e. the global data `G` is copied into the transient `T` is then immediately copied back into `G`. Because of ADR-18 we know that this has no effect, because `G` is used as input and output and must therefore be point wise, so `G[i, j]` in the output can only be `G[i, j]` at the beginning. The new transformation `MultiStateGlobalSelfCopyElimination` handles a different case, it looks for patterns `(G) -> (T)` and `(T) -> (G)`, which is essentially the same, but this time the definition of `T` and the write back of `T` into `G` does not need to be in the same state. In the long run, the two transformation should be combined. --------- Co-authored-by: edopao --- .../runners/dace/transformations/__init__.py | 10 +- .../redundant_array_removers.py | 564 ++++++++++++++++++ .../runners/dace/transformations/simplify.py | 167 +----- ...ulti_state_global_self_copy_elimination.py | 514 ++++++++++++++++ ...gle_state_global_self_copy_elimination.py} | 8 +- 5 files changed, 1116 insertions(+), 147 deletions(-) create mode 100644 src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multi_state_global_self_copy_elimination.py rename tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/{test_global_self_copy_elimination.py => test_single_state_global_self_copy_elimination.py} (93%) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index df48d35d39..81ecb107cb 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -24,9 +24,13 @@ from .map_fusion import MapFusion, MapFusionParallel, MapFusionSerial from .map_orderer import MapIterationOrder, gt_set_iteration_order from .map_promoter import SerialMapPromoter +from .redundant_array_removers import ( + MultiStateGlobalSelfCopyElimination, + SingleStateGlobalSelfCopyElimination, + gt_multi_state_global_self_copy_elimination, +) from .simplify import ( GT_SIMPLIFY_DEFAULT_SKIP_SET, - GT4PyGlobalSelfCopyElimination, GT4PyMapBufferElimination, GT4PyMoveTaskletIntoMap, gt_inline_nested_sdfg, @@ -47,7 +51,6 @@ __all__ = [ "GT_SIMPLIFY_DEFAULT_SKIP_SET", "GPUSetBlockSize", - "GT4PyGlobalSelfCopyElimination", "GT4PyMapBufferElimination", "GT4PyMoveTaskletIntoMap", "LoopBlocking", @@ -55,8 +58,10 @@ "MapFusionParallel", "MapFusionSerial", "MapIterationOrder", + "MultiStateGlobalSelfCopyElimination", "SerialMapPromoter", "SerialMapPromoterGPU", + "SingleStateGlobalSelfCopyElimination", "gt_auto_optimize", "gt_change_transient_strides", "gt_create_local_double_buffering", @@ -67,6 +72,7 @@ "gt_make_transients_persistent", "gt_map_strides_to_dst_nested_sdfg", "gt_map_strides_to_src_nested_sdfg", + "gt_multi_state_global_self_copy_elimination", "gt_propagate_strides_from_access_node", "gt_propagate_strides_of", "gt_reduce_distributed_buffering", diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py b/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py new file mode 100644 index 0000000000..5a0e117b21 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py @@ -0,0 +1,564 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any, Optional + +import dace +from dace import ( + data as dace_data, + properties as dace_properties, + transformation as dace_transformation, +) +from dace.sdfg import nodes as dace_nodes +from dace.transformation import pass_pipeline as dace_ppl + +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations + + +def gt_multi_state_global_self_copy_elimination( + sdfg: dace.SDFG, + validate: bool = False, +) -> Optional[dict[dace.SDFG, set[str]]]: + """Runs `MultiStateGlobalSelfCopyElimination` on the SDFG recursively. + + For the return value see `MultiStateGlobalSelfCopyElimination.apply_pass()`. + """ + pipeline = dace_ppl.Pipeline([gtx_transformations.MultiStateGlobalSelfCopyElimination()]) + res = pipeline.apply_pass(sdfg, {}) + + if validate: + sdfg.validate() + + if "MultiStateGlobalSelfCopyElimination" not in res: + return None + return res["MultiStateGlobalSelfCopyElimination"][sdfg] + + +@dace_properties.make_properties +class MultiStateGlobalSelfCopyElimination(dace_transformation.Pass): + """Removes self copying across different states. + + This transformation is very similar to `SingleStateGlobalSelfCopyElimination`, but + addresses a slightly different case. Assume we have the pattern `(G) -> (T)` + in one state, i.e. the global data `G` is copied into a transient. In another + state, we have the pattern `(T) -> (G)`, i.e. the data is written back. + + If the following conditions are satisfied, this transformation will remove all + writes to `G`: + - The only write access to `G` happens in the `(T) -> (G)` pattern. ADR-18 + guarantees, that if `G` is used as an input and output it must be pointwise. + Thus there is no weird shifting. + + If the only usage of `T` is to write into `G` then the transient `T` will be + removed. + + Note that this transformation does not consider the subsets of the writes from + `T` to `G` because ADR-18 guarantees to us, that _if_ `G` is a genuine input + and output, then the `G` read and write subsets have the exact same range. + If `G` is not an output then any mutating changes to `G` would be invalid. + + Todo: + - Implement the pattern `(G) -> (T) -> (G)` which is handled currently by + `SingleStateGlobalSelfCopyElimination`, see `_classify_candidate()` and + `_remove_writes_to_global()` for more. + - Make it more efficient such that the SDFG is not scanned multiple times. + """ + + def modifies(self) -> dace_ppl.Modifies: + return dace_ppl.Modifies.Memlets | dace_ppl.Modifies.AccessNodes + + def should_reapply(self, modified: dace_ppl.Modifies) -> bool: + return modified & (dace_ppl.Modifies.Memlets | dace_ppl.Modifies.AccessNodes) + + def depends_on(self) -> set[type[dace_transformation.Pass]]: + return { + dace_transformation.passes.FindAccessStates, + } + + def apply_pass( + self, sdfg: dace.SDFG, pipeline_results: dict[str, Any] + ) -> Optional[dict[dace.SDFG, set[str]]]: + """Applies the pass. + + The function will return a `dict` that contains for every SDFG, the name + of the processed data descriptors. If a name refers to a global memory, + then it means that all write backs, i.e. `(T) -> (G)` patterns, have + been removed for that `G`. If the name refers to a data descriptor that no + longer exists, then it means that the write `(G) -> (T)` was also eliminated. + Currently there is no possibility to identify which transient name belonged + to a global name. + """ + assert "FindAccessStates" in pipeline_results + + result: dict[dace.SDFG, set[str]] = dict() + for nsdfg in sdfg.all_sdfgs_recursive(): + single_level_res: set[str] = self._process_sdfg(nsdfg, pipeline_results) + if single_level_res: + result[nsdfg] = single_level_res + + return result if result else None + + def _process_sdfg( + self, + sdfg: dace.SDFG, + pipeline_results: dict[str, Any], + ) -> set[str]: + """Apply the pass to a single level of an SDFG, i.e. do not handle nested SDFG. + + The return value of this function is the same as for `apply_pass()`, but + only for the SDFG that was passed. + """ + t_mapping = self._find_candidates(sdfg, pipeline_results) + if len(t_mapping) == 0: + return set() + self._remove_writes_to_globals(sdfg, t_mapping, pipeline_results) + removed_transients = self._remove_transient_buffers_if_possible( + sdfg, t_mapping, pipeline_results + ) + + return removed_transients | t_mapping.keys() + + def _find_candidates( + self, + sdfg: dace.SDFG, + pipeline_results: dict[str, Any], + ) -> dict[str, set[str]]: + """The function searches for all candidates of that must be processed. + + The function returns a `dict` that maps the name of a global memory, `G` in + the above pattern, to the name of the buffer transient, `T` in the above + pattern. + """ + access_states: dict[str, set[dace.SDFGState]] = pipeline_results["FindAccessStates"][ + sdfg.cfg_id + ] + global_data = [ + aname + for aname, desc in sdfg.arrays.items() + if not desc.transient + and isinstance(desc, dace_data.Array) + and not isinstance(desc, dace_data.View) + ] + + candidates: dict[str, set[str]] = dict() + for gname in global_data: + candidate_tnames = self._classify_candidate(sdfg, gname, access_states) + if candidate_tnames is not None: + assert len(candidate_tnames) > 0 + candidates[gname] = candidate_tnames + + return candidates + + def _classify_candidate( + self, + sdfg: dace.SDFG, + gname: str, + access_states: dict[str, set[dace.SDFGState]], + ) -> Optional[set[str]]: + """The function tests if the global data `gname` can be handled. + + It essentially checks all conditions above, which is that the global is + only written through transients that are fully defined by the data itself. + writes to it are through transients that are fully defined by the data + itself. + + The function returns `None` if `gname` can not be handled by the function. + If `gname` can be handled the function returns a set of all data descriptors + that act as distributed buffers. + """ + # The set of access nodes that reads from the global, i.e. `gname`, essentially + # the set of candidates of `T` defined through the way it is defined. + # And the same set, but this time defined through who writes into the global. + reads_from_g: set[str] = set() + writes_to_g: set[str] = set() + + # In a first step we will identify the possible `T` only from the angle of + # how they interact with `G`. At a later point we will look at the `T` again, + # because in case of branches there might be multiple definitions of `T`. + for state in access_states[gname]: + for dnode in state.data_nodes(): + if dnode.data != gname: + continue + + # Note that we allow that `G` can be written to by multiple `T` at + # once. However, we require that all this data, is fully defined by + # a read to `G` itself. + for iedge in state.in_edges(dnode): + possible_t = iedge.src + + # If `G` is a pseudo output, see definition above, then it is only + # allowed that access nodes writes to them. Note, that here we + # will only collect which nodes writes to `G`, if these are + # valid `T`s will be checked later, after we cllected all of them. + if not isinstance(possible_t, dace_nodes.AccessNode): + return None + + possible_t_desc = possible_t.desc(sdfg) + if not possible_t_desc.transient: + return None # we must write into a transient. + if isinstance(possible_t_desc, dace_data.View): + return None # The global data must be written to from an array + if not isinstance(possible_t_desc, dace_data.Array): + return None + writes_to_g.add(possible_t.data) + + # Let's look who reads from `g` this will contribute to the `reads_from_g` set. + for oedge in state.out_edges(dnode): + possible_t = oedge.dst + # `T` must be an AccessNode. Note that it is not important + # what also reads from `G`. We just have to find everything that + # can act as `T`. + if not isinstance(possible_t, dace_nodes.AccessNode): + continue + + # It is important that only `G` defines `T`, so it must have + # an incoming degree of one, since we have SSA. + if state.in_degree(possible_t) != 1: + continue + + # `T` must fulfil some condition, like that it is transient. + possible_t_desc = possible_t.desc(sdfg) + if not possible_t_desc.transient: + continue # we must write into a transient. + if isinstance(possible_t_desc, dace_data.View): + continue # We must write into an array and not a view. + if not isinstance(possible_t_desc, dace_data.Array): + continue + + # Currently we do not handle the pattern `(T) -> (G) -> (T)`, + # see `_remove_writes_to_global()` for more, thus we filter + # this pattern here. + if any( + tnode_oedge.dst.data == gname + for tnode_oedge in state.out_edges(possible_t) + if isinstance(tnode_oedge.dst, dace_nodes.AccessNode) + ): + return None + + # Now add the data to the list of data that reads from `G`. + reads_from_g.add(possible_t.data) + + if len(writes_to_g) == 0: + return None + + # Now every write to `G` necessarily comes from an access node that was created + # by a direct read from `G`. We ensure this by checking that `writes_to_g` is + # a subset of `reads_to_g`. + # Note that the `T` nodes might not be unique, which happens in case + # of separate memlets for different subsets. + # of different subsets, are contained in ` + if not writes_to_g.issubset(reads_from_g): + return None + + # If we have branches, it might be that different data is written to `T` depending + # on which branch is selected, i.e. `T = G if cond else foo(A)`. For that + # reason we must now check that `G` is the only data source of `T`, but this + # time we must do the check on `T`. Note we only have to remove the particular access node + # to `T` where `G` is the only data source, while we keep the other access nodes. + # `T`. + for tname in list(writes_to_g): + for state in access_states[tname]: + for dnode in state.data_nodes(): + if dnode.data != tname: + continue + if state.in_degree(dnode) == 0: + continue # We are only interested at definitions. + + # Now ensures that only `gname` defines `T`. + for iedge in state.in_edges(dnode): + t_def_node = iedge.src + if not isinstance(t_def_node, dace_nodes.AccessNode): + writes_to_g.discard(tname) + break + if t_def_node.data != gname: + writes_to_g.discard(tname) + break + if tname not in writes_to_g: + break + + return None if len(writes_to_g) == 0 else writes_to_g + + def _remove_writes_to_globals( + self, + sdfg: dace.SDFG, + t_mapping: dict[str, set[str]], + pipeline_results: dict[str, Any], + ) -> None: + """Remove all writes to the global data defined through `t_mapping`. + + The function does not handle reads from the global to the transients. + + Args: + sdfg: The SDFG on which we should process. + t_mapping: Maps the name of the global data to the transient data. + This set was computed by the `_find_candidates()` function. + pipeline_results: The results of the pipeline. + """ + access_states: dict[str, set[dace.SDFGState]] = pipeline_results["FindAccessStates"][ + sdfg.cfg_id + ] + for gname, tnames in t_mapping.items(): + self._remove_writes_to_global( + sdfg=sdfg, gname=gname, tnames=tnames, access_states=access_states + ) + + def _remove_writes_to_global( + self, + sdfg: dace.SDFG, + gname: str, + tnames: set[str], + access_states: dict[str, set[dace.SDFGState]], + ) -> None: + """Remove writes to the global data `gname`. + + The function is the same as `_remove_writes_to_globals()` but only processes + one global data descriptor. + """ + # Here we delete the `T` node that writes into `G`, this might turn the `G` + # node into an isolated node. + # It is important that this code does not handle the `(G) -> (T) -> (G)` + # pattern, which is difficult to handle. The issue is that by removing `(T)`, + # what this function does, it also removes the definition `(T)`. However, + # it can only do that if it ensures that `T` is not used anywhere else. + # This is currently handle by the `SingleStateGlobalSelfCopyElimination` pass + # and the classifier rejects this pattern. + for state in access_states[gname]: + for dnode in list(state.data_nodes()): + if dnode.data != gname: + continue + for iedge in list(state.in_edges(dnode)): + tnode = iedge.src + if not isinstance(tnode, dace_nodes.AccessNode): + continue + if tnode.data in tnames: + state.remove_node(tnode) + + # It might be that the `dnode` has become isolated so remove it. + if state.degree(dnode) == 0: + state.remove_node(dnode) + + def _remove_transient_buffers_if_possible( + self, + sdfg: dace.SDFG, + t_mapping: dict[str, set[str]], + pipeline_results: dict[str, Any], + ) -> set[str]: + """Remove the transient data if it is possible, listed in `t_mapping`. + + Essentially the function will look if there is a read to any data that is + mentioned in `tnames`. If there isn't it will remove the write to it and + remove it from the registry. + The function must run after `_remove_writes_to_globals()`. + + The function returns the list of transients that were eliminated. + """ + access_states: dict[str, set[dace.SDFGState]] = pipeline_results["FindAccessStates"][ + sdfg.cfg_id + ] + result: set[str] = set() + for gname, tnames in t_mapping.items(): + result.update( + self._remove_transient_buffer_if_possible( + sdfg=sdfg, + gname=gname, + tnames=tnames, + access_states=access_states, + ) + ) + return result + + def _remove_transient_buffer_if_possible( + self, + sdfg: dace.SDFG, + gname: str, + tnames: set[str], + access_states: dict[str, set[dace.SDFGState]], + ) -> set[str]: + obsolete_ts: set[str] = set() + for tname in tnames: + # We can remove the (defining) write to `T` only if it is not read + # anywhere else. + if self._has_read_access_for(sdfg, tname, access_states): + continue + # Now we look for all writes to `tname` and remove them, since there + # are no reads. + for state in access_states[tname]: + neighbourhood: set[dace_nodes.Node] = set() + for dnode in list(state.data_nodes()): + if dnode.data == tname: + # We have to store potential sources nodes, which is `G`. + # This is because the local `G` node could become isolated. + # We do not need to consider the outgoing edges, because + # they are reads which we have handled above. + for iedge in state.in_edges(dnode): + assert ( + isinstance(iedge.src, dace_nodes.AccessNode) + and iedge.src.data == gname + ) + neighbourhood.add(iedge.src) + state.remove_node(dnode) + obsolete_ts.add(dnode.data) + + # We now have to check if an node has become isolated. + for nh_node in neighbourhood: + if state.degree(nh_node) == 0: + state.remove_node(nh_node) + + for tname in obsolete_ts: + sdfg.remove_data(tname, validate=False) + + return obsolete_ts + + def _has_read_access_for( + self, + sdfg: dace.SDFG, + dname: str, + access_states: dict[str, set[dace.SDFGState]], + ) -> bool: + """Checks if there is a read access on `dname`.""" + for state in access_states[dname]: + for dnode in state.data_nodes(): + if state.out_degree(dnode) == 0: + continue # We are only interested in read accesses + if dnode.data == dname: + return True + return False + + +@dace_properties.make_properties +class SingleStateGlobalSelfCopyElimination(dace_transformation.SingleStateTransformation): + """Remove global self copy. + + This transformation matches the following case `(G) -> (T) -> (G)`, i.e. `G` + is read from and written too at the same time, however, in between is `T` + used as a buffer. In the example above `G` is a global memory and `T` is a + temporary. This situation is generated by the lowering if the data node is + not needed (because the computation on it is only conditional). + + In case `G` refers to global memory rule 3 of ADR-18 guarantees that we can + only have a point wise dependency of the output on the input. + This transformation will remove the write into `G`, i.e. we thus only have + `(G) -> (T)`. The read of `G` and the definition of `T`, will only be removed + if `T` is not used downstream. If it is used `T` will be maintained. + """ + + node_read_g = dace_transformation.PatternNode(dace_nodes.AccessNode) + node_tmp = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) + node_write_g = dace_transformation.PatternNode(dace_nodes.AccessNode) + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.node_read_g, cls.node_tmp, cls.node_write_g)] + + def can_be_applied( + self, + graph: dace.SDFGState | dace.SDFG, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + read_g = self.node_read_g + write_g = self.node_write_g + tmp_node = self.node_tmp + g_desc = read_g.desc(sdfg) + tmp_desc = tmp_node.desc(sdfg) + + # NOTE: We do not check if `G` is read downstream. + if read_g.data != write_g.data: + return False + if g_desc.transient: + return False + if not tmp_desc.transient: + return False + if graph.in_degree(read_g) != 0: + return False + if graph.out_degree(read_g) != 1: + return False + if graph.degree(tmp_node) != 2: + return False + if graph.in_degree(write_g) != 1: + return False + if graph.out_degree(write_g) != 0: + return False + if graph.scope_dict()[read_g] is not None: + return False + + return True + + def _is_read_downstream( + self, + start_state: dace.SDFGState, + sdfg: dace.SDFG, + data_to_look: str, + ) -> bool: + """Scans for reads to `data_to_look`. + + The function will go through states that are reachable from `start_state` + (including) and test if there is a read to the data container `data_to_look`. + It will return `True` the first time it finds such a node. + It is important that the matched nodes, i.e. `self.node_{read_g, write_g, tmp}` + are ignored. + + Args: + start_state: The state where the scanning starts. + sdfg: The SDFG on which we operate. + data_to_look: The data that we want to look for. + + Todo: + Port this function to use DaCe pass pipeline. + """ + read_g: dace_nodes.AccessNode = self.node_read_g + write_g: dace_nodes.AccessNode = self.node_write_g + tmp_node: dace_nodes.AccessNode = self.node_tmp + + # TODO(phimuell): Run the `StateReachability` pass in a pipeline and use + # the `_pipeline_results` member to access the data. + return gtx_transformations.utils.is_accessed_downstream( + start_state=start_state, + sdfg=sdfg, + reachable_states=None, + data_to_look=data_to_look, + nodes_to_ignore={read_g, write_g, tmp_node}, + ) + + def apply( + self, + graph: dace.SDFGState | dace.SDFG, + sdfg: dace.SDFG, + ) -> None: + read_g: dace_nodes.AccessNode = self.node_read_g + write_g: dace_nodes.AccessNode = self.node_write_g + tmp_node: dace_nodes.AccessNode = self.node_tmp + + # We first check if `T`, the intermediate is not used downstream. In this + # case we can remove the read to `G` and `T` itself from the SDFG. + # We have to do this check before, because the matching is not fully stable. + is_tmp_used_downstream = self._is_read_downstream( + start_state=graph, sdfg=sdfg, data_to_look=tmp_node.data + ) + + # The write to `G` can always be removed. + graph.remove_node(write_g) + + # Also remove the read to `G` and `T` from the SDFG if possible. + if not is_tmp_used_downstream: + graph.remove_node(read_g) + graph.remove_node(tmp_node) + # It could still be used in a parallel branch. + try: + sdfg.remove_data(tmp_node.data, validate=True) + except ValueError as e: + if not str(e).startswith(f"Cannot remove data descriptor {tmp_node.data}:"): + raise diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py index f1fa65a716..1c2541ed99 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py @@ -57,8 +57,12 @@ def gt_simplify( - `InlineSDFGs`: Instead `gt_inline_nested_sdfg()` will be called. Further, the function will run the following passes in addition to DaCe simplify: - - `GT4PyGlobalSelfCopyElimination`: Special copy pattern that in the context - of GT4Py based SDFG behaves as a no op. + - `SingleStateGlobalSelfCopyElimination`: Special copy pattern that in the context + of GT4Py based SDFG behaves as a no op, i.e. `(G) -> (T) -> (G)`. + - `MultiStateGlobalSelfCopyElimination`: Very similar to + `SingleStateGlobalSelfCopyElimination`, with the exception that the write to + `T`, i.e. `(G) -> (T)` and the write back to `G`, i.e. `(T) -> (G)` might be + in different states. Furthermore, by default, or if `None` is passed for `skip` the passes listed in `GT_SIMPLIFY_DEFAULT_SKIP_SET` will be skipped. @@ -111,17 +115,31 @@ def gt_simplify( result = result or {} result.update(simplify_res) - if "GT4PyGlobalSelfCopyElimination" not in skip: + if "SingleStateGlobalSelfCopyElimination" not in skip: self_copy_removal_result = sdfg.apply_transformations_repeated( - GT4PyGlobalSelfCopyElimination(), + gtx_transformations.SingleStateGlobalSelfCopyElimination(), validate=validate, validate_all=validate_all, ) if self_copy_removal_result > 0: at_least_one_xtrans_run = True result = result or {} - result.setdefault("GT4PyGlobalSelfCopyElimination", 0) - result["GT4PyGlobalSelfCopyElimination"] += self_copy_removal_result + if "SingleStateGlobalSelfCopyElimination" not in result: + result["SingleStateGlobalSelfCopyElimination"] = 0 + result["SingleStateGlobalSelfCopyElimination"] += self_copy_removal_result + + if "MultiStateGlobalSelfCopyElimination" not in skip: + distributed_self_copy_result = ( + gtx_transformations.gt_multi_state_global_self_copy_elimination( + sdfg, validate=validate_all + ) + ) + if distributed_self_copy_result is not None: + at_least_one_xtrans_run = True + result = result or {} + if "MultiStateGlobalSelfCopyElimination" not in result: + result["MultiStateGlobalSelfCopyElimination"] = set() + result["MultiStateGlobalSelfCopyElimination"].update(distributed_self_copy_result) return result @@ -247,141 +265,10 @@ def gt_reduce_distributed_buffering( if ret is not None: all_result[rsdfg] = ret - return all_result - - -@dace_properties.make_properties -class GT4PyGlobalSelfCopyElimination(dace_transformation.SingleStateTransformation): - """Remove global self copy. - - This transformation matches the following case `(G) -> (T) -> (G)`, i.e. `G` - is read from and written too at the same time, however, in between is `T` - used as a buffer. In the example above `G` is a global memory and `T` is a - temporary. This situation is generated by the lowering if the data node is - not needed (because the computation on it is only conditional). - - In case `G` refers to global memory rule 3 of ADR-18 guarantees that we can - only have a point wise dependency of the output on the input. - This transformation will remove the write into `G`, i.e. we thus only have - `(G) -> (T)`. The read of `G` and the definition of `T`, will only be removed - if `T` is not used downstream. If it is used `T` will be maintained. - """ - - node_read_g = dace_transformation.PatternNode(dace_nodes.AccessNode) - node_tmp = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) - node_write_g = dace_transformation.PatternNode(dace_nodes.AccessNode) - - def __init__( - self, - *args: Any, - **kwargs: Any, - ) -> None: - super().__init__(*args, **kwargs) - - @classmethod - def expressions(cls) -> Any: - return [dace.sdfg.utils.node_path_graph(cls.node_read_g, cls.node_tmp, cls.node_write_g)] - - def can_be_applied( - self, - graph: dace.SDFGState | dace.SDFG, - expr_index: int, - sdfg: dace.SDFG, - permissive: bool = False, - ) -> bool: - read_g = self.node_read_g - write_g = self.node_write_g - tmp_node = self.node_tmp - g_desc = read_g.desc(sdfg) - tmp_desc = tmp_node.desc(sdfg) - - # NOTE: We do not check if `G` is read downstream. - if read_g.data != write_g.data: - return False - if g_desc.transient: - return False - if not tmp_desc.transient: - return False - if graph.in_degree(read_g) != 0: - return False - if graph.out_degree(read_g) != 1: - return False - if graph.degree(tmp_node) != 2: - return False - if graph.in_degree(write_g) != 1: - return False - if graph.out_degree(write_g) != 0: - return False - if graph.scope_dict()[read_g] is not None: - return False - - return True - - def _is_read_downstream( - self, - start_state: dace.SDFGState, - sdfg: dace.SDFG, - data_to_look: str, - ) -> bool: - """Scans for reads to `data_to_look`. - - The function will go through states that are reachable from `start_state` - (including) and test if there is a read to the data container `data_to_look`. - It will return `True` the first time it finds such a node. - It is important that the matched nodes, i.e. `self.node_{read_g, write_g, tmp}` - are ignored. + if len(all_result) == 0: + return None - Args: - start_state: The state where the scanning starts. - sdfg: The SDFG on which we operate. - data_to_look: The data that we want to look for. - - Todo: - Port this function to use DaCe pass pipeline. - """ - read_g: dace_nodes.AccessNode = self.node_read_g - write_g: dace_nodes.AccessNode = self.node_write_g - tmp_node: dace_nodes.AccessNode = self.node_tmp - - # TODO(phimuell): Run the `StateReachability` pass in a pipeline and use - # the `_pipeline_results` member to access the data. - return gtx_transformations.utils.is_accessed_downstream( - start_state=start_state, - sdfg=sdfg, - reachable_states=None, - data_to_look=data_to_look, - nodes_to_ignore={read_g, write_g, tmp_node}, - ) - - def apply( - self, - graph: dace.SDFGState | dace.SDFG, - sdfg: dace.SDFG, - ) -> None: - read_g: dace_nodes.AccessNode = self.node_read_g - write_g: dace_nodes.AccessNode = self.node_write_g - tmp_node: dace_nodes.AccessNode = self.node_tmp - - # We first check if `T`, the intermediate is not used downstream. In this - # case we can remove the read to `G` and `T` itself from the SDFG. - # We have to do this check before, because the matching is not fully stable. - is_tmp_used_downstream = self._is_read_downstream( - start_state=graph, sdfg=sdfg, data_to_look=tmp_node.data - ) - - # The write to `G` can always be removed. - graph.remove_node(write_g) - - # Also remove the read to `G` and `T` from the SDFG if possible. - if not is_tmp_used_downstream: - graph.remove_node(read_g) - graph.remove_node(tmp_node) - # It could still be used in a parallel branch. - try: - sdfg.remove_data(tmp_node.data, validate=True) - except ValueError as e: - if not str(e).startswith(f"Cannot remove data descriptor {tmp_node.data}:"): - raise + return all_result AccessLocation: TypeAlias = tuple[dace_nodes.AccessNode, dace.SDFGState] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multi_state_global_self_copy_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multi_state_global_self_copy_elimination.py new file mode 100644 index 0000000000..2eba2ce51c --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multi_state_global_self_copy_elimination.py @@ -0,0 +1,514 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from typing import Optional + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes +from dace.transformation import pass_pipeline as dace_ppl + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def apply_distributed_self_copy_elimination( + sdfg: dace.SDFG, +) -> Optional[dict[dace.SDFG, set[str]]]: + return gtx_transformations.gt_multi_state_global_self_copy_elimination(sdfg=sdfg, validate=True) + + +def _make_not_apply_because_of_write_to_g_sdfg() -> dace.SDFG: + """This SDFG is not eligible, because there is a write to `G`.""" + sdfg = dace.SDFG(util.unique_name("not_apply_because_of_write_to_g_sdfg")) + + # This is the `G` array. + sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) + # This is the `T` array. + sdfg.add_array(name="t", shape=(5,), dtype=dace.float64, transient=True) + + # This is an unrelated array that is used as output. + sdfg.add_array( + name="b", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + + state1 = sdfg.add_state(is_start_block=True) + state1.add_nedge(state1.add_access("a"), state1.add_access("t"), dace.Memlet("a[0:5] -> [0:5]")) + + state2 = sdfg.add_state_after(state1) + state2.add_mapped_tasklet( + "make_a_non_applicable", + map_ranges={"__i": "3:8"}, + inputs={}, + code="__out = 10.", + outputs={"__out": dace.Memlet("a[__i]")}, + external_edges=True, + ) + + state3 = sdfg.add_state_after(state2) + a3 = state3.add_access("a") + state3.add_nedge(state3.add_access("t"), a3, dace.Memlet("t[0:5] -> [0:5]")) + state3.add_mapped_tasklet( + "comp", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in + 1.", + outputs={"__out": dace.Memlet("b[__i]")}, + input_nodes={a3}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + +def _make_eligible_sdfg_1() -> dace.SDFG: + """This SDFG is very similar to the one generated by `_make_not_apply_because_of_write_to_g_sdfg()`. + + The main difference is that there is no mutating write to `a` and thus the + transformation applies. + """ + sdfg = dace.SDFG(util.unique_name("_make_eligible_sdfg_1")) + + # This is the `G` array. + sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) + # This is the `T` array. + sdfg.add_array(name="t", shape=(5,), dtype=dace.float64, transient=True) + + # These are some unrelated arrays that is used as output. + sdfg.add_array(name="b", shape=(10,), dtype=dace.float64, transient=False) + sdfg.add_array(name="c", shape=(10,), dtype=dace.float64, transient=False) + + state1 = sdfg.add_state(is_start_block=True) + state1.add_nedge(state1.add_access("a"), state1.add_access("t"), dace.Memlet("a[0:5] -> [0:5]")) + + state2 = sdfg.add_state_after(state1) + state2.add_mapped_tasklet( + "comp1", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in + 1.", + outputs={"__out": dace.Memlet("b[__i]")}, + external_edges=True, + ) + + state3 = sdfg.add_state_after(state2) + a3 = state3.add_access("a") + state3.add_nedge(state3.add_access("t"), a3, dace.Memlet("t[0:5] -> [0:5]")) + state3.add_mapped_tasklet( + "comp2", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in + 1.", + outputs={"__out": dace.Memlet("c[__i]")}, + input_nodes={a3}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + +def _make_multiple_temporaries_sdfg1() -> dace.SDFG: + """Generates an SDFG in which `G` is saved into different temporaries.""" + sdfg = dace.SDFG(util.unique_name("multiple_temporaries")) + + # This is the `G` array. + sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) + # This is the first `T` array. + sdfg.add_array(name="t1", shape=(5,), dtype=dace.float64, transient=True) + # This is the second `T` array. + sdfg.add_array(name="t2", shape=(5,), dtype=dace.float64, transient=True) + + # This are some unrelated array that is used as output. + sdfg.add_array(name="b", shape=(10,), dtype=dace.float64, transient=False) + + state1 = sdfg.add_state(is_start_block=True) + a1 = state1.add_access("a") + state1.add_nedge(a1, state1.add_access("t1"), dace.Memlet("a[0:5] -> [0:5]")) + state1.add_nedge(a1, state1.add_access("t2"), dace.Memlet("a[5:10] -> [0:5]")) + + state2 = sdfg.add_state_after(state1) + a2 = state2.add_access("a") + + state2.add_nedge(state2.add_access("t1"), a2, dace.Memlet("t1[0:5] -> [0:5]")) + state2.add_nedge(state2.add_access("t2"), a2, dace.Memlet("t2[0:5] -> [5:10]")) + + state2.add_mapped_tasklet( + "comp", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in + 1.", + outputs={"__out": dace.Memlet("b[__i]")}, + input_nodes={a2}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + +def _make_multiple_temporaries_sdfg2() -> dace.SDFG: + """Generates an SDFG where there are multiple `T` used. + + The main difference between the SDFG produced by this function and the one + generated by `_make_multiple_temporaries_sdfg()` is that the temporaries + are used sequentially. + """ + sdfg = dace.SDFG(util.unique_name("multiple_temporaries_sequential")) + + # This is the `G` array. + sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) + # This is the first `T` array. + sdfg.add_array(name="t1", shape=(5,), dtype=dace.float64, transient=True) + # This is the second `T` array. + sdfg.add_array(name="t2", shape=(5,), dtype=dace.float64, transient=True) + + # This are some unrelated array that is used as output. + sdfg.add_array(name="b", shape=(10,), dtype=dace.float64, transient=False) + sdfg.add_array(name="c", shape=(10,), dtype=dace.float64, transient=False) + + state1 = sdfg.add_state(is_start_block=True) + state1.add_nedge( + state1.add_access("a"), state1.add_access("t1"), dace.Memlet("a[0:5] -> [0:5]") + ) + + state2 = sdfg.add_state_after(state1) + a2 = state2.add_access("a") + + state2.add_nedge(state2.add_access("t1"), a2, dace.Memlet("t1[0:5] -> [0:5]")) + + state2.add_mapped_tasklet( + "comp", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in + 1.", + outputs={"__out": dace.Memlet("b[__i]")}, + input_nodes={a2}, + external_edges=True, + ) + + # This essentially repeats the same thing as above again, but this time with `t2`. + state3 = sdfg.add_state_after(state2) + state3.add_nedge( + state3.add_access("a"), state3.add_access("t2"), dace.Memlet("a[5:10] -> [0:5]") + ) + + state4 = sdfg.add_state_after(state3) + a4 = state4.add_access("a") + state4.add_nedge(state4.add_access("t2"), a4, dace.Memlet("t2[0:5] -> [5:10]")) + state4.add_mapped_tasklet( + "comp2", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in - 1.", + outputs={"__out": dace.Memlet("c[__i]")}, + input_nodes={a4}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + +def _make_multiple_temporaries_sdfg_keep_one_1() -> dace.SDFG: + """ + The generated SDFG is very similar to `_make_multiple_temporaries_sdfg1()` except + that `t1` can not be removed because it is used to generate `c`. + """ + sdfg = _make_multiple_temporaries_sdfg1() + + sdfg.add_array("c", shape=(5,), dtype=dace.float64, transient=False) + + state = sdfg.add_state_after( + next(iter(state for state in sdfg.states() if sdfg.out_degree(state) == 0)) + ) + state.add_mapped_tasklet( + "comp_that_needs_t1", + map_ranges={"__j": "0:5"}, + inputs={"__in": dace.Memlet("t1[__j]")}, + code="__out = __in + 4.0", + outputs={"__out": dace.Memlet("c[__j]")}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + +def _make_multiple_temporaries_sdfg_keep_one_2() -> dace.SDFG: + """ + The generated SDFG is very similar to `_make_multiple_temporaries_sdfg2()` except + that `t1` can not be removed because it is used to generate `d`. + """ + sdfg = _make_multiple_temporaries_sdfg2() + + sdfg.add_array("d", shape=(5,), dtype=dace.float64, transient=False) + + state = sdfg.add_state_after( + next(iter(state for state in sdfg.states() if sdfg.out_degree(state) == 0)) + ) + state.add_mapped_tasklet( + "comp_that_needs_t1", + map_ranges={"__j": "0:5"}, + inputs={"__in": dace.Memlet("t1[__j]")}, + code="__out = __in + 4.0", + outputs={"__out": dace.Memlet("d[__j]")}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + +def _make_non_eligible_because_of_pseudo_temporary() -> dace.SDFG: + """Generates an SDFG that that defines `T` from two souces, which is not handled. + + Note that in this particular case it would be possible, but we do not support it. + """ + sdfg = dace.SDFG(util.unique_name("multiple_temporaries_sequential")) + + # This is the `G` array. + sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) + # This is the `T` array. + sdfg.add_array(name="t", shape=(10,), dtype=dace.float64, transient=True) + + # This is the array that also writes to `T` and thus makes it inapplicable. + sdfg.add_array(name="pg", shape=(10,), dtype=dace.float64, transient=True) + + # This are some unrelated array that is used as output. + sdfg.add_array(name="b", shape=(10,), dtype=dace.float64, transient=False) + + state1 = sdfg.add_state(is_start_block=True) + t1 = state1.add_access("t") + state1.add_nedge(state1.add_access("a"), t1, dace.Memlet("a[0:5] -> [0:5]")) + state1.add_nedge(state1.add_access("pg"), t1, dace.Memlet("pg[0:5] -> [5:10]")) + + state2 = sdfg.add_state_after(state1) + a2 = state2.add_access("a") + state2.add_nedge(state2.add_access("t"), a2, dace.Memlet("t[0:5] -> [0:5]")) + state2.add_mapped_tasklet( + "comp", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in + 1.0", + outputs={"__out": dace.Memlet("b[__i]")}, + input_nodes={a2}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + +def _make_wb_single_state_sdfg() -> dace.SDFG: + """Generates an SDFG with the pattern `(G) -> (T) -> (G)` which is not handled. + + This pattern is handled by the `SingleStateGlobalSelfCopyElimination` transformation. + """ + sdfg = dace.SDFG(util.unique_name("single_state_write_back_sdfg")) + + sdfg.add_array("g", shape=(10,), dtype=dace.float64, transient=False) + sdfg.add_array("t", shape=(10,), dtype=dace.float64, transient=True) + sdfg.add_array("b", shape=(10,), dtype=dace.float64, transient=False) + + state1 = sdfg.add_state(is_start_block=True) + t1 = state1.add_access("t") + state1.add_nedge(state1.add_access("g"), t1, dace.Memlet("g[0:10] -> [0:10]")) + g1 = state1.add_access("g") + state1.add_nedge(t1, g1, dace.Memlet("t[0:10] -> [0:10]")) + + # return sdfg + + state1.add_mapped_tasklet( + "comp", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("g[__i]")}, + code="__out = __in + 1.0", + outputs={"__out": dace.Memlet("b[__i]")}, + input_nodes={g1}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + +def _make_non_eligible_sdfg_with_branches(): + """Creates an SDFG with two different definitions of `T`.""" + sdfg = dace.SDFG(util.unique_name("non_eligible_sdfg_with_branches_sdfg")) + + # This is the `G` array, it is also used as output. + sdfg.add_array("a", shape=(10,), dtype=dace.float64, transient=False) + # This is the (possible) `T` array. + sdfg.add_array("t", shape=(10,), dtype=dace.float64, transient=True) + + # This is an additional array that serves as input. In one case it defines `t`. + sdfg.add_array("b", shape=(10,), dtype=dace.float64, transient=False) + # This is the condition on which we switch. + sdfg.add_scalar("cond", dtype=dace.bool, transient=False) + + # This is the init state. + state1 = sdfg.add_state(is_start_block=True) + + # This is the state where `T` is not defined in terms of `G`. + stateT = sdfg.add_state(is_start_block=False) + sdfg.add_edge(state1, stateT, dace.InterstateEdge(condition="cond == True")) + stateT.add_nedge( + stateT.add_access("b"), stateT.add_access("t"), dace.Memlet("b[0:10] -> [0:10]") + ) + + # This is the state where `T` is defined in terms of `G`. + stateF = sdfg.add_state(is_start_block=False) + sdfg.add_edge(state1, stateF, dace.InterstateEdge(condition="cond != True")) + stateF.add_nedge( + stateF.add_access("a"), stateF.add_access("t"), dace.Memlet("a[0:10] -> [0:10]") + ) + + # Now the write back state, where `T` is written back into `G`. + stateWB = sdfg.add_state(is_start_block=False) + stateWB.add_nedge( + stateWB.add_access("t"), stateWB.add_access("a"), dace.Memlet("t[0:10] -> [0:10]") + ) + + sdfg.add_edge(stateF, stateWB, dace.InterstateEdge()) + sdfg.add_edge(stateT, stateWB, dace.InterstateEdge()) + + sdfg.validate() + return sdfg + + +def test_not_apply_because_of_write_to_g(): + sdfg = _make_not_apply_because_of_write_to_g_sdfg() + old_hash = sdfg.hash_sdfg() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + res = apply_distributed_self_copy_elimination(sdfg) + nb_access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode) + + assert res is None + assert nb_access_nodes_initially == nb_access_nodes_after + assert old_hash == sdfg.hash_sdfg() + + +def test_eligible_sdfg_1(): + sdfg = _make_eligible_sdfg_1() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + res = apply_distributed_self_copy_elimination(sdfg) + access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode, return_nodes=True) + + assert res == {"a", "t"} + assert nb_access_nodes_initially == len(access_nodes_after) + 3 + assert not any(an.data == "t" for an in access_nodes_after) + + +def test_multiple_temporaries(): + sdfg = _make_multiple_temporaries_sdfg1() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + res = apply_distributed_self_copy_elimination(sdfg) + access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode, return_nodes=True) + + assert res == {"a", "t1", "t2"} + assert not any(an.data.startswith("t") for an in access_nodes_after) + assert nb_access_nodes_initially == len(access_nodes_after) + 5 + + +def test_multiple_temporaries_2(): + sdfg = _make_multiple_temporaries_sdfg2() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + res = apply_distributed_self_copy_elimination(sdfg) + access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode, return_nodes=True) + + assert res == {"a", "t1", "t2"} + assert not any(an.data.startswith("t") for an in access_nodes_after) + assert nb_access_nodes_initially == len(access_nodes_after) + 6 + + +def test_multiple_temporaries_keep_one_1(): + sdfg = _make_multiple_temporaries_sdfg_keep_one_1() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + # NOTE: The transformation will not only remove the `(t2) -> (a)` write in the + # second block, but also the `(t1) -> (a)` write, this is because it was + # concluded that this was a noops write. This might be a bit unintuitive + # considering that `t1` is used in the third state. However, this is why the + # `(a) -> (t1)` write in the first state is maintained. + res = apply_distributed_self_copy_elimination(sdfg) + access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode, return_nodes=True) + start_block_nodes = util.count_nodes(sdfg.start_block, dace_nodes.AccessNode, return_nodes=True) + + assert res == {"a", "t2"} + assert not any(an.data == "t2" for an in access_nodes_after) + assert sum(an.data == "t1" for an in access_nodes_after) == 2 + assert nb_access_nodes_initially == len(access_nodes_after) + 3 + assert len(start_block_nodes) == 2 + assert {nb.data for nb in start_block_nodes} == {"a", "t1"} + + +def test_multiple_temporaries_keep_one_2(): + sdfg = _make_multiple_temporaries_sdfg_keep_one_2() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + res = apply_distributed_self_copy_elimination(sdfg) + access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode, return_nodes=True) + + assert res == {"a", "t2"} + assert not any(an.data == "t2" for an in access_nodes_after) + assert sum(an.data == "t1" for an in access_nodes_after) == 2 + assert nb_access_nodes_initially == len(access_nodes_after) + 4 + + +def test_pseudo_temporaries(): + sdfg = _make_non_eligible_because_of_pseudo_temporary() + old_hash = sdfg.hash_sdfg() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + res = apply_distributed_self_copy_elimination(sdfg) + nb_access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode) + + assert res is None + assert nb_access_nodes_initially == nb_access_nodes_after + assert old_hash == sdfg.hash_sdfg() + + +def test_single_state(): + sdfg = _make_wb_single_state_sdfg() + old_hash = sdfg.hash_sdfg() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + res = apply_distributed_self_copy_elimination(sdfg) + nb_access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode) + + assert res is None + assert nb_access_nodes_initially == nb_access_nodes_after + assert old_hash == sdfg.hash_sdfg() + + +def test_branches(): + sdfg = _make_non_eligible_sdfg_with_branches() + old_hash = sdfg.hash_sdfg() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + res = apply_distributed_self_copy_elimination(sdfg) + nb_access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode) + + assert res is None + assert nb_access_nodes_initially == nb_access_nodes_after + assert old_hash == sdfg.hash_sdfg() diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_single_state_global_self_copy_elimination.py similarity index 93% rename from tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py rename to tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_single_state_global_self_copy_elimination.py index 1d98fef8c4..2264102182 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_single_state_global_self_copy_elimination.py @@ -17,8 +17,6 @@ from . import util -import dace - def _make_self_copy_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: """Generates an SDFG that contains the self copying pattern.""" @@ -51,7 +49,7 @@ def test_global_self_copy_elimination_only_pattern(): assert state.number_of_edges() == 2 count = sdfg.apply_transformations_repeated( - gtx_transformations.GT4PyGlobalSelfCopyElimination, validate=True, validate_all=True + gtx_transformations.SingleStateGlobalSelfCopyElimination, validate=True, validate_all=True ) assert count != 0 @@ -90,7 +88,7 @@ def test_global_self_copy_elimination_g_downstream(): assert state2.number_of_nodes() == 5 count = sdfg.apply_transformations_repeated( - gtx_transformations.GT4PyGlobalSelfCopyElimination, validate=True, validate_all=True + gtx_transformations.SingleStateGlobalSelfCopyElimination, validate=True, validate_all=True ) assert count != 0 @@ -132,7 +130,7 @@ def test_global_self_copy_elimination_tmp_downstream(): assert state2.number_of_nodes() == 5 count = sdfg.apply_transformations_repeated( - gtx_transformations.GT4PyGlobalSelfCopyElimination, validate=True, validate_all=True + gtx_transformations.SingleStateGlobalSelfCopyElimination, validate=True, validate_all=True ) assert count != 0 From 7bdfaa175364165315df6a80aab71a63ae610b54 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 3 Mar 2025 09:11:27 +0100 Subject: [PATCH 163/178] feat[next]: ITIR type inference: store param types in `itir.Lambda` (#1868) --- .../next/iterator/type_system/inference.py | 8 +++++++ .../iterator_tests/test_type_inference.py | 21 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index fe450625db..d6faefc372 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -37,6 +37,14 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: assert type_info.is_compatible_type( node.type, type_ ), "Node already has a type which differs." + # Also populate the type of the parameters of a lambda. That way the one can access the type + # of a parameter by a lookup in the symbol table. As long as `_set_node_type` is used + # exclusively, the information stays consistent with the types stored in the `FunctionType` + # of the lambda itself. + if isinstance(node, itir.Lambda): + assert isinstance(type_, ts.FunctionType) + for param, param_type in zip(node.params, type_.pos_only_args): + _set_node_type(param, param_type) node.type = type_ diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index a39fe3c6d8..c13cb1d119 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -238,6 +238,26 @@ def test_adhoc_polymorphism(): assert result.type == ts.TupleType(types=[bool_type, int_type]) +def test_binary_lambda(): + func = im.lambda_("a", "b")(im.make_tuple("a", "b")) + testee = im.call(func)(im.ref("a_", bool_type), im.ref("b_", int_type)) + + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + + expected_type = ts.TupleType(types=[bool_type, int_type]) + assert result.type == expected_type + assert result.fun.params[0].type == bool_type + assert result.fun.params[1].type == int_type + assert result.fun.type == ts.FunctionType( + pos_only_args=[bool_type, int_type], + pos_or_kw_args={}, + kw_only_args={}, + returns=expected_type, + ) + + def test_aliased_function(): testee = im.let("f", im.lambda_("x")("x"))(im.call("f")(1)) result = itir_type_inference.infer(testee, offset_provider_type={}) @@ -245,6 +265,7 @@ def test_aliased_function(): assert result.args[0].type == ts.FunctionType( pos_only_args=[int_type], pos_or_kw_args={}, kw_only_args={}, returns=int_type ) + assert result.args[0].params[0].type == int_type assert result.type == int_type From efb6373c67bc4b1d3f62111d540fe28fd2d007a6 Mon Sep 17 00:00:00 2001 From: SF-N Date: Mon, 3 Mar 2025 11:47:27 +0100 Subject: [PATCH 164/178] refactor[next]: Global ordering relation of dimensions (#1847) This PR introduces a global ordering relation of dimensions, replacing the previous mechanism in promote_dims. The ordering relation is: sorting first by `Dimension.kind` (`HORIZONTAL` < `VERTICAL` < `LOCAL`) and then lexicographically by `Dimension.value`. Reason: An as_fieldop call as emitted by the frontend has no domain, and inferring the type of the domain was not possible, since no global ordering relation of dimensions existed, e.g. for an as_fieldop operating on a `Vertex` and a `K` field it was unclear if the dimensions of the output were `(Vertex, K)` or (`K, Vertex)`, which lead to many negative consequences in other places, that will be tackled in [PR 1853](https://github.com/GridTools/gt4py/pull/1853) and following ones. --- src/gt4py/next/common.py | 103 +++++++----------- src/gt4py/next/type_system/type_info.py | 10 +- .../next/type_system/type_specifications.py | 6 + tests/next_tests/integration_tests/cases.py | 1 + .../ffront_tests/test_gt4py_builtins.py | 9 +- .../ffront_tests/test_foast_to_gtir.py | 4 +- .../ffront_tests/test_type_deduction.py | 14 +-- .../iterator_tests/test_type_inference.py | 5 +- .../binding_tests/test_cpp_interface.py | 8 +- tests/next_tests/unit_tests/test_common.py | 48 ++++---- .../next_tests/unit_tests/test_type_system.py | 5 +- 11 files changed, 92 insertions(+), 121 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index e5b393f1ae..f615833045 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -69,6 +69,9 @@ def __str__(self) -> str: return self.value +_DIM_KIND_ORDER = {DimensionKind.HORIZONTAL: 1, DimensionKind.VERTICAL: 2, DimensionKind.LOCAL: 3} + + def dimension_to_implicit_offset(dim: str) -> str: """ Return name of offset implicitly defined by a dimension. @@ -1123,84 +1126,56 @@ class GridType(StrEnum): UNSTRUCTURED = "unstructured" +def _ordered_dims(dims: list[Dimension] | set[Dimension]) -> list[Dimension]: + return sorted(dims, key=lambda dim: (_DIM_KIND_ORDER[dim.kind], dim.value)) + + +def check_dims(dims: list[Dimension]) -> None: + if sum(1 for dim in dims if dim.kind == DimensionKind.LOCAL) > 1: + raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.") + + if dims != _ordered_dims(dims): + raise ValueError( + f"Dimensions '{', '.join(map(str, dims))}' are not ordered correctly, expected '{', '.join(map(str, _ordered_dims(dims)))}'." + ) + + def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: """ - Find a unique ordering of multiple (individually ordered) lists of dimensions. + Find an ordering of multiple lists of dimensions. - The resulting list of dimensions contains all dimensions of the arguments - in the order they originally appear. If no unique order exists or a - contradicting order is found an exception is raised. - - A modified version (ensuring uniqueness of the order) of - `Kahn's algorithm `_ - is used to topologically sort the arguments. + The resulting list contains all unique dimensions from the input lists, + sorted first by dims_kind_order, i.e., `Dimension.kind` (`HORIZONTAL` < `VERTICAL` < `LOCAL`) and then + lexicographically by `Dimension.value`. Examples: >>> from gt4py.next.common import Dimension - >>> I, J, K = (Dimension(value=dim) for dim in ["I", "J", "K"]) - >>> promote_dims([I, J], [I, J, K]) == [I, J, K] + >>> I = Dimension("I", DimensionKind.HORIZONTAL) + >>> J = Dimension("J", DimensionKind.HORIZONTAL) + >>> K = Dimension("K", DimensionKind.VERTICAL) + >>> E2V = Dimension("E2V", kind=DimensionKind.LOCAL) + >>> E2C = Dimension("E2C", kind=DimensionKind.LOCAL) + >>> promote_dims([J, K], [I, K]) == [I, J, K] True - - >>> promote_dims([I, J], [K]) # doctest: +ELLIPSIS + >>> promote_dims([K, J], [I, K]) Traceback (most recent call last): ... - ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. - - >>> promote_dims([I, J], [J, I]) # doctest: +ELLIPSIS + ValueError: Dimensions 'K[vertical], J[horizontal]' are not ordered correctly, expected 'J[horizontal], K[vertical]'. + >>> promote_dims([I, K], [J, E2V]) == [I, J, K, E2V] + True + >>> promote_dims([I, E2C], [K, E2V]) Traceback (most recent call last): ... - ValueError: Dimensions can not be promoted. The following dimensions appear in contradicting order: I, J. + ValueError: There are more than one dimension with DimensionKind 'LOCAL'. """ - # build a graph with the vertices being dimensions and edges representing - # the order between two dimensions. The graph is encoded as a dictionary - # mapping dimensions to their predecessors, i.e. a dictionary containing - # adjacency lists. Since graphlib.TopologicalSorter uses predecessors - # (contrary to successors) we also use this directionality here. - graph: dict[Dimension, set[Dimension]] = {} + for dims in dims_list: - if len(dims) == 0: - continue - # create a vertex for each dimension - for dim in dims: - graph.setdefault(dim, set()) - # add edges - predecessor = dims[0] - for dim in dims[1:]: - graph[dim].add(predecessor) - predecessor = dim - - # modified version of Kahn's algorithm - topologically_sorted_list: list[Dimension] = [] - - # compute in-degree for each vertex - in_degree = {v: 0 for v in graph.keys()} - for v1 in graph: - for v2 in graph[v1]: - in_degree[v2] += 1 - - # process vertices with in-degree == 0 - # TODO(tehrengruber): avoid recomputation of zero_in_degree_vertex_list - while zero_in_degree_vertex_list := [v for v, d in in_degree.items() if d == 0]: - if len(zero_in_degree_vertex_list) != 1: - raise ValueError( - f"Dimensions can not be promoted. Could not determine " - f"order of the following dimensions: " - f"{', '.join((dim.value for dim in zero_in_degree_vertex_list))}." - ) - v = zero_in_degree_vertex_list[0] - del in_degree[v] - topologically_sorted_list.insert(0, v) - # update in-degree - for predecessor in graph[v]: - in_degree[predecessor] -= 1 - - if len(in_degree.items()) > 0: - raise ValueError( - f"Dimensions can not be promoted. The following dimensions " - f"appear in contradicting order: {', '.join((dim.value for dim in in_degree.keys()))}." - ) + check_dims(list(dims)) + unique_dims = {dim for dims in dims_list for dim in dims} - return topologically_sorted_list + promoted_dims = _ordered_dims(unique_dims) + check_dims(promoted_dims) + return promoted_dims class FieldBuiltinFuncRegistry: diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 27dd2cf02c..bbaaa82728 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -567,12 +567,11 @@ def promote( >>> promoted.dims == [I, J, K] and promoted.dtype == dtype True - >>> promote( + >>> promoted: ts.FieldType = promote( ... ts.FieldType(dims=[I, J], dtype=dtype), ts.FieldType(dims=[K], dtype=dtype) - ... ) # doctest: +ELLIPSIS - Traceback (most recent call last): - ... - ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. + ... ) + >>> promoted.dims == [I, J, K] and promoted.dtype == dtype + True """ if not always_field and all(isinstance(type_, ts.ScalarType) for type_ in types): if not all(type_ == types[0] for type_ in types): @@ -642,6 +641,7 @@ def return_type_field( new_dims.append(d) else: new_dims.extend(target_dims) + new_dims = common._ordered_dims(new_dims) # e.g. `Vertex, V2E, K` -> `Vertex, K, V2E` return ts.FieldType(dims=new_dims, dtype=field_type.dtype) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 2fbd039d16..5b46f9dd0d 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -105,6 +105,12 @@ def __str__(self) -> str: dims = "..." if self.dims is Ellipsis else f"[{', '.join(dim.value for dim in self.dims)}]" return f"Field[{dims}, {self.dtype}]" + @eve_datamodels.validator("dims") + def _dims_validator( + self, attribute: eve_datamodels.Attribute, dims: list[common.Dimension] + ) -> None: + common.check_dims(dims) + class TupleType(DataType): # TODO(tehrengruber): Remove `DeferredType` again. This was erroneously diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 89ad556476..6e8ff1b3f6 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -60,6 +60,7 @@ # mypy does not accept [IDim, ...] as a type IField: TypeAlias = gtx.Field[[IDim], np.int32] # type: ignore [valid-type] +JField: TypeAlias = gtx.Field[[JDim], np.int32] # type: ignore [valid-type] IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type] IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type] KField: TypeAlias = gtx.Field[[KDim], np.int32] # type: ignore [valid-type] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index ab1c625fef..d7fe252cb4 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -95,16 +95,9 @@ def reduction_ek_field( return neighbor_sum(edge_f(V2E), axis=V2EDim) -@gtx.field_operator -def reduction_ke_field( - edge_f: common.Field[[KDim, Edge], np.int32], -) -> common.Field[[KDim, Vertex], np.int32]: - return neighbor_sum(edge_f(V2E), axis=V2EDim) - - @pytest.mark.uses_unstructured_shift @pytest.mark.parametrize( - "fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__ + "fop", [reduction_e_field, reduction_ek_field], ids=lambda fop: fop.__name__ ) def test_neighbor_sum(unstructured_case_3d, fop): v2e_table = unstructured_case_3d.offset_provider["V2E"].ndarray diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index c0d762efc8..776cd4e1a9 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -901,7 +901,7 @@ def foo() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: def test_broadcast(): def foo(inp: gtx.Field[[TDim], float64]): - return broadcast(inp, (UDim, TDim)) + return broadcast(inp, (TDim, UDim)) parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) @@ -912,7 +912,7 @@ def foo(inp: gtx.Field[[TDim], float64]): def test_scalar_broadcast(): def foo(): - return broadcast(1, (UDim, TDim)) + return broadcast(1, (TDim, UDim)) parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py index 254772fd8a..f9393bd99c 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py @@ -81,20 +81,18 @@ def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): def test_binop_nonmatching_dims(): - """Binary operations can only work when both fields have the same dimensions.""" + """Dimension promotion is applied before Binary operations, i.e., they can also work on two fields that don't have the same dimensions.""" X = Dimension("X") Y = Dimension("Y") def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]): return a + b - with pytest.raises( - errors.DSLError, - match=( - r"Could not promote 'Field\[\[X], float64\]' and 'Field\[\[Y\], float64\]' to common type in call to +." - ), - ): - _ = FieldOperatorParser.apply_to_function(nonmatching) + parsed = FieldOperatorParser.apply_to_function(nonmatching) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[X, Y], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) def test_bitopping_float(): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index c13cb1d119..6e2f941095 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -43,7 +43,7 @@ unstructured_case, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import simple_mesh - +from next_tests.integration_tests.cases import IField, JField bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) @@ -52,8 +52,11 @@ int_list_type = ts.ListType(element_type=int_type) float_i_field = ts.FieldType(dims=[IDim], dtype=float64_type) +float_j_field = ts.FieldType(dims=[JDim], dtype=float64_type) +float_ij_field = ts.FieldType(dims=[IDim, JDim], dtype=float64_type) float_vertex_k_field = ts.FieldType(dims=[Vertex, KDim], dtype=float64_type) float_edge_k_field = ts.FieldType(dims=[Edge, KDim], dtype=float64_type) +float_edge_field = ts.FieldType(dims=[Edge], dtype=float64_type) float_vertex_v2e_field = ts.FieldType(dims=[Vertex, V2EDim], dtype=float64_type) it_on_v_of_e_type = it_ts.IteratorType( diff --git a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py index 51b6bf512b..a25732649a 100644 --- a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py +++ b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py @@ -60,14 +60,14 @@ def function_buffer_example(): interface.Parameter( name="a_buf", type_=ts.FieldType( - dims=[gtx.Dimension("foo"), gtx.Dimension("bar")], + dims=[gtx.Dimension("bar"), gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.FLOAT64), ), ), interface.Parameter( name="b_buf", type_=ts.FieldType( - dims=[gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.INT64) + dims=[gtx.Dimension("bar")], dtype=ts.ScalarType(ts.ScalarKind.INT64) ), ), ], @@ -111,11 +111,11 @@ def function_tuple_example(): type_=ts.TupleType( types=[ ts.FieldType( - dims=[gtx.Dimension("foo"), gtx.Dimension("bar")], + dims=[gtx.Dimension("bar"), gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.FLOAT64), ), ts.FieldType( - dims=[gtx.Dimension("foo"), gtx.Dimension("bar")], + dims=[gtx.Dimension("bar"), gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.FLOAT64), ), ] diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 8f46fc7ce1..09ca44aaac 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -10,7 +10,9 @@ from typing import Optional, Pattern import pytest +import re +from gt4py import next as gtx import gt4py.next.common as common from gt4py.next.common import ( Dimension, @@ -25,7 +27,11 @@ unit_range, ) - +C2E = Dimension("C2E", kind=DimensionKind.LOCAL) +V2E = Dimension("V2E", kind=DimensionKind.LOCAL) +E2V = Dimension("E2V", kind=DimensionKind.LOCAL) +E2C = Dimension("E2C", kind=DimensionKind.LOCAL) +E2C2V = Dimension("E2C2V", kind=DimensionKind.LOCAL) ECDim = Dimension("ECDim") IDim = Dimension("IDim") JDim = Dimension("JDim") @@ -324,16 +330,6 @@ def test_domain_intersection_different_dimensions(a_domain, second_domain, expec assert result_domain == expected -def test_domain_intersection_reversed_dimensions(a_domain): - domain2 = Domain(dims=(JDim, IDim), ranges=(UnitRange(2, 12), UnitRange(7, 17))) - - with pytest.raises( - ValueError, - match="Dimensions can not be promoted. The following dimensions appear in contradicting order: IDim, JDim.", - ): - a_domain & domain2 - - @pytest.mark.parametrize( "index, expected", [ @@ -571,27 +567,29 @@ def dimension_promotion_cases() -> ( ): raw_list = [ # list of list of dimensions, expected result, expected error message - ([["I", "J"], ["I"]], ["I", "J"], None), - ([["I", "J"], ["J"]], ["I", "J"], None), - ([["I", "J"], ["J", "K"]], ["I", "J", "K"], None), + ([[IDim, JDim], [IDim]], [IDim, JDim], None), + ([[JDim], [IDim, JDim]], [IDim, JDim], None), + ([[JDim, KDim], [IDim, JDim]], [IDim, JDim, KDim], None), ( - [["I", "J"], ["J", "I"]], + [[IDim, JDim], [JDim, IDim]], None, - r"The following dimensions appear in contradicting order: I, J.", + "Dimensions 'JDim[horizontal], IDim[horizontal]' are not ordered correctly, expected 'IDim[horizontal], JDim[horizontal]'.", ), + ([[JDim, KDim], [IDim, KDim]], [IDim, JDim, KDim], None), ( - [["I", "K"], ["J", "K"]], + [[KDim, JDim], [IDim, KDim]], None, - r"Could not determine order of the following dimensions: I, J", + "Dimensions 'KDim[vertical], JDim[horizontal]' are not ordered correctly, expected 'JDim[horizontal], KDim[vertical]'.", ), + ( + [[JDim, V2E], [IDim, KDim, E2C2V]], + None, + "There are more than one dimension with DimensionKind 'LOCAL'.", + ), + ([[JDim, V2E], [IDim, KDim]], [IDim, JDim, KDim, V2E], None), ] - # transform dimension names into Dimension objects return [ - ( - [[Dimension(el) for el in arg] for arg in args], - [Dimension(el) for el in result] if result else result, - msg, - ) + ([[el for el in arg] for arg in args], [el for el in result] if result else result, msg) for args, result, msg in raw_list ] @@ -608,7 +606,7 @@ def test_dimension_promotion( with pytest.raises(Exception) as exc_info: promote_dims(*dim_list) - assert exc_info.match(expected_error_msg) + assert exc_info.match(re.escape(expected_error_msg)) class TestCartesianConnectivity: diff --git a/tests/next_tests/unit_tests/test_type_system.py b/tests/next_tests/unit_tests/test_type_system.py index 99758d6f14..69ff54b711 100644 --- a/tests/next_tests/unit_tests/test_type_system.py +++ b/tests/next_tests/unit_tests/test_type_system.py @@ -305,10 +305,7 @@ def callable_type_info_cases(): ts.FieldType(dims=[KDim], dtype=int_type), ], {}, - [ - r"Dimensions can not be promoted. Could not determine order of the " - r"following dimensions: J, K." - ], + [], ts.FieldType(dims=[IDim, JDim, KDim], dtype=float_type), ), ( From b6a316285ad5825b9a9f2be5b3a8df36fa2450e1 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 4 Mar 2025 07:39:33 +0100 Subject: [PATCH 165/178] refactor: stick to supported device types (#1879) ## Description Following the [YAGNI principle](https://github.com/GridTools/gt4py/blob/main/CODING_GUIDELINES.md#software-design), stick to the device types that are actually supported in the codebase. We can add support for other devices later. @havogt as discussed the other day, let's have a discussion about whether or not we need all these device types defined. I know that @FlorianDeconinck would like to work towards over protection and throwing errors rather sooner than later for unsupported things (very much in general). This is part of a clean-up series tracked in issue https://github.com/GridTools/gt4py/issues/1880. ## Requirements - [ ] All fixes and/or new features come with corresponding tests. N/A - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- src/gt4py/_core/definitions.py | 26 +++++++------------------- src/gt4py/storage/cartesian/utils.py | 2 +- 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 8f62788b8f..ba273c75d9 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -373,37 +373,25 @@ class DeviceType(enum.IntEnum): CPU = 1 CUDA = 2 - CPU_PINNED = 3 - OPENCL = 4 - VULKAN = 7 - METAL = 8 - VPI = 9 + # CPU_PINNED = 3 # noqa: ERA001 + # OPENCL = 4 # noqa: ERA001 + # VULKAN = 7 # noqa: ERA001 + # METAL = 8 # noqa: ERA001 + # VPI = 9 # noqa: ERA001 ROCM = 10 - CUDA_MANAGED = 13 - ONE_API = 14 + # CUDA_MANAGED = 13 # noqa: ERA001 + # ONE_API = 14 # noqa: ERA001 CPUDeviceTyping: TypeAlias = Literal[DeviceType.CPU] CUDADeviceTyping: TypeAlias = Literal[DeviceType.CUDA] -CPUPinnedDeviceTyping: TypeAlias = Literal[DeviceType.CPU_PINNED] -OpenCLDeviceTyping: TypeAlias = Literal[DeviceType.OPENCL] -VulkanDeviceTyping: TypeAlias = Literal[DeviceType.VULKAN] -MetalDeviceTyping: TypeAlias = Literal[DeviceType.METAL] -VPIDeviceTyping: TypeAlias = Literal[DeviceType.VPI] ROCMDeviceTyping: TypeAlias = Literal[DeviceType.ROCM] -CUDAManagedDeviceTyping: TypeAlias = Literal[DeviceType.CUDA_MANAGED] -OneApiDeviceTyping: TypeAlias = Literal[DeviceType.ONE_API] DeviceTypeT = TypeVar( "DeviceTypeT", CPUDeviceTyping, CUDADeviceTyping, - CPUPinnedDeviceTyping, - OpenCLDeviceTyping, - VulkanDeviceTyping, - MetalDeviceTyping, - VPIDeviceTyping, ROCMDeviceTyping, ) diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index bd89c85052..2e1bfb69b5 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -199,7 +199,7 @@ def asarray( elif not device: if hasattr(array, "__dlpack_device__"): kind, _ = array.__dlpack_device__() - if kind in [core_defs.DeviceType.CPU, core_defs.DeviceType.CPU_PINNED]: + if kind in [core_defs.DeviceType.CPU]: xp = np elif kind in [ core_defs.DeviceType.CUDA, From 28eba10c44650cb8d30f7b093fed50e3dfd0b336 Mon Sep 17 00:00:00 2001 From: Tobias Wicky-Pfund Date: Tue, 4 Mar 2025 14:51:34 +0100 Subject: [PATCH 166/178] style[cartesian]: Cleanup codegen tests (#1899) ## Description This PR cleans up all the warnings that would be present in the test_code_generation. --- .../test_code_generation.py | 207 +++++++++--------- 1 file changed, 106 insertions(+), 101 deletions(-) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 4e0fa8903c..c2b82e4bac 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -27,17 +27,16 @@ ) from gt4py.storage.cartesian import utils as storage_utils -from cartesian_tests.definitions import ( - ALL_BACKENDS, - CPU_BACKENDS, - get_array_library, -) +from cartesian_tests.definitions import ALL_BACKENDS, CPU_BACKENDS, get_array_library from cartesian_tests.integration_tests.multi_feature_tests.stencil_definitions import ( EXTERNALS_REGISTRY as externals_registry, REGISTRY as stencil_definitions, ) +rng = np.random.default_rng(1337) + + @pytest.mark.parametrize("name", stencil_definitions) @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_generation(name, backend): @@ -63,15 +62,15 @@ def test_generation(name, backend): @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_lazy_stencil(backend): @gtscript.lazy_stencil(backend=backend) - def definition(field_a: gtscript.Field[np.float64], field_b: gtscript.Field[np.float64]): + def definition(field_a: Field[np.float64], field_b: Field[np.float64]): # type: ignore with computation(PARALLEL), interval(...): - field_a = field_b + field_a[0, 0, 0] = field_b @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_temporary_field_declared_in_if(backend): @gtscript.stencil(backend=backend) - def definition(field_a: gtscript.Field[np.float64]): + def definition(field_a: Field[np.float64]): # type: ignore with computation(PARALLEL), interval(...): if field_a < 0: field_b = -field_a @@ -83,9 +82,9 @@ def definition(field_a: gtscript.Field[np.float64]): @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_stage_without_effect(backend): @gtscript.stencil(backend=backend) - def definition(field_a: gtscript.Field[np.float64]): + def definition(field_a: Field[np.float64]): # type: ignore with computation(PARALLEL), interval(...): - field_c = 0.0 + field_c = 0.0 # noqa: F841 def test_ignore_np_errstate(): @@ -95,7 +94,7 @@ def setup_and_run(backend, **kwargs): ) @gtscript.stencil(backend=backend, **kwargs) - def divide_by_zero(field_a: gtscript.Field[np.float64]): + def divide_by_zero(field_a: Field[np.float64]): # type: ignore with computation(PARALLEL), interval(...): field_a = 1.0 / field_a @@ -110,16 +109,16 @@ def divide_by_zero(field_a: gtscript.Field[np.float64]): @pytest.mark.parametrize("backend", CPU_BACKENDS) def test_stencil_without_effect(backend): - def definition1(field_in: gtscript.Field[np.float64]): + def definition1(field_in: Field[np.float64]): # type: ignore with computation(PARALLEL), interval(...): - tmp = 0.0 + tmp = 0.0 # noqa: F841 - def definition2(f_in: gtscript.Field[np.float64]): - from __externals__ import flag + def definition2(f_in: Field[np.float64]): # type: ignore + from __externals__ import flag # type: ignore with computation(PARALLEL), interval(...): if __INLINED(flag): - B = f_in + B = f_in # noqa: F841 stencil1 = gtscript.stencil(backend, definition1) stencil2 = gtscript.stencil(backend, definition2, externals={"flag": False}) @@ -146,7 +145,7 @@ def test_stage_merger_induced_interval_block_reordering(backend): ) @gtscript.stencil(backend=backend) - def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.float64]): + def stencil(field_in: Field[np.float64], field_out: Field[np.float64]): # type: ignore with computation(BACKWARD): with interval(-2, -1): # block 1 field_out = field_in @@ -156,7 +155,7 @@ def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.f with interval(-1, None): # block 3 field_out = 2 * field_in with interval(0, -1): # block 4 - field_out = 3 * field_in + field_out[0, 0, 0] = 3 * field_in stencil(field_in, field_out) @@ -168,9 +167,9 @@ def stencil(field_in: gtscript.Field[np.float64], field_out: gtscript.Field[np.f def test_lower_dimensional_inputs(backend): @gtscript.stencil(backend=backend) def stencil( - field_3d: gtscript.Field[gtscript.IJK, np.float64], - field_2d: gtscript.Field[gtscript.IJ, np.float64], - field_1d: gtscript.Field[gtscript.K, np.float64], + field_3d: Field[gtscript.IJK, np.float64], # type: ignore + field_2d: Field[gtscript.IJ, np.float64], # type: ignore + field_1d: Field[gtscript.K, np.float64], # type: ignore ): with computation(PARALLEL): with interval(0, -1): @@ -182,7 +181,7 @@ def stencil( with interval(0, 1): field_3d = tmp[1, 0, 0] + field_1d[1] with interval(1, None): - field_3d = tmp[-1, 0, 0] + field_3d[0, 0, 0] = tmp[-1, 0, 0] full_shape = (6, 6, 6) aligned_index = (1, 1, 0) @@ -223,17 +222,17 @@ def stencil( def test_lower_dimensional_masked(backend): @gtscript.stencil(backend=backend) def copy_2to3( - cond: gtscript.Field[gtscript.IJK, np.float64], - inp: gtscript.Field[gtscript.IJ, np.float64], - outp: gtscript.Field[gtscript.IJK, np.float64], + cond: Field[gtscript.IJK, np.float64], # type: ignore + inp: Field[gtscript.IJ, np.float64], # type: ignore + outp: Field[gtscript.IJK, np.float64], # type: ignore ): with computation(PARALLEL), interval(...): if cond > 0.0: - outp = inp + outp[0, 0, 0] = inp - inp = np.random.randn(10, 10) - outp = np.random.randn(10, 10, 10) - cond = np.random.randn(10, 10, 10) + inp = rng.standard_normal((10, 10)) + outp = rng.standard_normal((10, 10, 10)) + cond = rng.standard_normal((10, 10, 10)) inp_f = gt_storage.from_array(inp, aligned_index=(0, 0), backend=backend) outp_f = gt_storage.from_array(outp, aligned_index=(0, 0, 0), backend=backend) @@ -254,17 +253,17 @@ def copy_2to3( def test_lower_dimensional_masked_2dcond(backend): @gtscript.stencil(backend=backend) def copy_2to3( - cond: gtscript.Field[gtscript.IJK, np.float64], - inp: gtscript.Field[gtscript.IJ, np.float64], - outp: gtscript.Field[gtscript.IJK, np.float64], + cond: Field[gtscript.IJK, np.float64], # type: ignore + inp: Field[gtscript.IJ, np.float64], # type: ignore + outp: Field[gtscript.IJK, np.float64], # type: ignore ): with computation(FORWARD), interval(...): if cond > 0.0: - outp = inp + outp[0, 0, 0] = inp - inp = np.random.randn(10, 10) - outp = np.random.randn(10, 10, 10) - cond = np.random.randn(10, 10, 10) + inp = rng.standard_normal((10, 10)) + outp = rng.standard_normal((10, 10, 10)) + cond = rng.standard_normal((10, 10, 10)) inp_f = gt_storage.from_array(inp, aligned_index=(0, 0), backend=backend) outp_f = gt_storage.from_array(outp, aligned_index=(0, 0, 0), backend=backend) @@ -285,15 +284,17 @@ def copy_2to3( def test_lower_dimensional_inputs_2d_to_3d_forward(backend): @gtscript.stencil(backend=backend) def copy_2to3( - inp: gtscript.Field[gtscript.IJ, np.float64], - outp: gtscript.Field[gtscript.IJK, np.float64], + inp: Field[gtscript.IJ, np.float64], # type: ignore + outp: Field[gtscript.IJK, np.float64], # type: ignore ): with computation(FORWARD), interval(...): outp[0, 0, 0] = inp - inp_f = gt_storage.from_array(np.random.randn(10, 10), aligned_index=(0, 0), backend=backend) + inp_f = gt_storage.from_array( + rng.standard_normal((10, 10)), aligned_index=(0, 0), backend=backend + ) outp_f = gt_storage.from_array( - np.random.randn(10, 10, 10), aligned_index=(0, 0, 0), backend=backend + rng.standard_normal((10, 10, 10)), aligned_index=(0, 0, 0), backend=backend ) copy_2to3(inp_f, outp_f) inp_f = storage_utils.cpu_copy(inp_f) @@ -308,12 +309,12 @@ def test_higher_dimensional_fields(backend): @gtscript.stencil(backend=backend) def stencil( - field: gtscript.Field[np.float64], - vec_field: gtscript.Field[FLOAT64_VEC2], - mat_field: gtscript.Field[FLOAT64_MAT22], + field: Field[np.float64], # type: ignore + vec_field: Field[FLOAT64_VEC2], # type: ignore + mat_field: Field[FLOAT64_MAT22], # type: ignore ): with computation(PARALLEL), interval(...): - tmp = vec_field[0, 0, 0][0] + vec_field[0, 0, 0][1] + tmp = vec_field[0, 0, 0][0] + vec_field[0, 0, 0][1] # noqa: F841 with computation(FORWARD): with interval(0, 1): @@ -360,45 +361,45 @@ def stencil( def test_input_order(backend): @gtscript.stencil(backend=backend) def stencil( - in_field: gtscript.Field[np.float64], + in_field: Field[np.float64], # type: ignore parameter: np.float64, - out_field: gtscript.Field[np.float64], + out_field: Field[np.float64], # type: ignore ): with computation(PARALLEL), interval(...): - out_field = in_field * parameter + out_field[0, 0, 0] = in_field * parameter @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_variable_offsets(backend): @gtscript.stencil(backend=backend) def stencil_ij( - in_field: gtscript.Field[np.float64], - out_field: gtscript.Field[np.float64], - index_field: gtscript.Field[gtscript.IJ, int], + in_field: Field[np.float64], # type: ignore + out_field: Field[np.float64], # type: ignore + index_field: Field[gtscript.IJ, int], # type: ignore ): with computation(FORWARD), interval(...): - out_field = in_field[0, 0, 1] + in_field[0, 0, index_field + 1] + out_field[0, 0, 0] = in_field[0, 0, 1] + in_field[0, 0, index_field + 1] index_field = index_field + 1 @gtscript.stencil(backend=backend) def stencil_ijk( - in_field: gtscript.Field[np.float64], - out_field: gtscript.Field[np.float64], - index_field: gtscript.Field[int], + in_field: Field[np.float64], # type: ignore + out_field: Field[np.float64], # type: ignore + index_field: Field[int], # type: ignore ): with computation(PARALLEL), interval(...): - out_field = in_field[0, 0, 1] + in_field[0, 0, index_field + 1] + out_field[0, 0, 0] = in_field[0, 0, 1] + in_field[0, 0, index_field + 1] @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_variable_offsets_and_while_loop(backend): @gtscript.stencil(backend=backend) def stencil( - pe1: gtscript.Field[np.float64], - pe2: gtscript.Field[np.float64], - qin: gtscript.Field[np.float64], - qout: gtscript.Field[np.float64], - lev: gtscript.Field[gtscript.IJ, np.int_], + pe1: Field[np.float64], # type: ignore + pe2: Field[np.float64], # type: ignore + qin: Field[np.float64], # type: ignore + qout: Field[np.float64], # type: ignore + lev: Field[gtscript.IJ, np.int_], # type: ignore ): with computation(FORWARD), interval(0, -1): if pe2[0, 0, 1] <= pe1[0, 0, lev]: @@ -408,13 +409,13 @@ def stencil( while pe1[0, 0, lev + 1] < pe2[0, 0, 1]: qsum += qin[0, 0, lev] / (pe2[0, 0, 1] - pe1[0, 0, lev]) lev = lev + 1 - qout = qsum / (pe2[0, 0, 1] - pe2) + qout[0, 0, 0] = qsum / (pe2[0, 0, 1] - pe2) @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_nested_while_loop(backend): @gtscript.stencil(backend=backend) - def stencil(field_a: gtscript.Field[np.float64], field_b: gtscript.Field[np.int_]): + def stencil(field_a: Field[np.float64], field_b: Field[np.int_]): # type: ignore with computation(PARALLEL), interval(...): while field_a < 1: add = 0 @@ -426,13 +427,13 @@ def stencil(field_a: gtscript.Field[np.float64], field_b: gtscript.Field[np.int_ @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_mask_with_offset_written_in_conditional(backend): @gtscript.stencil(backend) - def stencil(outp: gtscript.Field[np.float64]): + def stencil(outp: Field[np.float64]): # type: ignore with computation(PARALLEL), interval(...): cond = True if cond[0, -1, 0] or cond[0, 0, 0]: outp = 1.0 else: - outp = 0.0 + outp[0, 0, 0] = 0.0 outp = gt_storage.zeros( shape=(10, 10, 10), backend=backend, aligned_index=(0, 0, 0), dtype=float @@ -449,8 +450,8 @@ def test_write_data_dim_indirect_addressing(backend): INT32_VEC2 = (np.int32, (2,)) def stencil( - input_field: gtscript.Field[gtscript.IJK, np.int32], - output_field: gtscript.Field[gtscript.IJK, INT32_VEC2], + input_field: Field[gtscript.IJK, np.int32], # type: ignore + output_field: Field[gtscript.IJK, INT32_VEC2], # type: ignore index: int, ): with computation(PARALLEL), interval(...): @@ -474,12 +475,12 @@ def test_read_data_dim_indirect_addressing(backend): INT32_VEC2 = (np.int32, (2,)) def stencil( - input_field: gtscript.Field[gtscript.IJK, INT32_VEC2], - output_field: gtscript.Field[gtscript.IJK, np.int32], + input_field: Field[gtscript.IJK, INT32_VEC2], # type: ignore + output_field: Field[gtscript.IJK, np.int32], # type: ignore index: int, ): with computation(PARALLEL), interval(...): - output_field = input_field[0, 0, 0][index] + output_field[0, 0, 0] = input_field[0, 0, 0][index] aligned_index = (0, 0, 0) full_shape = (1, 1, 2) @@ -499,11 +500,11 @@ class TestNegativeOrigin: def test_negative_origin_i(self, backend): @gtscript.stencil(backend=backend) def stencil_i( - input_field: gtscript.Field[gtscript.IJK, np.int32], - output_field: gtscript.Field[gtscript.IJK, np.int32], + input_field: Field[gtscript.IJK, np.int32], # type: ignore + output_field: Field[gtscript.IJK, np.int32], # type: ignore ): with computation(PARALLEL), interval(...): - output_field = input_field[1, 0, 0] + output_field[0, 0, 0] = input_field[1, 0, 0] input_field = gt_storage.ones( backend=backend, aligned_index=(0, 0, 0), shape=(1, 1, 1), dtype=np.int32 @@ -518,11 +519,11 @@ def stencil_i( def test_negative_origin_k(self, backend): @gtscript.stencil(backend=backend) def stencil_k( - input_field: gtscript.Field[gtscript.IJK, np.int32], - output_field: gtscript.Field[gtscript.IJK, np.int32], + input_field: Field[gtscript.IJK, np.int32], # type: ignore + output_field: Field[gtscript.IJK, np.int32], # type: ignore ): with computation(PARALLEL), interval(...): - output_field = input_field[0, 0, 1] + output_field[0, 0, 0] = input_field[0, 0, 1] input_field = gt_storage.ones( backend=backend, aligned_index=(0, 0, 0), shape=(1, 1, 1), dtype=np.int32 @@ -538,9 +539,9 @@ def stencil_k( @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_origin_k_fields(backend): @gtscript.stencil(backend=backend, rebuild=True) - def k_to_ijk(outp: Field[np.float64], inp: Field[gtscript.K, np.float64]): + def k_to_ijk(outp: Field[np.float64], inp: Field[gtscript.K, np.float64]): # type: ignore with computation(PARALLEL), interval(...): - outp = inp + outp[0, 0, 0] = inp origin = {"outp": (0, 0, 1), "inp": (2,)} domain = (2, 2, 8) @@ -566,11 +567,11 @@ def k_to_ijk(outp: Field[np.float64], inp: Field[gtscript.K, np.float64]): @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_pruned_args_match(backend): @gtscript.stencil(backend=backend) - def test(out: Field[np.float64], inp: Field[np.float64]): + def test(out: Field[np.float64], inp: Field[np.float64]): # type: ignore with computation(PARALLEL), interval(...): out = 0.0 with horizontal(region[I[0] - 1, J[0] - 1]): - out = inp + out[0, 0, 0] = inp inp = gt_storage.zeros( backend=backend, aligned_index=(0, 0, 0), shape=(2, 2, 2), dtype=np.float64 @@ -600,7 +601,7 @@ def test_K_offset_write(backend): # A is untouched # B is written in K+1 and should have K_values, except for the first element (FORWARD) @gtscript.stencil(backend=backend) - def simple(A: Field[np.float64], B: Field[np.float64]): + def simple(A: Field[np.float64], B: Field[np.float64]): # type: ignore with computation(FORWARD), interval(...): B[0, 0, 1] = A @@ -619,7 +620,7 @@ def simple(A: Field[np.float64], B: Field[np.float64]): # means while A is update B will have non-updated values of A # Because of the interval, value of B[0] is 0 @gtscript.stencil(backend=backend) - def forward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): + def forward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): # type: ignore with computation(FORWARD), interval(1, None): A[0, 0, -1] = scalar B[0, 0, 0] = A @@ -640,7 +641,7 @@ def forward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): # Order of operations: BACKWARD with negative offset # means A is update B will get the updated values of A @gtscript.stencil(backend=backend) - def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): + def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): # type: ignore with computation(BACKWARD), interval(1, None): A[0, 0, -1] = scalar B[0, 0, 0] = A @@ -668,7 +669,7 @@ def test_K_offset_write_conditional(backend): K_values = arraylib.arange(start=40, stop=44) @gtscript.stencil(backend=backend) - def column_physics_conditional(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): + def column_physics_conditional(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): # type: ignore with computation(BACKWARD), interval(1, -1): if A > 0 and B > 0: A[0, 0, -1] = scalar @@ -732,9 +733,9 @@ def test_direct_datadims_index(backend): F64_VEC4 = (np.float64, (2, 2, 2, 2)) @gtscript.stencil(backend=backend) - def test(out: Field[np.float64], inp: GlobalTable[F64_VEC4]): + def test(out: Field[np.float64], inp: GlobalTable[F64_VEC4]): # type: ignore with computation(PARALLEL), interval(...): - out = inp.A[1, 0, 1, 0] + out[0, 0, 0] = inp.A[1, 0, 1, 0] inp = gt_storage.ones(backend=backend, shape=(2, 2, 2, 2), dtype=np.float64) inp[1, 0, 1, 0] = 42 @@ -751,8 +752,8 @@ def add_42(v): @gtscript.stencil(backend=backend) def test( - in_field: Field[np.float64], - out_field: Field[np.float64], + in_field: Field[np.float64], # type: ignore + out_field: Field[np.float64], # type: ignore ): with computation(PARALLEL), interval(...): count = 1 @@ -770,11 +771,12 @@ def test( def _xfail_dace_backends(param): if param.values[0].startswith("dace:"): - marks = param.marks + [ + marks = [ + *param.marks, pytest.mark.xfail( raises=ValueError, reason="Missing support in DaCe backends, see https://github.com/GridTools/gt4py/issues/1881.", - ) + ), ] # make a copy because otherwise we are operating in-place return pytest.param(*param.values, marks=marks) @@ -785,11 +787,14 @@ def _xfail_dace_backends(param): def test_cast_in_index(backend): @gtscript.stencil(backend) def cast_in_index( - in_field: Field[np.float64], i32: np.int32, i64: np.int64, out_field: Field[np.float64] + in_field: Field[np.float64], # type: ignore + i32: np.int32, + i64: np.int64, + out_field: Field[np.float64], # type: ignore ): """Simple copy stencil with forced cast in index calculation.""" with computation(PARALLEL), interval(...): - out_field = in_field[0, 0, i32 - i64] + out_field[0, 0, 0] = in_field[0, 0, i32 - i64] @pytest.mark.parametrize("backend", ALL_BACKENDS) @@ -798,15 +803,15 @@ def test_read_after_write_stencil(backend): @gtscript.stencil(backend=backend) def lagrangian_contributions( - q: Field[np.float64], - pe1: Field[np.float64], - pe2: Field[np.float64], - q4_1: Field[np.float64], - q4_2: Field[np.float64], - q4_3: Field[np.float64], - q4_4: Field[np.float64], - dp1: Field[np.float64], - lev: gtscript.Field[gtscript.IJ, np.int64], + q: Field[np.float64], # type: ignore + pe1: Field[np.float64], # type: ignore + pe2: Field[np.float64], # type: ignore + q4_1: Field[np.float64], # type: ignore + q4_2: Field[np.float64], # type: ignore + q4_3: Field[np.float64], # type: ignore + q4_4: Field[np.float64], # type: ignore + dp1: Field[np.float64], # type: ignore + lev: Field[gtscript.IJ, np.int64], # type: ignore ): """ Args: @@ -824,7 +829,7 @@ def lagrangian_contributions( pl = (pe2 - pe1[0, 0, lev]) / dp1[0, 0, lev] if pe2[0, 0, 1] <= pe1[0, 0, lev + 1]: pr = (pe2[0, 0, 1] - pe1[0, 0, lev]) / dp1[0, 0, lev] - q = ( + q[0, 0, 0] = ( q4_2[0, 0, lev] + 0.5 * (q4_4[0, 0, lev] + q4_3[0, 0, lev] - q4_2[0, 0, lev]) * (pr + pl) - q4_4[0, 0, lev] * 1.0 / 3.0 * (pr * (pr + pl) + pl * pl) From c610561e918a0edc3fe5d0fb411ac4eae08c0cf0 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 4 Mar 2025 17:18:15 +0100 Subject: [PATCH 167/178] refactor[cartesian]: gt4py/dace bridge cleanup (#1895) ## Description In preparation for PR https://github.com/GridTools/gt4py/pull/1894, pull out some refactors and cleanups. Notable in this PR are the changes to `src/gt4py/cartesian/gtc/dace/oir_to_dace.py` - visit `stencil.vertical_loops` directly instead of calling `generic_visit` (simplification since there's nothing else to visit) - rename library nodes from `f"{sdfg_name}_computation_{id(node)}"` to `f"{sdfg_name}_vloop_{counter}_{node.loop_order}_{id(node)}"`. This adds a bit more information (because `sdfg_name` is the same for all library nodes) and thus simplifies debugging workflows. Related issue: https://github.com/GEOS-ESM/NDSL/issues/53 ## Requirements - [x] All fixes and/or new features come with corresponding tests. covered by existing test suite - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- src/gt4py/cartesian/backend/base.py | 8 ++--- src/gt4py/cartesian/backend/dace_backend.py | 3 +- src/gt4py/cartesian/gtc/dace/daceir.py | 32 +++++++++---------- src/gt4py/cartesian/gtc/dace/oir_to_dace.py | 12 +++++-- .../multi_feature_tests/test_suites.py | 6 ---- 5 files changed, 31 insertions(+), 30 deletions(-) diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index 5bab0453a9..571f86b527 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -172,9 +172,9 @@ def generate_computation(self) -> Dict[str, Union[str, Dict]]: Returns ------- Dict[str, str | Dict] of source file names / directories -> contents: - If a key's value is a string it is interpreted as a file name and the value as the - source code of that file - If a key's value is a Dict, it is interpreted as a directory name and it's + If a key's value is a string, it is interpreted as a file name and its value as the + source code of that file. + If a key's value is a Dict, it is interpreted as a directory name and its value as a nested file hierarchy to which the same rules are applied recursively. The root path is relative to the build directory. @@ -222,7 +222,7 @@ def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]: Returns ------- - Analog to :py:meth:`generate_computation` but containing bindings source code, The + Analog to :py:meth:`generate_computation` but containing bindings source code. The dictionary contains a tree of directories with leaves being a mapping from filename to source code pairs, relative to the build directory. diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 8ca18705c9..a36a9824bd 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -31,6 +31,7 @@ ) from gt4py.cartesian.backend.module_generator import make_args_data_from_gtir from gt4py.cartesian.gtc import common, gtir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.nodes import StencilComputation from gt4py.cartesian.gtc.dace.oir_to_dace import OirSDFGBuilder from gt4py.cartesian.gtc.dace.transformations import ( @@ -119,8 +120,6 @@ def _set_expansion_orders(sdfg: dace.SDFG): def _set_tile_sizes(sdfg: dace.SDFG): - import gt4py.cartesian.gtc.dace.daceir as dcir # avoid circular import - for node, _ in filter( lambda n: isinstance(n[0], StencilComputation), sdfg.all_nodes_recursive() ): diff --git a/src/gt4py/cartesian/gtc/dace/daceir.py b/src/gt4py/cartesian/gtc/dace/daceir.py index 43a33fdd6d..90c0649940 100644 --- a/src/gt4py/cartesian/gtc/dace/daceir.py +++ b/src/gt4py/cartesian/gtc/dace/daceir.py @@ -51,11 +51,11 @@ def tile_symbol(self) -> eve.SymbolRef: return eve.SymbolRef("__tile_" + self.lower()) @staticmethod - def dims_3d() -> Generator["Axis", None, None]: + def dims_3d() -> Generator[Axis, None, None]: yield from [Axis.I, Axis.J, Axis.K] @staticmethod - def dims_horizontal() -> Generator["Axis", None, None]: + def dims_horizontal() -> Generator[Axis, None, None]: yield from [Axis.I, Axis.J] def to_idx(self) -> int: @@ -357,7 +357,7 @@ def free_symbols(self) -> Set[eve.SymbolRef]: class GridSubset(eve.Node): - intervals: Dict[Axis, Union[DomainInterval, TileInterval, IndexWithExtent]] + intervals: Dict[Axis, Union[DomainInterval, IndexWithExtent, TileInterval]] def __iter__(self): for axis in Axis.dims_3d(): @@ -429,10 +429,10 @@ def from_gt4py_extent(cls, extent: gt4py.cartesian.gtc.definitions.Extent): @classmethod def from_interval( cls, - interval: Union[oir.Interval, TileInterval, DomainInterval, IndexWithExtent], + interval: Union[DomainInterval, IndexWithExtent, oir.Interval, TileInterval], axis: Axis, ): - res_interval: Union[IndexWithExtent, TileInterval, DomainInterval] + res_interval: Union[DomainInterval, IndexWithExtent, TileInterval] if isinstance(interval, (DomainInterval, oir.Interval)): res_interval = DomainInterval( start=AxisBound( @@ -441,7 +441,7 @@ def from_interval( end=AxisBound(level=interval.end.level, offset=interval.end.offset, axis=Axis.K), ) else: - assert isinstance(interval, (TileInterval, IndexWithExtent)) + assert isinstance(interval, (IndexWithExtent, TileInterval)) res_interval = interval return cls(intervals={axis: res_interval}) @@ -464,7 +464,7 @@ def full_domain(cls, axes=None): return GridSubset(intervals=res_subsets) def tile(self, tile_sizes: Dict[Axis, int]): - res_intervals: Dict[Axis, Union[DomainInterval, TileInterval, IndexWithExtent]] = {} + res_intervals: Dict[Axis, Union[DomainInterval, IndexWithExtent, TileInterval]] = {} for axis, interval in self.intervals.items(): if isinstance(interval, DomainInterval) and axis in tile_sizes: if axis == Axis.K: @@ -505,15 +505,15 @@ def union(self, other): intervals[axis] = interval1.union(interval2) else: assert ( - isinstance(interval2, (TileInterval, DomainInterval)) - and isinstance(interval1, (IndexWithExtent, DomainInterval)) + isinstance(interval2, (DomainInterval, TileInterval)) + and isinstance(interval1, (DomainInterval, IndexWithExtent)) ) or ( - isinstance(interval1, (TileInterval, DomainInterval)) + isinstance(interval1, (DomainInterval, TileInterval)) and isinstance(interval2, IndexWithExtent) ) intervals[axis] = ( interval1 - if isinstance(interval1, (TileInterval, DomainInterval)) + if isinstance(interval1, (DomainInterval, TileInterval)) else interval2 ) return GridSubset(intervals=intervals) @@ -747,7 +747,7 @@ class IndexAccess(common.FieldAccess, Expr): offset: Optional[Union[common.CartesianOffset, VariableKOffset]] -class AssignStmt(common.AssignStmt[Union[ScalarAccess, IndexAccess], Expr], Stmt): +class AssignStmt(common.AssignStmt[Union[IndexAccess, ScalarAccess], Expr], Stmt): _dtype_validation = common.assign_stmt_dtype_validation(strict=True) @@ -851,14 +851,14 @@ class Tasklet(ComputationNode, IterationNode, eve.SymbolTableTrait): class DomainMap(ComputationNode, IterationNode): index_ranges: List[Range] schedule: MapSchedule - computations: List[Union[Tasklet, DomainMap, NestedSDFG]] + computations: List[Union[DomainMap, NestedSDFG, Tasklet]] class ComputationState(IterationNode): - computations: List[Union[Tasklet, DomainMap]] + computations: List[Union[DomainMap, Tasklet]] -class DomainLoop(IterationNode, ComputationNode): +class DomainLoop(ComputationNode, IterationNode): axis: Axis index_range: Range loop_states: List[Union[ComputationState, DomainLoop]] @@ -868,7 +868,7 @@ class NestedSDFG(ComputationNode, eve.SymbolTableTrait): label: eve.Coerced[eve.SymbolRef] field_decls: List[FieldDecl] symbol_decls: List[SymbolDecl] - states: List[Union[DomainLoop, ComputationState]] + states: List[Union[ComputationState, DomainLoop]] # There are circular type references with string placeholders. These statements let datamodels resolve those. diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py index 9dd66bac82..bd06da7d8f 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py @@ -41,6 +41,7 @@ class SDFGContext: decls: Dict[str, oir.Decl] block_extents: Dict[int, Extent] access_infos: Dict[str, dcir.FieldAccessInfo] + loop_counter: int = 0 def __init__(self, stencil: oir.Stencil): self.sdfg = dace.SDFG(stencil.name) @@ -98,6 +99,13 @@ def _make_dace_subset(self, local_access_info, field): global_access_info, local_access_info, self.decls[field].data_dims ) + def _vloop_name(self, node: oir.VerticalLoop, ctx: OirSDFGBuilder.SDFGContext) -> str: + sdfg_name = ctx.sdfg.name + counter = ctx.loop_counter + ctx.loop_counter += 1 + + return f"{sdfg_name}_vloop_{counter}_{node.loop_order}_{id(node)}" + def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext): declarations = { acc.name: ctx.decls[acc.name] @@ -105,7 +113,7 @@ def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFG if acc.name in ctx.decls } library_node = StencilComputation( - name=f"{ctx.sdfg.name}_computation_{id(node)}", + name=self._vloop_name(node, ctx), extents=ctx.block_extents, declarations=declarations, oir_node=node, @@ -174,6 +182,6 @@ def visit_Stencil(self, node: oir.Stencil): lifetime=dace.AllocationLifetime.Persistent, debuginfo=get_dace_debuginfo(decl), ) - self.generic_visit(node, ctx=ctx) + self.visit(node.vertical_loops, ctx=ctx) ctx.sdfg.validate() return ctx.sdfg diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py index 10d8999565..032dc3bb5e 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py @@ -7,7 +7,6 @@ # SPDX-License-Identifier: BSD-3-Clause import numpy as np -import pytest from gt4py.cartesian import gtscript, testing as gt_testing from gt4py.cartesian.gtscript import ( @@ -25,7 +24,6 @@ from .stencil_definitions import optional_field, two_optional_fields -# ---- Identity stencil ---- class TestIdentity(gt_testing.StencilTestSuite): """Identity stencil.""" @@ -43,7 +41,6 @@ def validation(field_a, domain=None, origin=None): pass -# ---- Copy stencil ---- class TestCopy(gt_testing.StencilTestSuite): """Copy stencil.""" @@ -86,7 +83,6 @@ def validation(field_a, field_b, domain=None, origin=None): field_b[...] = (field_b[...] - 1.0) / 2.0 -# ---- Scale stencil ---- class TestGlobalScale(gt_testing.StencilTestSuite): """Scale stencil using a global global_name.""" @@ -108,7 +104,6 @@ def validation(field_a, domain, origin, **kwargs): field_a[...] = SCALE_FACTOR * field_a # noqa: F821 [undefined-name] -# ---- Parametric scale stencil ----- class TestParametricScale(gt_testing.StencilTestSuite): """Scale stencil using a parameter.""" @@ -128,7 +123,6 @@ def validation(field_a, *, scale, domain, origin, **kwargs): field_a[...] = scale * field_a -# --- Parametric-mix stencil ---- class TestParametricMix(gt_testing.StencilTestSuite): """Linear combination of input fields using several parameters.""" From 098d325579fbe7ad475b42d28d26d7f65c852f23 Mon Sep 17 00:00:00 2001 From: SF-N Date: Thu, 6 Mar 2025 09:59:45 +0100 Subject: [PATCH 168/178] refactor[next]: Simplify `ir_makers.domain` (#1903) The `domain` ir maker now only accepts Dimensions, not strings. This simplifies the typing in some places and is less error prone, since one can not accidentally create a domain with the wrong kind, e.g. by using `"KDim"`. Co-authored-by: Till Ehrengruber --- .../next/iterator/ir_utils/domain_utils.py | 2 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 18 +++++------------- .../transforms_tests/test_domain_inference.py | 2 +- 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 27900b6db6..17df4f2ec5 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -79,7 +79,7 @@ def from_expr(cls, node: itir.Node) -> SymbolicDomain: return cls(node.fun.id, ranges) # type: ignore[attr-defined] # ensure by assert above def as_expr(self) -> itir.FunCall: - converted_ranges: dict[common.Dimension | str, tuple[itir.Expr, itir.Expr]] = { + converted_ranges: dict[common.Dimension, tuple[itir.Expr, itir.Expr]] = { key: (value.start, value.stop) for key, value in self.ranges.items() } return im.domain(self.grid_type, converted_ranges) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 42b82ffdd0..9d77ca4686 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -402,22 +402,14 @@ def _impl(*its: itir.Expr) -> itir.FunCall: def domain( grid_type: Union[common.GridType, str], - ranges: dict[Union[common.Dimension, str], tuple[itir.Expr, itir.Expr]], + ranges: dict[common.Dimension, tuple[itir.Expr, itir.Expr]], ) -> itir.FunCall: """ - >>> str( - ... domain( - ... common.GridType.CARTESIAN, - ... { - ... common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL): (0, 10), - ... common.Dimension(value="JDim", kind=common.DimensionKind.HORIZONTAL): (0, 20), - ... }, - ... ) - ... ) + >>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) + >>> JDim = common.Dimension(value="JDim", kind=common.DimensionKind.HORIZONTAL) + >>> str(domain(common.GridType.CARTESIAN, {IDim: (0, 10), JDim: (0, 20)})) 'c⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩' - >>> str(domain(common.GridType.CARTESIAN, {"IDim": (0, 10), "JDim": (0, 20)})) - 'c⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩' - >>> str(domain(common.GridType.UNSTRUCTURED, {"IDim": (0, 10), "JDim": (0, 20)})) + >>> str(domain(common.GridType.UNSTRUCTURED, {IDim: (0, 10), JDim: (0, 20)})) 'u⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩' """ if isinstance(grid_type, common.GridType): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 4a2a441510..86cc8a6773 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -515,7 +515,7 @@ def test_cond(offset_provider): testee = im.if_(cond, field_1, field_2) - domain = im.domain(common.GridType.CARTESIAN, {"IDim": (0, 11)}) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) domain_tmp = translate_domain(domain, {"Ioff": -1}, offset_provider) expected_domains_dict = {"in_field1": {IDim: (0, 12)}, "in_field2": {IDim: (-2, 12)}} expected_tmp2 = im.as_fieldop(tmp_stencil2, domain_tmp)( From 629a073088c62760acabde5d39f04ae68a1cd504 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 7 Mar 2025 15:25:42 +0100 Subject: [PATCH 169/178] fix[next][dace]: Update cuda codegen for concat_where (#1906) We use `CopyToMap` in CUDA lowering for copies between arrays that do not necessarily have the same strides. This happens in case of `concat_where`, where we copy a source array into a subset of the destination array. For this reason, the option `ignore_strides` must be set to `True` (`False` by default). --- .../runners/dace/transformations/gpu_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index 10ed652ec2..f4372db80f 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -297,7 +297,14 @@ def _gt_expand_non_standard_memlets_sdfg( # Turn unsupported copy to a map try: dace_transformation.dataflow.CopyToMap.apply_to( - sdfg, save=False, annotate=False, a=a, b=b + sdfg, + save=False, + annotate=False, + a=a, + b=b, + options={ + "ignore_strides": True + }, # apply 'CopyToMap' even if src/dst strides are different ) except ValueError: # If transformation doesn't match, continue normally continue From 2f4db7236c7d7e5b2f2daec382966bd1605f088f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Mon, 10 Mar 2025 14:26:55 +0100 Subject: [PATCH 170/178] feat[dace][next]: Added `CopyChainRemover` (#1901) This PR adds the `CopyChainRemover` transformation. The introduction of `concat_where` introduced a new pattern, which is essentially a chain of copies. As an example imagine the case that a domain is split into 3 subdomain. The result on the first subdomain is stored in `T1`, the one of the second in `T2` and the one for the third domain in `T3`. `T1` and `T2` are then copied into `T4`, finally `T4` together with `T3` are then copied into `T5`. This transformation will remove `T1`, `T2`, `T3` and `T4` thus the results will be written into `T5` directly. There are some limitation, if we have the pattern `(A1) -> (A2)`, then we eliminate `A1` only if: - `A1` is fully read; this is to avoid some nasty adjustments of map bounds. - There can only be one connection between `A1` and `A2`. The transformation was added twice to the simplify pass, which allows us to mitigate DaCe [issue#1959](https://github.com/spcl/dace/issues/1959). --------- Co-authored-by: edopao --- .../runners/dace/transformations/__init__.py | 4 + .../redundant_array_removers.py | 491 +++++++++++++++- .../runners/dace/transformations/simplify.py | 34 ++ .../runners/dace/transformations/utils.py | 40 +- .../test_copy_chain_remover.py | 553 ++++++++++++++++++ .../test_distributed_buffer_relocator.py | 1 - 6 files changed, 1117 insertions(+), 6 deletions(-) create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index 81ecb107cb..6157704857 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -25,9 +25,11 @@ from .map_orderer import MapIterationOrder, gt_set_iteration_order from .map_promoter import SerialMapPromoter from .redundant_array_removers import ( + CopyChainRemover, MultiStateGlobalSelfCopyElimination, SingleStateGlobalSelfCopyElimination, gt_multi_state_global_self_copy_elimination, + gt_remove_copy_chain, ) from .simplify import ( GT_SIMPLIFY_DEFAULT_SKIP_SET, @@ -50,6 +52,7 @@ __all__ = [ "GT_SIMPLIFY_DEFAULT_SKIP_SET", + "CopyChainRemover", "GPUSetBlockSize", "GT4PyMapBufferElimination", "GT4PyMoveTaskletIntoMap", @@ -76,6 +79,7 @@ "gt_propagate_strides_from_access_node", "gt_propagate_strides_of", "gt_reduce_distributed_buffering", + "gt_remove_copy_chain", "gt_set_gpu_blocksize", "gt_set_iteration_order", "gt_simplify", diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py b/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py index 5a0e117b21..f691602764 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py @@ -6,16 +6,19 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Optional +from typing import Any, Optional, Sequence import dace from dace import ( data as dace_data, properties as dace_properties, + subsets as dace_sbs, + symbolic as dace_sym, transformation as dace_transformation, ) -from dace.sdfg import nodes as dace_nodes +from dace.sdfg import graph as dace_graph, nodes as dace_nodes from dace.transformation import pass_pipeline as dace_ppl +from dace.transformation.passes import analysis as dace_analysis from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations @@ -39,6 +42,36 @@ def gt_multi_state_global_self_copy_elimination( return res["MultiStateGlobalSelfCopyElimination"][sdfg] +def gt_remove_copy_chain( + sdfg: dace.SDFG, + validate: bool = False, + validate_all: bool = False, + single_use_data: Optional[dict[dace.SDFG, set[str]]] = None, +) -> Optional[int]: + """Applies the `CopyChainRemover` transformation to the SDFG. + + The transformation returns the number of removed data containers or `None` + if nothing was done. + + Args: + sdfg: The SDFG to process. + validate: Perform validation after the pass has run. + validate_all: Perform extensive validation. + single_use_data: Which data descriptors are used only once. + If not passed the function will run `FindSingleUseData`. + """ + if single_use_data is None: + find_single_use_data = dace_analysis.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + + result: int = sdfg.apply_transformations_repeated( + CopyChainRemover(single_use_data=single_use_data), + validate=validate, + validate_all=validate_all, + ) + return result if result != 0 else None + + @dace_properties.make_properties class MultiStateGlobalSelfCopyElimination(dace_transformation.Pass): """Removes self copying across different states. @@ -562,3 +595,457 @@ def apply( except ValueError as e: if not str(e).startswith(f"Cannot remove data descriptor {tmp_node.data}:"): raise + + +@dace_properties.make_properties +class CopyChainRemover(dace_transformation.SingleStateTransformation): + """Removes chain of redundant copies, mostly related to `concat_where`. + + `concat_where`, especially when nested, will build "chains" of AccessNodes, + this transformation will remove them. It should be called repeatedly until a + fix point is reached and should be seen as an addition to the array removal passes + that ship with DaCe. + The transformation will look for the pattern `(A1) -> (A2)`, i.e. a data container + is copied into another one. The transformation will then remove `A1` and rewire + the edges such that they now refer to `A2`. Another, and probably better way, is to + consider the transformation as fusion transformation for AccessNodes. + + The transformation builds on ADR-18 and imposes the following additional + requirements before it can be applied: + - Through the merging of `A1` and `A2` no cycles are created. + - `A1` can not be used anywhere else. + - `A1` is fully read by `A2`. + - `A1` is a transient and must have the same dimensionality than `A2`. + + Notes: + - The transformation assumes that the domain inference adjusted the ranges of + the maps such that, in case they write into a transient, the full shape of the transient array is written. + has the same size, i.e. there is not padding, or data that is not written + to. + + Args: + single_use_data: List of data containers that are used only at one place. + Will be stored internally and not updated. + + Todo: + - Extend such that not the full array must be read. + - Try to allow more than one connection between `A1` and `A2`. + - Modify it such that also `A2` can be removed. + """ + + node_a1 = dace_transformation.PatternNode(dace_nodes.AccessNode) + node_a2 = dace_transformation.PatternNode(dace_nodes.AccessNode) + + # Name of all data that is used at only one place. Is computed by the + # `FindSingleUseData` pass and be passed at construction time. Needed until + # [issue#1911](https://github.com/spcl/dace/issues/1911) has been solved. + _single_use_data: dict[dace.SDFG, set[str]] + + def __init__( + self, + *args: Any, + single_use_data: dict[dace.SDFG, set[str]], + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self._single_use_data = single_use_data + + @classmethod + def expressions(cls) -> Any: + return [ + dace.sdfg.utils.node_path_graph( + cls.node_a1, + cls.node_a2, + ) + ] + + def can_be_applied( + self, + graph: dace.SDFGState, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + a1: dace_nodes.AccessNode = self.node_a1 + a2: dace_nodes.AccessNode = self.node_a2 + + a1_desc = a1.desc(sdfg) + a2_desc = a2.desc(sdfg) + + # We remove `a1` so it must be a transient and used only once. + if not a1_desc.transient: + return False + if not self.is_single_use_data(sdfg, a1): + return False + + # This avoids that we have to modify the subsets in a fancy way. + if len(a1_desc.shape) != len(a2_desc.shape): + return False + + # For simplicity we assume that neither of `a1` nor `a2` are views. + # TODO(phimuell): Implement some of the cases. + if gtx_transformations.utils.is_view(a1_desc, None): + return False + if gtx_transformations.utils.is_view(a2_desc, None): + return False + + # We only allow that we operate on the top level scope. + if graph.scope_dict()[a1] is not None: + return False + + # TODO(phimuell): Relax this to only prevent host-device copies. + if a1_desc.storage != a2_desc.storage: + return False + + # There shall only be one edge connecting `a1` and `a2`. + # We even strengthen this requirement by not checking for the node `a2`, + # but for the data. + connecting_edges = [ + oedge + for oedge in graph.out_edges(a1) + if isinstance(oedge.dst, dace_nodes.AccessNode) and (oedge.dst.data == a2.data) + ] + if len(connecting_edges) != 1: + return False + + # The full array `a1` is copied into `a2`. Note that it is allowed, that + # `a2` is bigger than `a1`, it is just important that everything that was + # written into `a1` is also accessed. + connecting_edge = connecting_edges[0] + assert connecting_edge.dst is a2 + connecting_memlet = connecting_edge.data + + # If the destination or the source subset of the connection is not fully + # specified, we do not apply. + src_subset = connecting_memlet.get_src_subset(connecting_edge, graph) + if src_subset is None: + return False + dst_subset = connecting_memlet.get_dst_subset(connecting_edge, graph) + if dst_subset is None: + return False + + # NOTE: The main benefit of requiring that the whole array is read is + # that we do not have to adjust maps. + a1_range = dace_sbs.Range.from_array(a1_desc) + if not src_subset.covers(a1_range): + return False + + # We have to ensure that no cycle is created through the removal of `a1`. + # For this we have to ensure that there is no connection, beside the direct + # one between `a1` and `a2`. + # NOTE: We only check the outgoing edges of `a1`, it is not needed to also + # check the incoming edges, because this will not create a cycle. + if gtx_transformations.utils.is_reachable( + start=[oedge.dst for oedge in graph.out_edges(a1) if oedge.dst is not a2], + target=a2, + state=graph, + ): + return False + + # NOTE: In case `a2` is a non transient we do not have to check if it is read + # or written to somewhere else in this state. The reason is that ADR18 + # guarantees us that everything is point wise, therefore `a1` is never + # used as double buffer. + return True + + def is_single_use_data( + self, + sdfg: dace.SDFG, + data: str | dace_nodes.AccessNode, + ) -> bool: + """Checks if `data` is a single use data.""" + assert sdfg in self._single_use_data + if isinstance(data, dace_nodes.AccessNode): + data = data.data + return data in self._single_use_data[sdfg] + + def apply( + self, + graph: dace.SDFGState | dace.SDFG, + sdfg: dace.SDFG, + ) -> None: + a1: dace_nodes.AccessNode = self.node_a1 + a2: dace_nodes.AccessNode = self.node_a2 + a1_to_a2_edge: dace_graph.MultiConnectorEdge = next( + oedge for oedge in graph.out_edges(a1) if oedge.dst is a2 + ) + a1_to_a2_memlet: dace.Memlet = a1_to_a2_edge.data + a1_to_a2_dst_subset: dace_sbs.Range = a1_to_a2_memlet.get_dst_subset(a1_to_a2_edge, graph) + + # Note that it is possible that `a1` is connected to the same node multiple + # times, although through different edges. We have to modify the data + # flow there, since the offsets and the data have changed. However, we must + # do this only once. Note that only matching the node is not enough, a + # counter example would be a Map with different connector names. + reconfigured_neighbour: set[tuple[dace_nodes.Node, Optional[str]]] = set() + + # Now we compose the new subset. + # We build on the fact that we have ensured that the whole array `a1` is + # copied into `a2`. Thus the destination of the original source, i.e. + # whatever write into `a1`, is just offset by the beginning of the range + # `a1` writes into `a2`. + # (s1) ------[c:d]-> (A1) -[0:N]------[a:b]-> (A2) + # (s1) ---------[(a + c):(a + c + (d - c))]-> (A2) + # Thus the offset is simply given by `a`, the start where `a1` is written into + # `a2`. + # NOTE: If we ever allow the that `a1` is not fully read, then we would have + # to modify this computation slightly. + a2_offsets: Sequence[dace_sym.SymExpr] = a1_to_a2_dst_subset.min_element() + + # Handle the producer side of things. + for producer_edge in list(graph.in_edges(a1)): + producer: dace_nodes.Node = producer_edge.src + producer_conn = producer_edge.src_conn + new_producer_edge = self._reroute_edge( + is_producer_edge=True, + current_edge=producer_edge, + a2_offsets=a2_offsets, + state=graph, + sdfg=sdfg, + a1=a1, + a2=a2, + ) + if (producer, producer_conn) not in reconfigured_neighbour: + self._reconfigure_dataflow( + is_producer_edge=True, + new_edge=new_producer_edge, + sdfg=sdfg, + state=graph, + a2_offsets=a2_offsets, + a1=a1, + a2=a2, + ) + reconfigured_neighbour.add((producer, producer_conn)) + + # Handle the consumer side of things, as they now have to read from `a2`. + # It is important that the offset is still the same. + for consumer_edge in list(graph.out_edges(a1)): + consumer: dace_nodes.Node = consumer_edge.dst + consumer_conn = consumer_edge.dst_conn + if consumer is a2: + assert consumer_edge is a1_to_a2_edge + continue + new_consumer_edge = self._reroute_edge( + is_producer_edge=False, + current_edge=consumer_edge, + a2_offsets=a2_offsets, + state=graph, + sdfg=sdfg, + a1=a1, + a2=a2, + ) + if (consumer, consumer_conn) not in reconfigured_neighbour: + self._reconfigure_dataflow( + is_producer_edge=False, + new_edge=new_consumer_edge, + sdfg=sdfg, + state=graph, + a2_offsets=a2_offsets, + a1=a1, + a2=a2, + ) + reconfigured_neighbour.add((consumer, consumer_conn)) + + # After the rerouting we have to delete the `a1` data node and descriptor, + # this will also remove all the old edges. + graph.remove_node(a1) + sdfg.remove_data(a1.data, validate=False) + + # We will now propagate the strides starting from the access nodes `a2`. + # Essentially, this will replace the strides from `a1` with the ones of + # `a2`. We do it outside to make sure that we do not forget a case and + # that we propagate the change into every NestedSDFG only once. + gtx_transformations.gt_propagate_strides_from_access_node( + sdfg=sdfg, + state=graph, + outer_node=a2, + ) + + def _reroute_edge( + self, + is_producer_edge: bool, + current_edge: dace_graph.MultiConnectorEdge, + a2_offsets: Sequence[dace_sym.SymExpr], + state: dace.SDFGState, + sdfg: dace.SDFG, + a1: dace_nodes.AccessNode, + a2: dace_nodes.AccessNode, + ) -> dace_graph.MultiConnectorEdge: + """Performs the rerouting of the edge. + + Essentially the function creates new edges that account for the fact that + `a1` will be replaced with `a2`. Depending on the value of `is_producer_edge` + the behaviour is slightly different. + + If `is_producer_edge` is `True` then the function assumes that `current_edge` + ends at `a1`. It will then create a new edge that has the same start and a + similar Memlet but ends at `a2`. + If `is_producer_edge` is `False` then the function assumes that `current_edge` + starts at `a1`. It will then create a new edge that starts at `a2` but has the + same destination and a similar Memlet. + In both cases the Memlet and the corresponding subset, will be modified such + that they account that `a1` was replaced with `a2`. + + It is important that the function will **not** do the following things: + - Remove the old edge, i.e. `producer_edge`. + - Modify the data flow at the other side of the edge. + + The function returns the new edge. + + Args: + is_producer_edge: Indicates how to interpret `current_edge`. + current_edge: The current edge that should be replaced. + a2_offsets: Offset that describes how much to shift writes and reads, + that were previously associated with `a1`. + state: The state in which we operate. + sdfg: The SDFG on which we operate on. + a1: The `a1` node. + a2: The `a2` node. + + """ + current_memlet: dace.Memlet = current_edge.data + if is_producer_edge: + current_subset: dace_sbs.Range = current_memlet.get_dst_subset(current_edge, state) + new_src = current_edge.src + new_src_conn = current_edge._src_conn + new_dst = a2 + new_dst_conn = None + assert current_edge.dst_conn is None + else: + current_subset = current_memlet.get_src_subset(current_edge, state) + new_src = a2 + new_src_conn = None + new_dst = current_edge.dst + new_dst_conn = current_edge.dst_conn + assert current_edge.src_conn is None + + # If the subset we care about, which is always on the `a1` side, was not + # specified we assume that the whole `a1` has been written. + # TODO(edopao): Fix lowering that this does not happens, it happens for example + # in `tests/next_tests/integration_tests/feature_tests/ffront_tests/ + # test_execution.py::test_docstring`. + if current_subset is None: + current_subset = dace_sbs.Range.from_array(a1.desc(sdfg)) + + # This is the new Memlet, that we will use. We copy it from the original + # Memlet and modify it later. + new_memlet: dace.Memlet = dace.Memlet.from_memlet(current_memlet) + + # Because we operate on the `subset` and `other_subset` properties directly + # we do not need to distinguish between the different directions. Also + # in both cases the offset is the same. + if new_memlet.data == a1.data: + new_memlet.data = a2.data + new_subset = current_subset.offset_new(a2_offsets, negative=False) + new_memlet.subset = new_subset + else: + new_subset = current_subset.offset_new(a2_offsets, negative=False) + new_memlet.other_subset = new_subset + + new_edge = state.add_edge( + new_src, + new_src_conn, + new_dst, + new_dst_conn, + new_memlet, + ) + assert ( # Ensure that the edge has the right direction. + new_subset is new_edge.data.dst_subset + if is_producer_edge + else new_subset is new_edge.data.src_subset + ) + return new_edge + + def _reconfigure_dataflow( + self, + is_producer_edge: bool, + new_edge: dace_graph.MultiConnectorEdge, + a2_offsets: Sequence[dace_sym.SymExpr], + state: dace.SDFGState, + sdfg: dace.SDFG, + a1: dace_nodes.AccessNode, + a2: dace_nodes.AccessNode, + ) -> None: + """Modify the data flow associated to `new_edge`. + + The `_reroute_edge()` function creates a new edge, but it does not modify + the data flow at the other side, of the connection, this is done by this + function. + + Depending on the value of `is_producer_edge` the function will either modify + the source of `new_edge` (`True`) or it will modify the data flow associated + to the destination of `new_edge` (`False`). + Furthermore, the specific actions depends on what kind of node is on the other + side. However, essentially the function will modify it to account for the + change from `a1` to `a2`. + + It is important that it is the caller's responsibility to ensure that this + function is not called multiple times on the same producer target. + + It is important that this function will not propagate the new strides. This + must be done from the outside. + + Args: + is_producer_edge: If `True` then the source of `new_edge` is processed, + if `False` then the destination part of `new_edge` is processed. + new_edge: The newly created edge, essentially the return value of + `self._reroute__edge()`. + a2_offsets: Offset that describes how much to shift subsets associated + to `a1` to account that they are now associated to `a2`. + state: The state in which we operate. + sdfg: The SDFG on which we operate. + a1: The `a1` node. + a2: The `a2` node. + """ + other_node = new_edge.src if is_producer_edge else new_edge.dst + + if isinstance(other_node, dace_nodes.AccessNode): + # There is nothing here to do. + pass + + elif isinstance(other_node, dace_nodes.Tasklet): + # A very obscure case, but I think it might happen, but as in the AccessNode + # case there is nothing to do here. + pass + + elif isinstance(other_node, (dace_nodes.MapExit | dace_nodes.MapEntry)): + # Essentially, we have to propagate the change that everything that + # refers to `a1` should now refer to `a2`, In addition we also have to + # modify the subsets, depending on the direction of the new edge either + # the source or destination subset. + # NOTE: Because we assume that `a1` is read fully into `a2` we do not + # have to adjust the ranges of the Map. If we would drop this assumption + # then we would have to modify the ranges such that only the ranges we + # need are computed. + # NOTE: Also for this case we have to propagate the strides, for the case + # that a NestedSDFG is inside the map, but this is done externally. + assert ( + isinstance(other_node, dace_nodes.MapExit) + if is_producer_edge + else isinstance(other_node, dace_nodes.MapEntry) + ) + for memlet_tree in state.memlet_tree(new_edge).traverse_children(include_self=False): + edge_to_adjust = memlet_tree.edge + memlet_to_adjust = edge_to_adjust.data + if memlet_to_adjust.data == a1.data: + memlet_to_adjust.data = a2.data + + if is_producer_edge: + subset_to_adjust = memlet_to_adjust.get_dst_subset(edge_to_adjust, state) + else: + subset_to_adjust = memlet_to_adjust.get_src_subset(edge_to_adjust, state) + assert subset_to_adjust is not None + subset_to_adjust.offset(a2_offsets, negative=False) + + elif isinstance(other_node, dace_nodes.NestedSDFG): + # We have obviously to adjust the strides, however, this is done outside + # this function. + # TODO(phimuell): Look into the implication that we not necessarily pass + # the full array, but essentially slice a bit. + pass + + else: + # As we encounter them we should handle them case by case. + raise NotImplementedError( + f"The case for '{type(other_node).__name__}' has not been implemented." + ) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py index 1c2541ed99..c2c5acf05f 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py @@ -63,6 +63,8 @@ def gt_simplify( `SingleStateGlobalSelfCopyElimination`, with the exception that the write to `T`, i.e. `(G) -> (T)` and the write back to `G`, i.e. `(T) -> (G)` might be in different states. + - `CopyChainRemover`: Which removes some chains that are introduced by the + `concat_where` built-in function. Furthermore, by default, or if `None` is passed for `skip` the passes listed in `GT_SIMPLIFY_DEFAULT_SKIP_SET` will be skipped. @@ -90,6 +92,24 @@ def gt_simplify( while at_least_one_xtrans_run: at_least_one_xtrans_run = False + # NOTE: See comment in `gt_inline_nested_sdfg()` for more. + sdfg.reset_cfg_list() + + # To mitigate DaCe issue 1959, we run the chain removal transformation here. + # TODO(phimuell): Remove as soon as we have a true solution. + if "CopyChainRemover" not in skip: + copy_chain_remover_result = gtx_transformations.gt_remove_copy_chain( + sdfg=sdfg, + validate=validate, + validate_all=validate_all, + ) + if copy_chain_remover_result is not None: + at_least_one_xtrans_run = True + result = result or {} + if "CopyChainRemover" not in result: + result["CopyChainRemover"] = 0 + result["CopyChainRemover"] += copy_chain_remover_result + if "InlineSDFGs" not in skip: inline_res = gt_inline_nested_sdfg( sdfg=sdfg, @@ -115,6 +135,20 @@ def gt_simplify( result = result or {} result.update(simplify_res) + # This is the place were we actually want to apply the chain removal. + if "CopyChainRemover" not in skip: + copy_chain_remover_result = gtx_transformations.gt_remove_copy_chain( + sdfg=sdfg, + validate=validate, + validate_all=validate_all, + ) + if copy_chain_remover_result is not None: + at_least_one_xtrans_run = True + result = result or {} + if "CopyChainRemover" not in result: + result["CopyChainRemover"] = 0 + result["CopyChainRemover"] += copy_chain_remover_result + if "SingleStateGlobalSelfCopyElimination" not in skip: self_copy_removal_result = sdfg.apply_transformations_repeated( gtx_transformations.SingleStateGlobalSelfCopyElimination(), diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py index d315f99264..7afb93d5c2 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -8,7 +8,7 @@ """Common functionality for the transformations/optimization pipeline.""" -from typing import Any, Container, Optional, Union +from typing import Any, Container, Optional, Sequence, Union import dace from dace import data as dace_data @@ -232,12 +232,46 @@ def is_accessed_downstream( return False +def is_reachable( + start: Union[dace_nodes.Node, Sequence[dace_nodes.Node]], + target: Union[dace_nodes.Node, Sequence[dace_nodes.Node]], + state: dace.SDFGState, +) -> bool: + """Explores the graph from `start` and checks if `target` is reachable. + + The exploration of the graph is done in a way that ignores the connector names. + It is possible to pass multiple start nodes and targets. In case of multiple target nodes, the function returns True if any of them is reachable. + + Args: + start: The node from where to start. + target: The node to look for. + state: The SDFG state on which we operate. + """ + to_visit: list[dace_nodes.Node] = [start] if isinstance(start, dace_nodes.Node) else list(start) + targets: set[dace_nodes.Node] = {target} if isinstance(target, dace_nodes.Node) else set(target) + seen: set[dace_nodes.Node] = set() + + while to_visit: + node = to_visit.pop() + if node in targets: + return True + seen.add(node) + to_visit.extend(oedge.dst for oedge in state.out_edges(node) if oedge.dst not in seen) + + return False + + def is_view( node: Union[dace_nodes.AccessNode, dace_data.Data], - sdfg: dace.SDFG, + sdfg: Optional[dace.SDFG] = None, ) -> bool: """Tests if `node` points to a view or not.""" - node_desc: dace_data.Data = node.desc(sdfg) if isinstance(node, dace_nodes.AccessNode) else node + if isinstance(node, dace_nodes.AccessNode): + assert sdfg is not None + node_desc = node.desc(sdfg) + else: + assert isinstance(node, dace_data.Data) + node_desc = node return isinstance(node_desc, dace_data.View) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py new file mode 100644 index 0000000000..8251352e49 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py @@ -0,0 +1,553 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import pytest +import copy +import numpy as np + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + + +def _make_simple_linear_chain_sdfg() -> dace.SDFG: + """Creates a simple linear chain. + + All intermediates have the same size. + """ + sdfg = dace.SDFG(util.unique_name("simple_linear_chain_sdfg")) + + for name in ["a", "b", "c", "d", "e"]: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=True, + ) + sdfg.arrays["a"].transient = False + sdfg.arrays["e"].transient = False + + state = sdfg.add_state(is_start_block=True) + b, c, d, e = (state.add_access(name) for name in "bcde") + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in", + outputs={"__out": dace.Memlet("b[__i]")}, + output_nodes={b}, + external_edges=True, + ) + state.add_nedge(b, c, dace.Memlet("b[0:10] -> [0:10]")) + state.add_nedge(c, d, dace.Memlet("c[0:10] -> [0:10]")) + state.add_nedge(d, e, dace.Memlet("d[0:10] -> [0:10]")) + sdfg.validate() + return sdfg + + +def _make_diff_sizes_linear_chain_sdfg() -> ( + tuple[dace.SDFG, dace.SDFGState, dace_nodes.AccessNode, dace_nodes.Tasklet] +): + """Creates a linear chain of copies. + + The main differences compared to the SDFG made by `_make_simple_linear_chain_sdfg()` + is that here the intermediate arrays have different sizes, that become bigger. + It essentially checks the adjusting of the memlet subset during copying. + + The function returns a tuple with the following content. + - The SDFG that was generated. + - The SDFG state. + - The AccessNode that is used as final output, refers to `e`. + - The Tasklet that is within the Map. + """ + sdfg = dace.SDFG(util.unique_name("diff_size_linear_chain_sdfg")) + + array_size_increment = 10 + array_size = 10 + for name in ["a", "b", "c", "d", "e"]: + sdfg.add_array( + name, + shape=(array_size,), + dtype=dace.float64, + transient=True, + ) + array_size += array_size_increment + sdfg.arrays["a"].transient = False + sdfg.arrays["e"].transient = False + assert sdfg.arrays["e"].shape[0] == 50 + + state = sdfg.add_state(is_start_block=True) + b, c, d, e = (state.add_access(name) for name in "bcde") + + tasklet, _, _ = state.add_mapped_tasklet( + "comp1", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in", + outputs={"__out": dace.Memlet("b[__i + 3]")}, + output_nodes={b}, + external_edges=True, + ) + state.add_nedge(b, c, dace.Memlet("b[0:20] -> [10:30]")) + state.add_nedge(c, d, dace.Memlet("c[0:30] -> [2:32]")) + state.add_nedge(d, e, dace.Memlet("d[0:40] -> [3:43]")) + sdfg.validate() + return sdfg, state, e, tasklet + + +def _make_multi_stage_reduction_sdfg() -> dace.SDFG: + """Creates an SDFG that has a two stage copy reduction.""" + sdfg = dace.SDFG(util.unique_name("multi_stage_reduction")) + state: dace.SDFGState = sdfg.add_state(is_start_block=True) + + # This is the size of the arrays, if not mentioned here, then its size is 10. + array_sizes: dict[str, int] = {"d": 20, "f": 40, "o1": 40} + def_array_size = 10 + + array_names: list[str] = ["i1", "i2", "i3", "i4", "a", "b", "c", "d", "e", "f", "o1"] + for name in array_names: + sdfg.add_array( + name, + shape=(array_sizes.get(name, def_array_size),), + dtype=dace.float64, + transient=(len(name) == 1), + ) + + a, b, c, d, e, f = (state.add_access(name) for name in "abcdef") + + state.add_mapped_tasklet( + "comp_i1", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("i1[__i]")}, + code="__out = __in + 1.0", + outputs={"__out": dace.Memlet("a[__i]")}, + output_nodes={a}, + external_edges=True, + ) + state.add_mapped_tasklet( + "comp_i2", + map_ranges={"__j": "0:10"}, + inputs={"__in": dace.Memlet("i2[__j]")}, + code="__out = __in + 2.", + outputs={"__out": dace.Memlet("b[__j]")}, + output_nodes={b}, + external_edges=True, + ) + state.add_mapped_tasklet( + "comp_i3", + map_ranges={"__k": "0:10"}, + inputs={"__in": dace.Memlet("i3[__k]")}, + code="__out = __in + 3.", + outputs={"__out": dace.Memlet("c[__k]")}, + output_nodes={c}, + external_edges=True, + ) + + state.add_nedge(state.add_access("i4"), e, dace.Memlet("i4[0:10] -> [0:10]")) + + state.add_nedge(b, d, dace.Memlet("b[0:10] -> [0:10]")) + state.add_nedge(c, d, dace.Memlet("c[0:10] -> [10:20]")) + + state.add_nedge(a, f, dace.Memlet("a[0:10] -> [0:10]")) + state.add_nedge(d, f, dace.Memlet("d[0:20] -> [10:30]")) + state.add_nedge(e, f, dace.Memlet("e[0:10] -> [30:40]")) + + state.add_nedge(f, state.add_access("o1"), dace.Memlet("f[0:40] -> [0:40]")) + + sdfg.validate() + return sdfg + + +def _make_not_fully_copied() -> dace.SDFG: + """ + Make an SDFG where two intermediate array is not fully copied. Thus the + transformation only applies once, when `d` is removed. + """ + sdfg = dace.SDFG(util.unique_name("not_fully_copied_intermediate")) + + for name in ["a", "b", "c", "d", "e"]: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=True, + ) + sdfg.arrays["a"].transient = False + sdfg.arrays["e"].transient = False + + state = sdfg.add_state(is_start_block=True) + b, c, d, e = (state.add_access(name) for name in "bcde") + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in", + outputs={"__out": dace.Memlet("b[__i]")}, + output_nodes={b}, + external_edges=True, + ) + state.add_nedge(b, c, dace.Memlet("b[2:10] -> [0:8]")) + state.add_nedge(c, d, dace.Memlet("c[0:8] -> [0:8]")) + state.add_nedge(d, e, dace.Memlet("d[0:10] -> [0:10]")) + sdfg.validate() + return sdfg + + +def _make_possible_cyclic_sdfg() -> dace.SDFG: + """ + If the transformation would remove `a1` then it would create a cycle. Thus the + transformation should not apply. + """ + sdfg = dace.SDFG(util.unique_name("possible_cyclic_sdfg")) + + anames = ["i1", "a1", "a2", "o1"] + for name in anames: + sdfg.add_array( + name, + shape=((30,) if name in ["o1", "a2"] else (10,)), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["a1"].transient = True + sdfg.arrays["a2"].transient = True + + state = sdfg.add_state(is_start_block=True) + i1, a1, a2, o1 = (state.add_access(name) for name in anames) + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("i1[__i]")}, + code="__out = __in + 1", + outputs={"__out": dace.Memlet("a2[__i]")}, + input_nodes={i1}, + output_nodes={a2}, + external_edges=True, + ) + + state.add_nedge(i1, a1, dace.Memlet("i1[0:10] -> [0:10]")) + state.add_nedge(a1, a2, dace.Memlet("a1[0:10] -> [10:20]")) + + state.add_mapped_tasklet( + "comp2", + map_ranges={"__j": "0:10"}, + inputs={"__in": dace.Memlet("a1[__j]")}, + code="__out = __in + 2.0", + outputs={"__out": dace.Memlet("a2[__j + 20]")}, + input_nodes={a1}, + output_nodes={a2}, + external_edges=True, + ) + + state.add_nedge(a2, o1, dace.Memlet("a2[0:30] -> [0:30]")) + + sdfg.validate() + return sdfg + + +def _make_linear_chain_with_nested_sdfg_sdfg() -> tuple[dace.SDFG, dace.SDFG]: + """ + The structure is very similar than `_make_diff_sizes_linear_chain_sdfg()`, with + the main difference that the Map is inside a NestedSDFG. + """ + + def make_inner_sdfg() -> dace.SDFG: + inner_sdfg = dace.SDFG("inner_sdfg") + inner_state = inner_sdfg.add_state(is_start_block=True) + for name in ["i0", "o0"]: + inner_sdfg.add_array(name=name, shape=(10, 10), dtype=dace.float64, transient=False) + inner_state.add_mapped_tasklet( + "inner_comp", + map_ranges={ + "__i0": "0:10", + "__i1": "0:10", + }, + inputs={"__in": dace.Memlet("i0[__i0, __i1]")}, + code="__out = __in + 10.", + outputs={"__out": dace.Memlet("o0[__i0, __i1]")}, + external_edges=True, + ) + inner_sdfg.validate() + return inner_sdfg + + inner_sdfg = make_inner_sdfg() + + sdfg = dace.SDFG(util.unique_name("linear_chain_with_nested_sdfg")) + state = sdfg.add_state(is_start_block=True) + + array_size_increment = 10 + array_size = 10 + for name in ["a", "b", "c", "d", "e"]: + sdfg.add_array( + name, + shape=(array_size, array_size), + dtype=dace.float64, + transient=True, + ) + if name != "a": + array_size += array_size_increment + assert sdfg.arrays["a"].shape == sdfg.arrays["b"].shape + assert sdfg.arrays["e"].shape == (40, 40) + sdfg.arrays["a"].transient = False + sdfg.arrays["e"].transient = False + a, b, c, d, e = (state.add_access(name) for name in "abcde") + + nsdfg = state.add_nested_sdfg( + inner_sdfg, + parent=sdfg, + inputs={"i0"}, + outputs={"o0"}, + symbol_mapping={}, + ) + + state.add_edge(a, None, nsdfg, "i0", sdfg.make_array_memlet("a")) + state.add_edge(nsdfg, "o0", b, None, sdfg.make_array_memlet("b")) + + state.add_nedge(b, c, dace.Memlet("b[0:10, 0:10] -> [5:15, 3:13]")) + state.add_nedge(c, d, dace.Memlet("c[0:20, 0:20] -> [2:22, 6:26]")) + state.add_nedge(d, e, dace.Memlet("d[0:30, 0:30] -> [1:31, 8:38]")) + sdfg.validate() + return sdfg, inner_sdfg + + +def _make_a1_has_output_sdfg() -> dace.SDFG: + """Here `a1` has an output degree of 2, one to `a2` and one to another output.""" + sdfg = dace.SDFG(util.unique_name("a1_has_an_additional_output_sdfg")) + state = sdfg.add_state(is_start_block=True) + + # All other arrays have a size of 10. + anames = ["i1", "i2", "i3", "a1", "a2", "o1", "o2"] + def_array_size = 10 + asizes = {"a1": 20, "a2": 30, "o2": 30} + for name in anames: + sdfg.add_array( + name=name, + shape=(asizes.get(name, def_array_size),), + dtype=dace.float64, + transient=name[0] == "a", + ) + a1, a2 = (state.add_access("a1"), state.add_access("a2")) + + state.add_nedge(state.add_access("i1"), a1, dace.Memlet("i1[0:10] -> [0:10]")) + state.add_nedge(state.add_access("i2"), a1, dace.Memlet("i2[0:10] -> [10:20]")) + + state.add_nedge(a1, state.add_access("o1"), dace.Memlet("a1[5:15] -> [0:10]")) + + state.add_nedge(state.add_access("i3"), a2, dace.Memlet("i3[0:10] -> [0:10]")) + state.add_nedge(a1, a2, dace.Memlet("a1[0:20] -> [10:30]")) + + state.add_nedge(a2, state.add_access("o2"), dace.Memlet("a2[0:30] -> [0:30]")) + + sdfg.validate() + return sdfg + + +def test_simple_linear_chain(): + sdfg = _make_simple_linear_chain_sdfg() + + nb_applies = gtx_transformations.gt_remove_copy_chain(sdfg, validate_all=True) + + acnodes: list[dace_nodes.AccessNode] = util.count_nodes( + sdfg, dace_nodes.AccessNode, return_nodes=True + ) + assert len(acnodes) == 2 + assert not any(ac.desc(sdfg).transient for ac in acnodes) + assert nb_applies == 3 + + +def test_diff_size_linear_chain(): + sdfg, state, output, tasklet = _make_diff_sizes_linear_chain_sdfg() + + nb_applies = gtx_transformations.gt_remove_copy_chain(sdfg, validate_all=True) + + acnodes: list[dace_nodes.AccessNode] = util.count_nodes( + sdfg, dace_nodes.AccessNode, return_nodes=True + ) + assert len(acnodes) == 2 + assert not any(ac.desc(sdfg).transient for ac in acnodes) + assert nb_applies == 3 + assert output in acnodes + assert state.in_degree(output) == 1 + assert state.out_degree(output) == 0 + + # Look if the subsets were correctly adapted, for that we look at the output + # AccessNode and the tasklet inside the map. + output_memlet: dace.Memlet = next(iter(state.in_edges(output))).data + assert output_memlet.dst_subset.min_element()[0] == 18 + assert output_memlet.dst_subset.max_element()[0] == 27 + + tasklet_memlet: dace.Memlet = next(iter(state.out_edges(tasklet))).data + assert str(tasklet_memlet.subset[0][0] - 18).strip() == "__i" + + +def test_multi_stage_reduction(): + sdfg = _make_multi_stage_reduction_sdfg() + + # Make the input + ref = { + "i1": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "i2": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "i3": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "i4": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "o1": np.zeros(40, dtype=np.float64), + } + res = copy.deepcopy(ref) + + # Generate the reference solution. + csdfg_ref = sdfg.compile() + csdfg_ref(**ref) + + # Apply the transformation. + nb_applies = gtx_transformations.gt_remove_copy_chain(sdfg, validate_all=True) + + # Run the processed SDFG + csdfg_res = sdfg.compile() + csdfg_res(**res) + + # Perform all the checks. + acnodes: list[dace_nodes.AccessNode] = util.count_nodes( + sdfg, dace_nodes.AccessNode, return_nodes=True + ) + assert len(acnodes) == 5 + assert not any(ac.desc(sdfg).transient for ac in acnodes) + assert all(np.allclose(ref[name], res[name]) for name in ref.keys()) + + +def test_not_fully_copied(): + sdfg = _make_not_fully_copied() + + # Apply the transformation. + # It will only remove `d` all the others are retained, because they are not read + # correctly, i.e. fully. + nb_applies = gtx_transformations.gt_remove_copy_chain(sdfg, validate_all=True) + + # Perform all the checks. + acnodes: list[dace_nodes.AccessNode] = util.count_nodes( + sdfg, dace_nodes.AccessNode, return_nodes=True + ) + assert len(acnodes) == 4 + assert nb_applies == 1 + assert "d" not in acnodes + + +def test_possible_cyclic_sdfg(): + sdfg = _make_possible_cyclic_sdfg() + + # Apply the transformation. + # It will not remove `a1`, because it it would and replace it with `a2` then + # the resulting SDFG is cyclic. It will, however, replace `a2` with `o1`. + nb_applies = gtx_transformations.gt_remove_copy_chain(sdfg, validate_all=True) + + # Perform all the checks. + acnodes: list[dace_nodes.AccessNode] = util.count_nodes( + sdfg, dace_nodes.AccessNode, return_nodes=True + ) + assert len(acnodes) == 3 + assert nb_applies == 1 + assert "o1" not in acnodes + + +def test_a1_additional_output(): + sdfg = _make_a1_has_output_sdfg() + + # Make the input + ref = { + "i1": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "i2": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "i3": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "o1": np.zeros(10, dtype=np.float64), + "o2": np.zeros(30, dtype=np.float64), + } + res = copy.deepcopy(ref) + + csdfg_ref = sdfg.compile() + csdfg_ref(**ref) + + # Apply the transformation. + # The transformation removes `a1` and `a2`. + nb_applies = gtx_transformations.gt_remove_copy_chain(sdfg, validate_all=True) + + # Perform the tests. + acnodes: list[dace_nodes.AccessNode] = util.count_nodes( + sdfg, dace_nodes.AccessNode, return_nodes=True + ) + assert len(acnodes) == 5 + assert nb_applies == 2 + assert not any(acnode.data.startswith("a") for acnode in acnodes) + + # Now run the SDFG, which is essentially to check if the subsets were handled + # correctly. This is especially important for `o1` which is composed of both + # `i1` and `i2`. + csdfg_res = sdfg.compile() + csdfg_res(**res) + assert all(np.allclose(ref[name], res[name]) for name in ref.keys()) + + +def test_linear_chain_with_nested_sdfg(): + sdfg, inner_sdfg = _make_linear_chain_with_nested_sdfg_sdfg() + + # Ensure that the SDFG was constructed in the correct way. + assert inner_sdfg.arrays["i0"].strides == sdfg.arrays["a"].strides + assert inner_sdfg.arrays["o0"].strides == sdfg.arrays["b"].strides + assert inner_sdfg.arrays["i0"].shape == inner_sdfg.arrays["o0"].shape + assert inner_sdfg.arrays["i0"].shape == sdfg.arrays["a"].shape + + def ref_comp(a, e): + def inner_ref(i0, o0): + for i in range(10): + for j in range(10): + o0[i, j] = i0[i, j] + 10 + + b, c, d = np.zeros_like(a), np.zeros((20, 20)), np.zeros((30, 30)) + inner_ref(i0=a, o0=b) + c[5:15, 3:13] = b + d[2:22, 6:26] = c + e[1:31, 8:38] = d + + # Make the input + ref = { + "a": np.array(np.random.rand(10, 10), dtype=np.float64, copy=True), + "e": np.zeros((40, 40), dtype=np.float64), + } + res = copy.deepcopy(ref) + + ref_comp(**ref) + + # Apply the transformation. + # It should remove all non transient arrays. + nb_applies = gtx_transformations.gt_remove_copy_chain(sdfg, validate_all=True) + + # Perform the tests. + acnodes: list[dace_nodes.AccessNode] = util.count_nodes( + sdfg, dace_nodes.AccessNode, return_nodes=True + ) + assert {ac.data for ac in acnodes} == {"a", "e"} + assert util.count_nodes(sdfg, dace_nodes.NestedSDFG) == 1 + + # The shapes should be the same as before. + assert inner_sdfg.arrays["i0"].shape == inner_sdfg.arrays["o0"].shape + assert inner_sdfg.arrays["i0"].shape == sdfg.arrays["a"].shape + + # The strides of `i0` should also be the same as before, but the strides + # of `o0` should now be the same as `e`. + assert inner_sdfg.arrays["i0"].strides == sdfg.arrays["a"].strides + assert inner_sdfg.arrays["o0"].strides == sdfg.arrays["e"].strides + + # Now run the transformed SDFG to see if the same output is generated. + csdfg_res = sdfg.compile() + csdfg_res(**res) + assert all(np.allclose(ref[name], res[name]) for name in ref.keys()) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py index d3aadf8927..8befcf0610 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py @@ -18,7 +18,6 @@ ) from . import util -import dace def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: From 5a084263fa81ee4ec5cfc7149a4e9988f2f0086e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Thu, 13 Mar 2025 06:56:01 +0100 Subject: [PATCH 171/178] fix[dace][next]: Fixed some undefined behaviour in the chain remover. (#1910) The issue was discovered by Edoardo (@edopao). The underlying problem is that `get_src_subset()` tries to set the direction of the Memlet, for that it looks for the source of the data. However, when we called it then the source is unspecific because the new edge is already there and the old one was not removed. This might implicitly change the direction of the Memlet. I have removed the calls where they are potentially dangerous and added some counter measure to avoid the problem. --- .../redundant_array_removers.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py b/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py index f691602764..8919d2bc0f 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py @@ -60,6 +60,13 @@ def gt_remove_copy_chain( single_use_data: Which data descriptors are used only once. If not passed the function will run `FindSingleUseData`. """ + + # To ensures that the `{src,dst}_subset` are properly set, run initialization. + # See [issue 1703](https://github.com/spcl/dace/issues/1703) + for state in sdfg.states(): + for edge in state.edges(): + edge.data.try_initialize(sdfg, state, edge) + if single_use_data is None: find_single_use_data = dace_analysis.FindSingleUseData() single_use_data = find_single_use_data.apply_pass(sdfg, None) @@ -905,14 +912,17 @@ def _reroute_edge( """ current_memlet: dace.Memlet = current_edge.data if is_producer_edge: - current_subset: dace_sbs.Range = current_memlet.get_dst_subset(current_edge, state) + # NOTE: See the note in `_reconfigure_dataflow()` why it is not save to + # use the `get_{dst, src}_subset()` function, although it would be more + # appropriate. + current_subset: dace_sbs.Range = current_memlet.dst_subset new_src = current_edge.src new_src_conn = current_edge._src_conn new_dst = a2 new_dst_conn = None assert current_edge.dst_conn is None else: - current_subset = current_memlet.get_src_subset(current_edge, state) + current_subset = current_memlet.src_subset new_src = a2 new_src_conn = None new_dst = current_edge.dst @@ -1027,13 +1037,19 @@ def _reconfigure_dataflow( for memlet_tree in state.memlet_tree(new_edge).traverse_children(include_self=False): edge_to_adjust = memlet_tree.edge memlet_to_adjust = edge_to_adjust.data + + # NOTE: Actually we should use the `get_{src, dst}_subset()` functions, + # see https://github.com/spcl/dace/issues/1703. However, we can not + # do that because the SDFG is currently in an invalid state. So + # we have to call the properties and hope that it works. + subset_to_adjust = ( + memlet_to_adjust.dst_subset if is_producer_edge else memlet_to_adjust.src_subset + ) + + # If needed modify the association of the Memlet. if memlet_to_adjust.data == a1.data: memlet_to_adjust.data = a2.data - if is_producer_edge: - subset_to_adjust = memlet_to_adjust.get_dst_subset(edge_to_adjust, state) - else: - subset_to_adjust = memlet_to_adjust.get_src_subset(edge_to_adjust, state) assert subset_to_adjust is not None subset_to_adjust.offset(a2_offsets, negative=False) From ee950ee38f3994f669cc0a653d8a1a5a9a32c7e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Thu, 13 Mar 2025 06:57:00 +0100 Subject: [PATCH 172/178] feat[dace][next]: Updated GPU Transformation Scheme (#1904) This PR slightly changes how the GPU transformation works. It mainly changes how trivial Maps are eliminated. First, there is now a dedicated function, `gt_remove_trivial_gpu_maps()` for this. Furthermore, before it was only using the `TrivialGPUMapElimination` transformation, which tries to promote and fuse the trivial maps, but only with downstream maps. Now there is a second stage that tries to fuse the trivial maps together. This is mostly to reduce the number of kernel calls that we are doing. --- .../runners/dace/transformations/gpu_utils.py | 87 +++++++++++++++++-- 1 file changed, 78 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index f4372db80f..e1f105f0ef 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -92,15 +92,10 @@ def gt_gpu_transformation( gtx_transformations.gt_simplify(sdfg) if try_removing_trivial_maps: - # In DaCe a Tasklet, outside of a Map, can not write into an _array_ that is on - # GPU. `sdfg.apply_gpu_transformations()` will wrap such Tasklets in a Map. So - # we might end up with lots of these trivial Maps, each requiring a separate - # kernel launch. To prevent this we will combine these trivial maps, if - # possible, with their downstream maps. - sdfg.apply_transformations_once_everywhere( - TrivialGPUMapElimination(), - validate=False, - validate_all=False, + gt_remove_trivial_gpu_maps( + sdfg=sdfg, + validate=validate, + validate_all=validate_all, ) gtx_transformations.gt_simplify(sdfg, validate=validate, validate_all=validate_all) @@ -578,6 +573,80 @@ def apply( gpu_map.gpu_launch_bounds = launch_bounds +def gt_remove_trivial_gpu_maps( + sdfg: dace.SDFG, + validate: bool = True, + validate_all: bool = False, +) -> dace.SDFG: + """Removes trivial maps that were created by the GPU transformation. + + The main problem is that a Tasklet outside of a Map cannot write into an + _array_ that is on GPU. `sdfg.apply_gpu_transformations()` will wrap such + Tasklets in a Map. The `GT4PyMoveTaskletIntoMap` pass, that runs before, + but only works if the tasklet is adjacent to a map. + + It first tries to promote them such that they can be fused in other non-trivial + maps, it will then also perform fusion on them, to reduce the number of kernel + calls. + + Args: + sdfg: The SDFG that we process. + validate: Perform validation at the end of the function. + validate_all: Perform validation also on intermediate steps. + """ + + # First we try to promote and fuse them with other non-trivial maps. + sdfg.apply_transformations_once_everywhere( + TrivialGPUMapElimination( + do_not_fuse=False, + only_gpu_maps=True, + ), + validate=False, + validate_all=False, + ) + gtx_transformations.gt_simplify(sdfg, validate=validate, validate_all=validate_all) + + # Now we try to fuse them together, however, we restrict the fusion to trivial + # GPU map. + def restrict_to_trivial_gpu_maps( + self: gtx_transformations.MapFusion, + map_entry_1: dace_nodes.MapEntry, + map_entry_2: dace_nodes.MapEntry, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + permissive: bool, + ) -> bool: + for map_entry in [map_entry_1, map_entry_2]: + _map = map_entry.map + if len(_map.params) != 1: + return False + if _map.range[0][0] != _map.range[0][1]: + return False + if _map.schedule not in [ + dace.dtypes.ScheduleType.GPU_Device, + dace.dtypes.ScheduleType.GPU_Default, + ]: + return False + return True + + sdfg.apply_transformations_repeated( + [ + gtx_transformations.MapFusionSerial( + only_toplevel_maps=True, + apply_fusion_callback=restrict_to_trivial_gpu_maps, + ), + gtx_transformations.MapFusionParallel( + only_toplevel_maps=True, + apply_fusion_callback=restrict_to_trivial_gpu_maps, + ), + ], + validate=validate, + validate_all=validate_all, + ) + + return sdfg + + @dace_properties.make_properties class TrivialGPUMapElimination(dace_transformation.SingleStateTransformation): """Eliminate certain kind of trivial GPU maps. From 0e362086d3fd7f24d021b1e1457b26afdd3987a0 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 14 Mar 2025 15:32:02 +0100 Subject: [PATCH 173/178] docs: rename ADRs/Index.md to README.md (#1907) Documents describing the contents of a (sub-)folder are commonly named README.md. This is so common that e.g. GitHub will display the contents of a README.md file when looking at the folder. --- docs/development/ADRs/{Index.md => README.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/development/ADRs/{Index.md => README.md} (100%) diff --git a/docs/development/ADRs/Index.md b/docs/development/ADRs/README.md similarity index 100% rename from docs/development/ADRs/Index.md rename to docs/development/ADRs/README.md From 42850f0422a516c937100744810d338d7b99d0b8 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 14 Mar 2025 15:49:16 +0100 Subject: [PATCH 174/178] docs: rename ADRs/Index.md to README.md (#1907) ## Description Documents describing the contents of a (sub-)folder are commonly named `README.md`. This is so common that e.g. GitHub will display the contents of a `README.md` file when looking at the folder, e.g. https://github.com/GridTools/gt4py/tree/main/docs/development/ADRs will show the contents of renamed file (after the proposed name change). ## Requirements - [ ] All fixes and/or new features come with corresponding tests. N/A - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> From 9be98f69070dab356b5ecf3a76fccdcf20967826 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 17 Mar 2025 14:06:23 +0100 Subject: [PATCH 175/178] refactor: centralize `CUPY_DEVICE_TYPE` in `_core/definitions` (#1908) ## Description This PR centralizes the definition of `CUPY_DEVICE_TYPE` in `_core/definitions`. This effectively de-duplicates the definition of `CUPY_DEVICE` in gt4py next and cartesian. Fixes a couple of typos along the way. Related issue: https://github.com/GridTools/gt4py/issues/1880 ## Requirements - [x] All fixes and/or new features come with corresponding tests. Should be covered by existing tests. - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- src/gt4py/_core/definitions.py | 13 +++++++++- src/gt4py/next/allocators.py | 26 +++++-------------- .../runners/dace/workflow/backend.py | 2 +- .../runners/dace/workflow/translation.py | 4 +-- .../next/program_processors/runners/gtfn.py | 2 +- src/gt4py/storage/cartesian/utils.py | 15 +++-------- .../feature_tests/dace/test_orchestration.py | 2 +- .../runners_tests/dace_tests/test_dace.py | 12 ++++----- 8 files changed, 33 insertions(+), 43 deletions(-) diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index ba273c75d9..41a592c3d4 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -39,15 +39,21 @@ ) -if TYPE_CHECKING: +try: import cupy as cp +except ImportError: + cp = None +if TYPE_CHECKING: CuPyNDArray: TypeAlias = cp.ndarray import jax.numpy as jnp JaxNDArray: TypeAlias = jnp.ndarray +# The actual assignment happens after the definition of `DeviceType` enum. +CUPY_DEVICE_TYPE: Literal[None, DeviceType.CUDA, DeviceType.ROCM] +"""Type of the GPU accelerator device, if present.""" # -- Scalar types supported by GT4Py -- bool_ = np.bool_ @@ -396,6 +402,11 @@ class DeviceType(enum.IntEnum): ) +CUPY_DEVICE_TYPE = ( + None if not cp else (DeviceType.ROCM if cp.cuda.runtime.is_hip else DeviceType.CUDA) +) + + @dataclasses.dataclass(frozen=True) class Device(Generic[DeviceTypeT]): """ diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/allocators.py index 864f8c1b09..dae2e9d021 100644 --- a/src/gt4py/next/allocators.py +++ b/src/gt4py/next/allocators.py @@ -18,7 +18,6 @@ Any, Callable, Final, - Literal, Optional, Protocol, Sequence, @@ -28,19 +27,6 @@ ) -try: - import cupy as cp -except ImportError: - cp = None - - -CUPY_DEVICE: Final[Literal[None, core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]] = ( - None - if not cp - else (core_defs.DeviceType.ROCM if cp.cuda.runtime.is_hip else core_defs.DeviceType.CUDA) -) - - FieldLayoutMapper: TypeAlias = Callable[ [Sequence[common.Dimension]], core_allocators.BufferLayoutMap ] @@ -180,7 +166,7 @@ def __gt_allocate__( def horizontal_first_layout_mapper( dims: Sequence[common.Dimension], ) -> core_allocators.BufferLayoutMap: - """Map dimensions to a buffer layout making horizonal dims change the slowest (i.e. larger strides).""" + """Map dimensions to a buffer layout making horizontal dims change the slowest (i.e. larger strides).""" def pos_of_kind(kind: common.DimensionKind) -> list[int]: return [i for i, dim in enumerate(dims) if dim.kind == kind] @@ -246,11 +232,11 @@ def __gt_allocate__( raise self.exception -if CUPY_DEVICE is not None: +if core_defs.CUPY_DEVICE_TYPE is not None: assert isinstance(core_allocators.cupy_array_utils, core_allocators.ArrayUtils) cupy_array_utils = core_allocators.cupy_array_utils - if CUPY_DEVICE is core_defs.DeviceType.CUDA: + if core_defs.CUPY_DEVICE_TYPE is core_defs.DeviceType.CUDA: class CUDAFieldBufferAllocator(BaseFieldBufferAllocator[core_defs.CUDADeviceTyping]): def __init__(self) -> None: @@ -278,7 +264,7 @@ def __init__(self) -> None: else: - class InvalidGPUFielBufferAllocator(InvalidFieldBufferAllocator[core_defs.CUDADeviceTyping]): + class InvalidGPUFieldBufferAllocator(InvalidFieldBufferAllocator[core_defs.CUDADeviceTyping]): def __init__(self) -> None: super().__init__( device_type=core_defs.DeviceType.CUDA, @@ -288,7 +274,9 @@ def __init__(self) -> None: StandardGPUFieldBufferAllocator: Final[type[FieldBufferAllocatorProtocol]] = cast( type[FieldBufferAllocatorProtocol], - type(device_allocators[CUPY_DEVICE]) if CUPY_DEVICE else InvalidGPUFielBufferAllocator, + type(device_allocators[core_defs.CUPY_DEVICE_TYPE]) + if core_defs.CUPY_DEVICE_TYPE + else InvalidGPUFieldBufferAllocator, ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py index 55d7122767..fb93e3df79 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py @@ -27,7 +27,7 @@ class Params: name_postfix = "" gpu = factory.Trait( allocator=next_allocators.StandardGPUFieldBufferAllocator(), - device_type=next_allocators.CUPY_DEVICE or core_defs.DeviceType.CUDA, + device_type=core_defs.CUPY_DEVICE_TYPE or core_defs.DeviceType.CUDA, name_device="gpu", ) cached = factory.Trait( diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index 6e1b3a6f32..e31e4ea741 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -15,7 +15,7 @@ import factory from gt4py._core import definitions as core_defs -from gt4py.next import allocators as gtx_allocators, common +from gt4py.next import common from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import interface @@ -84,7 +84,7 @@ def __call__( inp.args.offset_provider, # TODO(havogt): should be offset_provider_type once the transformation don't require run-time info inp.args.column_axis, auto_opt=self.auto_optimize, - on_gpu=(self.device_type == gtx_allocators.CUPY_DEVICE), + on_gpu=(self.device_type == core_defs.CUPY_DEVICE_TYPE), ) param_types = tuple( diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index a8961fd9bc..12f5f34a7e 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -185,7 +185,7 @@ class Params: name_postfix = "" gpu = factory.Trait( allocator=next_allocators.StandardGPUFieldBufferAllocator(), - device_type=next_allocators.CUPY_DEVICE or core_defs.DeviceType.CUDA, + device_type=core_defs.CUPY_DEVICE_TYPE or core_defs.DeviceType.CUDA, name_device="gpu", ) cached = factory.Trait( diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index 2e1bfb69b5..d2c5ff066f 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -12,7 +12,7 @@ import functools import math import numbers -from typing import Final, Literal, Optional, Sequence, Tuple, Union, cast +from typing import Literal, Optional, Sequence, Tuple, Union, cast import numpy as np import numpy.typing as npt @@ -30,13 +30,6 @@ cp = None -CUPY_DEVICE: Final[Literal[None, core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]] = ( - None - if not cp - else (core_defs.DeviceType.ROCM if cp.cuda.get_hipcc_path() else core_defs.DeviceType.CUDA) -) - - FieldLike = Union["cp.ndarray", np.ndarray, ArrayInterface, CUDAArrayInterface] _CPUBufferAllocator = allocators.NDArrayBufferAllocator( @@ -47,12 +40,12 @@ if cp: assert isinstance(allocators.cupy_array_utils, allocators.ArrayUtils) - if CUPY_DEVICE == core_defs.DeviceType.CUDA: + if core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.CUDA: _GPUBufferAllocator = allocators.NDArrayBufferAllocator( device_type=core_defs.DeviceType.CUDA, array_utils=allocators.cupy_array_utils, ) - elif CUPY_DEVICE == core_defs.DeviceType.ROCM: + elif core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.ROCM: _GPUBufferAllocator = allocators.NDArrayBufferAllocator( device_type=core_defs.DeviceType.ROCM, array_utils=allocators.cupy_array_utils, @@ -286,7 +279,7 @@ def _allocate_gpu( allocate_gpu = _allocate_gpu -if CUPY_DEVICE == core_defs.DeviceType.ROCM: +if core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.ROCM: class CUDAArrayInterfaceNDArray(cp.ndarray): def __new__(cls, input_array: "cp.ndarray") -> CUDAArrayInterfaceNDArray: diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index 3ba376b08f..2fb780c1bd 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -86,7 +86,7 @@ def test_sdfgConvertible_connectivities(unstructured_case): # noqa: F811 allocator, backend = unstructured_case.allocator, unstructured_case.backend - if gtx_allocators.is_field_allocator_for(allocator, gtx_allocators.CUPY_DEVICE): + if gtx_allocators.is_field_allocator_for(allocator, core_defs.CUPY_DEVICE_TYPE): import cupy as xp dace_storage_type = dace.StorageType.GPU_Global diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py index ca4a1e0f1f..64ec757f16 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py @@ -11,7 +11,6 @@ import ctypes import unittest import unittest.mock -from unittest.mock import patch import numpy as np import pytest @@ -21,13 +20,12 @@ from gt4py.next.ffront.fbuiltins import where from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import E2V, cartesian_case, unstructured_case +from next_tests.integration_tests.cases import E2V, cartesian_case, unstructured_case # noqa: F401 from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - exec_alloc_descriptor, - mesh_descriptor, + exec_alloc_descriptor, # noqa: F401 + mesh_descriptor, # noqa: F401 ) - dace = pytest.importorskip("dace") @@ -177,7 +175,7 @@ def verify_testee(offset_provider): offset_provider = unstructured_case.offset_provider else: assert gtx.allocators.is_field_allocator_for( - unstructured_case.backend.allocator, gtx.allocators.CUPY_DEVICE + unstructured_case.backend.allocator, core_defs.CUPY_DEVICE_TYPE ) import cupy as cp @@ -186,7 +184,7 @@ def verify_testee(offset_provider): # to gpu memory at each program call (see `dace_backend._ensure_is_on_device`), # therefore fast_call cannot be used (unless cupy reuses the same cupy array # from the its memory pool, but this behavior is random and unpredictable). - # Here we copy the connectivity to gpu memory, and resuse the same cupy array + # Here we copy the connectivity to gpu memory, and reuse the same cupy array # on multiple program calls, in order to ensure that fast_call is used. offset_provider = { "E2V": gtx.as_connectivity( From 7c65eea6ddb4e691bb982d726ba3ae89e4d2a566 Mon Sep 17 00:00:00 2001 From: SF-N Date: Mon, 17 Mar 2025 19:07:47 +0100 Subject: [PATCH 176/178] feat[next]: Extend and refactor constant folding (#1810) The following new transformations are introduced: - _CANONICALIZE_OP_FUNCALL_SYMREF_LITERAL_: `literal, symref` -> `symref, literal`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP `literal, funcall` -> `funcall, literal`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP `funcall, op` -> `op, funcall` for s[0] + (s[0] + 1), prerequisite for FOLD_MIN_MAX_PLUS - _CANONICALIZE_MINUS_: `a - b` -> `a + (-b)`, prerequisite for FOLD_MIN_MAX_PLUS - _CANONICALIZE_MIN_MAX_: `maximum(a, maximum(...))` -> `maximum(maximum(...), a)`, prerequisite for - _FOLD_FUNCALL_LITERAL_: `(a + 1) + 1` -> `a + (1 + 1)` - _FOLD_MIN_MAX_: `maximum(maximum(a, 1), a)` -> `maximum(a, 1)` `maximum(maximum(a, 1), 1)` -> `maximum(a, 1)` - FOLD_MIN_MAX_PLUS: `maximum(plus(a, 1), a)` -> `plus(a, 1)` `maximum(plus(a, 1), plus(a, -1))` -> `plus(a, maximum(1, -1))` - FOLD_NEUTRAL_OP: `a + 0` -> `a`, `a * 1` -> `a` In the end, unary minuses are transformed back to `minus`-calls where possible by `UndoCanonicalizeMinus`: `a + (-b)` -> `a - b` , `-a + b` -> `b - a`, `-a + (-b)` -> `-a - b` In addition the pass now transforms until a fixed point is reached (like in the `CollapseTuple` pass) instead of just transforming once. The `FixedPointTransformation` class from [PR 1826](https://github.com/GridTools/gt4py/pull/1826) is used here. Previously, large nested maximum expressions like `maximum(maximum(maximum(maximum(sym, 1),...), 1), maximum(maximum(sym, 1), 1))` caused timeouts in PMAP-GO as the runtime of the domain inference increased significantly due to the large domain expressions that could not be simplified. This replaces [this PR](https://github.com/tehrengruber/gt4py/pull/4). --------- Co-authored-by: Till Ehrengruber --- .../iterator/transforms/constant_folding.py | 231 +++++++++++++++--- .../transforms_tests/test_constant_folding.py | 189 ++++++++++---- 2 files changed, 345 insertions(+), 75 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 0dc324f94c..fdbfec99ca 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -6,57 +6,226 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from __future__ import annotations + +import dataclasses +import enum +import functools +import operator +from typing import Optional + +from gt4py import eve from gt4py.next.iterator import builtins, embedded, ir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.transforms import fixed_point_transformation + +def _value_from_literal(literal: ir.Literal): + return getattr(embedded, str(literal.type))(literal.value) -class ConstantFolding(PreserveLocationVisitor, NodeTranslator): + +class UndoCanonicalizeMinus(eve.NodeTranslator): PRESERVED_ANNEX_ATTRS = ( "type", "domain", ) + def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: + node = super().generic_visit(node, **kwargs) + # `a + (-b)` -> `a - b` , `-a + b` -> `b - a`, `-a + (-b)` -> `-a - b` + if cpm.is_call_to(node, "plus"): + a, b = node.args + if cpm.is_call_to(b, "neg"): + return im.minus(a, b.args[0]) + if isinstance(b, ir.Literal) and _value_from_literal(b) < 0: + return im.minus(a, -_value_from_literal(b)) + if cpm.is_call_to(a, "neg"): + return im.minus(b, a.args[0]) + if isinstance(a, ir.Literal) and _value_from_literal(a) < 0: + return im.minus(b, -_value_from_literal(a)) + return node + + +_COMMUTATIVE_OPS = ("plus", "multiplies", "minimum", "maximum") + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class ConstantFolding( + fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor +): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + + class Transformation(enum.Flag): + # `1 + a` -> `a + 1`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP + # `1 + f(...)` -> `f(...) + 1`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP + # `f(...) + (expr1 + expr2)` -> `(expr1 + expr2) + f(...)`, for `s[0] + (s[0] + 1)`, prerequisite for FOLD_MIN_MAX_PLUS + CANONICALIZE_OP_FUNCALL_SYMREF_LITERAL = enum.auto() + + # `a - b` -> `a + (-b)`, prerequisite for FOLD_MIN_MAX_PLUS + CANONICALIZE_MINUS = enum.auto() + + # `maximum(a, maximum(...))` -> `maximum(maximum(...), a)`, prerequisite for FOLD_MIN_MAX + CANONICALIZE_MIN_MAX = enum.auto() + + # `(a + 1) + 1` -> `a + (1 + 1)` + FOLD_FUNCALL_LITERAL = enum.auto() + + # `maximum(maximum(a, 1), a)` -> `maximum(a, 1)` + # `maximum(maximum(a, 1), 1)` -> `maximum(a, 1)` + FOLD_MIN_MAX = enum.auto() + + # `maximum(a + 1), a)` -> `a + 1` + # `maximum(a + 1, a + (-1))` -> `a + maximum(1, -1)` + FOLD_MIN_MAX_PLUS = enum.auto() + + # `a + 0` -> `a`, `a * 1` -> `a` + FOLD_NEUTRAL_OP = enum.auto() + + # `1 + 1` -> `2` + FOLD_ARITHMETIC_BUILTINS = enum.auto() + + # `minimum(a, a)` -> `a` + FOLD_MIN_MAX_LITERALS = enum.auto() + + # `if_(True, true_branch, false_branch)` -> `true_branch` + FOLD_IF = enum.auto() + + @classmethod + def all(self) -> ConstantFolding.Transformation: + return functools.reduce(operator.or_, self.__members__.values()) + + enabled_transformations: Transformation = Transformation.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] + @classmethod def apply(cls, node: ir.Node) -> ir.Node: - return cls().visit(node) + node = cls().visit(node) + return UndoCanonicalizeMinus().visit(node) + + def transform_canonicalize_op_funcall_symref_literal( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: + # `op(literal, symref|funcall)` -> `op(symref|funcall, literal)` + # `op1(funcall, op2(...))` -> `op1(op2(...), funcall)` for `s[0] + (s[0] + 1)` + if cpm.is_call_to(node, _COMMUTATIVE_OPS): + a, b = node.args + if (isinstance(a, ir.Literal) and not isinstance(b, ir.Literal)) or ( + not cpm.is_call_to(a, _COMMUTATIVE_OPS) and cpm.is_call_to(b, _COMMUTATIVE_OPS) + ): + return im.call(node.fun)(b, a) + return None - def visit_FunCall(self, node: ir.FunCall): - # visit depth-first such that nested constant expressions (e.g. `(1+2)+3`) are properly folded - new_node = self.generic_visit(node) + def transform_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `a - b` -> `a + (-b)` + if cpm.is_call_to(node, "minus"): + return im.plus(node.args[0], self.fp_transform(im.call("neg")(node.args[1]))) + return None + def transform_canonicalize_min_max(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `maximum(a, maximum(...))` -> `maximum(maximum(...), a)` + if cpm.is_call_to(node, ("maximum", "minimum")): + op = node.fun.id # type: ignore[attr-defined] # assured by if above + if cpm.is_call_to(node.args[1], op) and not cpm.is_call_to(node.args[0], op): + return im.call(op)(node.args[1], node.args[0]) + return None + + def transform_fold_funcall_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `(a + 1) + 1` -> `a + (1 + 1)` + if cpm.is_call_to(node, "plus"): + if cpm.is_call_to(node.args[0], "plus") and isinstance(node.args[1], ir.Literal): + (expr, lit1), lit2 = node.args[0].args, node.args[1] + if isinstance(expr, (ir.SymRef, ir.FunCall)) and isinstance(lit1, ir.Literal): + return im.plus( + expr, + self.fp_transform(im.plus(lit1, lit2)), + ) + return None + + def transform_fold_min_max(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `maximum(maximum(a, 1), a)` -> `maximum(a, 1)` + # `maximum(maximum(a, 1), 1)` -> `maximum(a, 1)` + if cpm.is_call_to(node, ("minimum", "maximum")): + op = node.fun.id # type: ignore[attr-defined] # assured by if above + if cpm.is_call_to(node.args[0], op): + fun_call, arg1 = node.args + if arg1 in fun_call.args: # type: ignore[attr-defined] # assured by if above + return fun_call + return None + + def transform_fold_min_max_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if ( - isinstance(new_node.fun, ir.SymRef) - and new_node.fun.id in ["minimum", "maximum"] - and new_node.args[0] == new_node.args[1] - ): # `minimum(a, a)` -> `a` - return new_node.args[0] + isinstance(node, ir.FunCall) + and isinstance(node.fun, ir.SymRef) + and cpm.is_call_to(node, ("minimum", "maximum")) + ): + arg0, arg1 = node.args + # `maximum(a + 1, a)` -> `a + 1` + if cpm.is_call_to(arg0, "plus"): + if arg0.args[0] == arg1: + return im.plus( + arg0.args[0], self.fp_transform(im.call(node.fun.id)(arg0.args[1], 0)) + ) + # `maximum(a + 1, a + (-1))` -> `a + maximum(1, -1)` + if cpm.is_call_to(arg0, "plus") and cpm.is_call_to(arg1, "plus"): + if arg0.args[0] == arg1.args[0]: + return im.plus( + arg0.args[0], + self.fp_transform(im.call(node.fun.id)(arg0.args[1], arg1.args[1])), + ) + + return None + def transform_fold_neutral_op(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `a + 0` -> `a`, `a * 1` -> `a` if ( - isinstance(new_node.fun, ir.SymRef) - and new_node.fun.id == "if_" - and isinstance(new_node.args[0], ir.Literal) - ): # `if_(True, true_branch, false_branch)` -> `true_branch` - if new_node.args[0].value == "True": - new_node = new_node.args[1] - else: - new_node = new_node.args[2] + cpm.is_call_to(node, "plus") + and isinstance(node.args[1], ir.Literal) + and node.args[1].value.isdigit() + and int(node.args[1].value) == 0 + ) or ( + cpm.is_call_to(node, "multiplies") + and isinstance(node.args[1], ir.Literal) + and node.args[1].value.isdigit() + and int(node.args[1].value) == 1 + ): + return node.args[0] + return None + @classmethod + def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `1 + 1` -> `2` if ( - isinstance(new_node, ir.FunCall) - and isinstance(new_node.fun, ir.SymRef) - and len(new_node.args) > 0 - and all(isinstance(arg, ir.Literal) for arg in new_node.args) - ): # `1 + 1` -> `2` + isinstance(node, ir.FunCall) + and isinstance(node.fun, ir.SymRef) + and len(node.args) > 0 + and all(isinstance(arg, ir.Literal) for arg in node.args) + ): try: - if new_node.fun.id in builtins.ARITHMETIC_BUILTINS: - fun = getattr(embedded, str(new_node.fun.id)) + if node.fun.id in builtins.ARITHMETIC_BUILTINS: + fun = getattr(embedded, str(node.fun.id)) arg_values = [ - getattr(embedded, str(arg.type))(arg.value) # type: ignore[attr-defined] # arg type already established in if condition - for arg in new_node.args + _value_from_literal(arg) # type: ignore[arg-type] # arg type already established in if condition + for arg in node.args ] - new_node = im.literal_from_value(fun(*arg_values)) + return im.literal_from_value(fun(*arg_values)) except ValueError: pass # happens for inf and neginf + return None - return new_node + def transform_fold_min_max_literals(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `minimum(a, a)` -> `a` + if cpm.is_call_to(node, ("minimum", "maximum")): + if node.args[0] == node.args[1]: + return node.args[0] + return None + + def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `if_(True, true_branch, false_branch)` -> `true_branch` + if cpm.is_call_to(node, "if_") and isinstance(node.args[0], ir.Literal): + if node.args[0].value == "True": + return node.args[1] + else: + return node.args[2] + return None diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index cf325c2daa..1da2b8cec5 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -9,54 +9,155 @@ from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.constant_folding import ConstantFolding - -def test_constant_folding_boolean(): - testee = im.not_(im.literal_from_value(True)) - expected = im.literal_from_value(False) - - actual = ConstantFolding.apply(testee) - assert actual == expected +import pytest +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding -def test_constant_folding_math_op(): - expected = im.literal_from_value(13) - testee = im.plus( - im.literal_from_value(4), - im.plus( - im.literal_from_value(7), im.minus(im.literal_from_value(7), im.literal_from_value(5)) +def test_cases(): + return ( + # expr, simplified expr + (im.plus(1, 1), 2), + (im.not_(True), False), + (im.plus(4, im.plus(7, im.minus(7, 5))), 13), + (im.if_(True, im.plus(im.ref("a"), 2), im.minus(9, 5)), im.plus("a", 2)), + (im.minimum("a", "a"), "a"), + (im.maximum(1, 2), 2), + # canonicalization + (im.plus("a", 1), im.plus("a", 1)), + (im.plus(1, "a"), im.plus("a", 1)), + # nested plus + (im.plus(im.plus("a", 1), 1), im.plus("a", 2)), + (im.plus(1, im.plus("a", 1)), im.plus("a", 2)), + # nested maximum + (im.maximum(im.maximum("a", 1), 1), im.maximum("a", 1)), + (im.maximum(im.maximum(1, "a"), 1), im.maximum("a", 1)), + (im.maximum("a", im.maximum(1, "a")), im.maximum("a", 1)), + (im.maximum(im.maximum(1, "a"), im.maximum(1, "a")), im.maximum("a", 1)), + (im.maximum(im.maximum(1, "a"), im.maximum("a", 1)), im.maximum("a", 1)), + (im.maximum(im.minimum("a", 1), "a"), im.maximum(im.minimum("a", 1), "a")), + # maximum & plus + (im.maximum(im.plus("a", 1), im.plus("a", 0)), im.plus("a", 1)), + ( + im.maximum(im.plus("a", 1), im.plus(im.plus("a", 1), 0)), + im.plus("a", 1), + ), + (im.maximum("a", im.plus("a", 1)), im.plus("a", 1)), + (im.maximum("a", im.plus("a", im.literal_from_value(-1))), im.ref("a")), + ( + im.plus("a", im.maximum(0, im.literal_from_value(-1))), + im.ref("a"), + ), + # plus & minus + (im.minus(im.plus("a", 1), im.plus(1, 1)), im.minus("a", 1)), + (im.plus(im.minus("a", 1), 2), im.plus("a", 1)), + (im.plus(im.minus(1, "a"), 1), im.minus(2, "a")), + # nested plus + (im.plus(im.plus("a", 1), im.plus(1, 1)), im.plus("a", 3)), + ( + im.plus(im.plus("a", im.literal_from_value(-1)), im.plus("a", 3)), + im.plus(im.minus("a", 1), im.plus("a", 3)), + ), + # maximum & minus + (im.maximum(im.minus("a", 1), "a"), im.ref("a")), + (im.maximum("a", im.minus("a", im.literal_from_value(-1))), im.plus("a", 1)), + ( + im.maximum(im.plus("a", im.literal_from_value(-1)), 1), + im.maximum(im.minus("a", 1), 1), + ), + # minimum & plus & minus + (im.minimum(im.plus("a", 1), "a"), im.ref("a")), + (im.minimum("a", im.plus("a", im.literal_from_value(-1))), im.minus("a", 1)), + (im.minimum(im.minus("a", 1), "a"), im.minus("a", 1)), + (im.minimum("a", im.minus("a", im.literal_from_value(-1))), im.ref("a")), + # nested maximum + (im.maximum("a", im.maximum("b", "a")), im.maximum("b", "a")), + # maximum & plus on complicated expr (tuple_get) + ( + im.maximum( + im.plus(im.tuple_get(1, "a"), 1), + im.maximum(im.tuple_get(1, "a"), im.plus(im.tuple_get(1, "a"), 1)), + ), + im.plus(im.tuple_get(1, "a"), 1), + ), + # nested maximum & plus + ( + im.maximum(im.maximum(im.plus(1, "a"), 1), im.plus(1, "a")), + im.maximum(im.plus("a", 1), 1), + ), + # sanity check that no strange things happen + # complex tests + ( + # 1 - max(max(1, max(1, sym), min(1, sym), sym), 1 + (min(-1, 2) + max(-1, 1 - sym))) + im.minus( + 1, + im.maximum( + im.maximum( + im.maximum(1, im.maximum(1, "a")), + im.maximum(im.maximum(1, "a"), "a"), + ), + im.plus( + 1, + im.plus( + im.minimum(im.literal_from_value(-1), 2), + im.maximum(im.literal_from_value(-1), im.minus(1, "a")), + ), + ), + ), + ), + # 1 - maximum(maximum(sym, 1), maximum(1 - sym, -1)) + im.minus( + 1, + im.maximum( + im.maximum("a", 1), + im.maximum(im.minus(1, "a"), im.literal_from_value(-1)), + ), + ), + ), + ( + # maximum(sym, 1 + sym) + (maximum(1, maximum(1, sym)) + (sym - 1 + (1 + (sym + 1) + 1))) - 2 + im.minus( + im.plus( + im.maximum("a", im.plus(1, "a")), + im.plus( + im.maximum(1, im.maximum(1, "a")), + im.plus(im.minus("a", 1), im.plus(im.plus(1, im.plus("a", 1)), 1)), + ), + ), + 2, + ), + # sym + 1 + (maximum(sym, 1) + (sym - 1 + (sym + 3))) - 2 + im.minus( + im.plus( + im.plus("a", 1), + im.plus( + im.maximum("a", 1), + im.plus(im.minus("a", 1), im.plus("a", 3)), + ), + ), + 2, + ), + ), + ( + # minimum(1 - sym, 1 + sym) + (maximum(maximum(1 - sym, 1 + sym), 1 - sym) + maximum(1 - sym, 1 - sym)) + im.plus( + im.minimum(im.minus(1, "a"), im.plus(1, "a")), + im.plus( + im.maximum(im.maximum(im.minus(1, "a"), im.plus(1, "a")), im.minus(1, "a")), + im.maximum(im.minus(1, "a"), im.minus(1, "a")), + ), + ), + # minimum(1 - sym, sym + 1) + (maximum(1 - sym, sym + 1) + (1 - sym)) + im.plus( + im.minimum(im.minus(1, "a"), im.plus("a", 1)), + im.plus(im.maximum(im.minus(1, "a"), im.plus("a", 1)), im.minus(1, "a")), + ), ), ) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_if(): - expected = im.plus("a", 2) - testee = im.if_( - im.literal_from_value(True), - im.plus(im.ref("a"), im.literal_from_value(2)), - im.minus(im.literal_from_value(9), im.literal_from_value(5)), - ) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_minimum(): - testee = im.minimum("a", "a") - expected = im.ref("a") - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_literal(): - testee = im.plus(im.literal_from_value(1), im.literal_from_value(2)) - expected = im.literal_from_value(3) - actual = ConstantFolding.apply(testee) - assert actual == expected -def test_constant_folding_literal_maximum(): - testee = im.maximum(im.literal_from_value(1), im.literal_from_value(2)) - expected = im.literal_from_value(2) +@pytest.mark.parametrize("test_case", test_cases()) +def test_constant_folding(test_case): + testee, expected = test_case actual = ConstantFolding.apply(testee) - assert actual == expected + assert actual == im.ensure_expr(expected) From e6b9398913672b758b3ac06dc8b0ba683341e8b5 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 18 Mar 2025 15:19:22 +0100 Subject: [PATCH 177/178] refactor[cartesian]: Dace backend: expose control flow (#1894) ## Description This PR refactors the GT4Py/DaCe bridge to expose control flow elements (`if` statements and `while` loops) to DaCe. Previously, the whole contents of a vertical loop was put in one big Tasklet. With this PR, that Tasklet is broken apart in case control flow is found such that control flow is visible in the SDFG. This allows DaCe to better analyze code and will be crucial in future (within the current milestone) performance optimization work. The main ideas in this PR are the following 1. Introduce `oir.CodeBlock` to recursively break down `oir.HorizontalExecution`s into smaller pieces that are either code flow or evaluated in (smaller) Tasklets. 2. Introduce `dcir.Condition`and `dcir.WhileLoop` to represent if statements and while loops that are translated into SDFG states. We keep the current `dcir.MaskStmt` / `dcir.While` for if statements / while loops inside horizontal regions, which aren't yet exposed to DaCe (see https://github.com/GridTools/gt4py/issues/1900). 3. Add support for `if` statements and `while` loops in the state machine of `sdfg_builder.py` 4. We are breaking up vertical loops inside stencils in multiple Tasklets. It might thus happen that we write a "local" scalar in one Tasklet and read it in another Tasklet (downstream). We thus create output connectors for all scalar writes in a Tasklet and input connectors for all reads (unless previously written in the same Tasklet). 5. Memlets can't be generated per horizontal execution anymore and need to be more fine grained. `TaskletAccessInfoCollector` does this work for us, duplicating some logic in `AccessInfoCollector`. A refactor task has been logged to fix/re-evaluate this later. This PR depends on the following (downstream) DaCe fixes - https://github.com/spcl/dace/pull/1954 - https://github.com/spcl/dace/pull/1955 which have been merged by now. Follow-up issues - unrelated changes have been moved to https://github.com/GridTools/gt4py/pull/1895 - https://github.com/GridTools/gt4py/issues/1896 - https://github.com/GridTools/gt4py/issues/1898 - https://github.com/GridTools/gt4py/issues/1900 Related issue: https://github.com/GEOS-ESM/NDSL/issues/53 ## Requirements - [x] All fixes and/or new features come with corresponding tests. Added new tests and increased coverage of horizontal regions with PRs #1807 and #1851. - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. Docs are [in our knowledge base](https://geos-esm.github.io/SMT-Nebulae/technical/backend/dace-bridge/) for now. Will be ported. --------- Co-authored-by: Roman Cattaneo <> Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- src/gt4py/cartesian/backend/dace_backend.py | 4 +- src/gt4py/cartesian/gtc/dace/daceir.py | 140 +++- .../gtc/dace/expansion/daceir_builder.py | 664 ++++++++++++++---- .../cartesian/gtc/dace/expansion/expansion.py | 23 +- .../gtc/dace/expansion/sdfg_builder.py | 349 ++++++++- .../gtc/dace/expansion/tasklet_codegen.py | 85 ++- .../cartesian/gtc/dace/expansion/utils.py | 2 +- src/gt4py/cartesian/gtc/dace/oir_to_dace.py | 7 +- src/gt4py/cartesian/gtc/dace/prefix.py | 23 + src/gt4py/cartesian/gtc/dace/utils.py | 211 +++++- src/gt4py/cartesian/gtc/oir.py | 4 + .../multi_feature_tests/test_dace_parsing.py | 24 +- .../unit_tests/test_gtc/dace/__init__.py | 9 +- .../test_gtc/dace/test_daceir_builder.py | 109 +++ .../test_gtc/dace/test_sdfg_builder.py | 144 ++++ .../unit_tests/test_gtc/dace/test_utils.py | 44 ++ .../unit_tests/test_gtc/dace/utils.py | 54 ++ .../unit_tests/test_gtc/test_oir_to_dace.py | 159 +++++ 18 files changed, 1800 insertions(+), 255 deletions(-) create mode 100644 src/gt4py/cartesian/gtc/dace/prefix.py rename src/gt4py/cartesian/gtc/dace/constants.py => tests/cartesian_tests/unit_tests/test_gtc/dace/__init__.py (56%) create mode 100644 tests/cartesian_tests/unit_tests/test_gtc/dace/test_daceir_builder.py create mode 100644 tests/cartesian_tests/unit_tests/test_gtc/dace/test_sdfg_builder.py create mode 100644 tests/cartesian_tests/unit_tests/test_gtc/dace/test_utils.py create mode 100644 tests/cartesian_tests/unit_tests/test_gtc/dace/utils.py create mode 100644 tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_dace.py diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index a36a9824bd..5fef6d88ba 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -35,7 +35,6 @@ from gt4py.cartesian.gtc.dace.nodes import StencilComputation from gt4py.cartesian.gtc.dace.oir_to_dace import OirSDFGBuilder from gt4py.cartesian.gtc.dace.transformations import ( - InlineThreadLocalTransients, NoEmptyEdgeTrivialMapElimination, nest_sequential_map_scopes, ) @@ -173,7 +172,8 @@ def _post_expand_transformations(sdfg: dace.SDFG): if node.schedule == dace.ScheduleType.CPU_Multicore and len(node.range) <= 1: node.schedule = dace.ScheduleType.Sequential - sdfg.apply_transformations_repeated(InlineThreadLocalTransients, validate=False) + # To be re-evaluated with https://github.com/GridTools/gt4py/issues/1896 + # sdfg.apply_transformations_repeated(InlineThreadLocalTransients, validate=False) # noqa: ERA001 sdfg.simplify(validate=False) nest_sequential_map_scopes(sdfg) for sd in sdfg.all_sdfgs_recursive(): diff --git a/src/gt4py/cartesian/gtc/dace/daceir.py b/src/gt4py/cartesian/gtc/dace/daceir.py index 90c0649940..e07ae9e52b 100644 --- a/src/gt4py/cartesian/gtc/dace/daceir.py +++ b/src/gt4py/cartesian/gtc/dace/daceir.py @@ -17,6 +17,7 @@ from gt4py import eve from gt4py.cartesian.gtc import common, oir from gt4py.cartesian.gtc.common import LocNode +from gt4py.cartesian.gtc.dace import prefix from gt4py.cartesian.gtc.dace.symbol_utils import ( get_axis_bound_dace_symbol, get_axis_bound_diff_str, @@ -525,10 +526,6 @@ class FieldAccessInfo(eve.Node): dynamic_access: bool = False variable_offset_axes: List[Axis] = eve.field(default_factory=list) - @property - def is_dynamic(self) -> bool: - return self.dynamic_access or len(self.variable_offset_axes) > 0 - def axes(self): yield from self.grid_subset.axes() @@ -713,7 +710,7 @@ def axes(self): @property def is_dynamic(self) -> bool: - return self.access_info.is_dynamic + return self.access_info.dynamic_access def with_set_access_info(self, access_info: FieldAccessInfo) -> FieldDecl: return FieldDecl( @@ -730,7 +727,8 @@ class Literal(common.Literal, Expr): class ScalarAccess(common.ScalarAccess, Expr): - pass + is_target: bool + original_name: Optional[str] = None class VariableKOffset(common.VariableKOffset[Expr]): @@ -744,7 +742,12 @@ def no_casts_in_offset_expression(self, _: datamodels.Attribute, expression: Exp class IndexAccess(common.FieldAccess, Expr): - offset: Optional[Union[common.CartesianOffset, VariableKOffset]] + # ScalarAccess used for indirect addressing + offset: Optional[common.CartesianOffset | Literal | ScalarAccess | VariableKOffset] + is_target: bool + + explicit_indices: Optional[list[Literal | ScalarAccess | VariableKOffset]] = None + """Used to access as a full field with explicit indices""" class AssignStmt(common.AssignStmt[Union[IndexAccess, ScalarAccess], Expr], Stmt): @@ -842,33 +845,146 @@ class IterationNode(eve.Node): grid_subset: GridSubset +class Condition(eve.Node): + condition: Tasklet + true_states: list[ComputationState | Condition | WhileLoop] + + # Currently unused due to how `if` statements are parsed in `gtir_to_oir`, see + # https://github.com/GridTools/gt4py/issues/1898 + false_states: list[ComputationState | Condition | WhileLoop] = eve.field(default_factory=list) + + @datamodels.validator("condition") + def condition_has_boolean_expression( + self, attribute: datamodels.Attribute, tasklet: Tasklet + ) -> None: + assert isinstance(tasklet, Tasklet) + assert len(tasklet.stmts) == 1 + assert isinstance(tasklet.stmts[0], AssignStmt) + assert isinstance(tasklet.stmts[0].left, ScalarAccess) + if tasklet.stmts[0].left.original_name is None: + raise ValueError( + f"Original node name not found for {tasklet.stmts[0].left.name}. DaCe IR error." + ) + assert isinstance(tasklet.stmts[0].right, Expr) + if tasklet.stmts[0].right.dtype != common.DataType.BOOL: + raise ValueError("Condition must be a boolean expression.") + + class Tasklet(ComputationNode, IterationNode, eve.SymbolTableTrait): - decls: List[LocalScalarDecl] + label: str stmts: List[Stmt] grid_subset: GridSubset = GridSubset.single_gridpoint() + @datamodels.validator("stmts") + def non_empty_list(self, attribute: datamodels.Attribute, v: list[Stmt]) -> None: + if len(v) < 1: + raise ValueError("Tasklet must contain at least one statement.") + + @datamodels.validator("stmts") + def read_after_write(self, attribute: datamodels.Attribute, statements: list[Stmt]) -> None: + def _remove_prefix(name: eve.SymbolRef) -> str: + return name.removeprefix(prefix.TASKLET_OUT).removeprefix(prefix.TASKLET_IN) + + class ReadAfterWriteChecker(eve.NodeVisitor): + def visit_IndexAccess(self, node: IndexAccess, writes: set[str]) -> None: + if node.is_target: + # Keep track of writes + writes.add(_remove_prefix(node.name)) + return + + # Check reads + if ( + node.name.startswith(prefix.TASKLET_OUT) + and _remove_prefix(node.name) not in writes + ): + raise ValueError(f"Reading undefined '{node.name}'. DaCe IR error.") + + if _remove_prefix(node.name) in writes and not node.name.startswith( + prefix.TASKLET_OUT + ): + raise ValueError( + f"Read after write of '{node.name}' not connected to out connector. DaCe IR error." + ) + + def visit_ScalarAccess(self, node: ScalarAccess, writes: set[str]) -> None: + # Handle stencil parameters differently because they are always available + if not node.name.startswith(prefix.TASKLET_IN) and not node.name.startswith( + prefix.TASKLET_OUT + ): + return + + # Keep track of writes + if node.is_target: + writes.add(_remove_prefix(node.name)) + return + + # Make sure we don't read uninitialized memory + if ( + node.name.startswith(prefix.TASKLET_OUT) + and _remove_prefix(node.name) not in writes + ): + raise ValueError(f"Reading undefined '{node.name}'. DaCe IR error.") + + if _remove_prefix(node.name) in writes and not node.name.startswith( + prefix.TASKLET_OUT + ): + raise ValueError( + f"Read after write of '{node.name}' not connected to out connector. DaCe IR error." + ) + + def visit_AssignStmt(self, node: AssignStmt, writes: Set[eve.SymbolRef]) -> None: + # Visiting order matters because `writes` must not contain the symbols from the left visit + self.visit(node.right, writes=writes) + self.visit(node.left, writes=writes) + + writes: set[str] = set() + checker = ReadAfterWriteChecker() + for statement in statements: + checker.visit(statement, writes=writes) + class DomainMap(ComputationNode, IterationNode): index_ranges: List[Range] schedule: MapSchedule - computations: List[Union[DomainMap, NestedSDFG, Tasklet]] + computations: List[Union[Tasklet, DomainMap, NestedSDFG]] class ComputationState(IterationNode): - computations: List[Union[DomainMap, Tasklet]] + computations: List[Union[Tasklet, DomainMap]] class DomainLoop(ComputationNode, IterationNode): axis: Axis index_range: Range - loop_states: List[Union[ComputationState, DomainLoop]] + loop_states: list[ComputationState | Condition | DomainLoop | WhileLoop] + + +class WhileLoop(eve.Node): + condition: Tasklet + body: list[ComputationState | Condition | WhileLoop] + + @datamodels.validator("condition") + def condition_has_boolean_expression( + self, attribute: datamodels.Attribute, tasklet: Tasklet + ) -> None: + assert isinstance(tasklet, Tasklet) + assert len(tasklet.stmts) == 1 + assert isinstance(tasklet.stmts[0], AssignStmt) + assert isinstance(tasklet.stmts[0].left, ScalarAccess) + if tasklet.stmts[0].left.original_name is None: + raise ValueError( + f"Original node name not found for {tasklet.stmts[0].left.name}. DaCe IR error." + ) + assert isinstance(tasklet.stmts[0].right, Expr) + if tasklet.stmts[0].right.dtype != common.DataType.BOOL: + raise ValueError("Condition must be a boolean expression.") class NestedSDFG(ComputationNode, eve.SymbolTableTrait): label: eve.Coerced[eve.SymbolRef] field_decls: List[FieldDecl] symbol_decls: List[SymbolDecl] - states: List[Union[ComputationState, DomainLoop]] + states: list[ComputationState | Condition | DomainLoop | WhileLoop] # There are circular type references with string placeholders. These statements let datamodels resolve those. diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index e93a15debe..f05a89c5fa 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -10,7 +10,8 @@ import dataclasses from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Union, cast +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union, cast import dace import dace.data @@ -18,11 +19,11 @@ import dace.subsets from gt4py import eve -from gt4py.cartesian.gtc import common, oir +from gt4py.cartesian.gtc import common, oir, utils as gtc_utils from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.expansion_specification import Loop, Map, Sections, Stages from gt4py.cartesian.gtc.dace.utils import ( - compute_dcir_access_infos, + compute_tasklet_access_infos, flatten_list, get_tasklet_symbol, make_dace_subset, @@ -39,54 +40,96 @@ from gt4py.cartesian.gtc.dace.nodes import StencilComputation -def _access_iter(node: oir.HorizontalExecution, get_outputs: bool): - if get_outputs: - iterator = filter( - lambda node: isinstance(node, oir.FieldAccess), - node.walk_values().if_isinstance(oir.AssignStmt).getattr("left"), +class AccessType(Enum): + READ = 0 + WRITE = 1 + + +def _field_access_iterator( + code_block: oir.CodeBlock | oir.MaskStmt | oir.While, access_type: AccessType +): + if access_type == AccessType.WRITE: + return ( + code_block.walk_values() + .if_isinstance(oir.AssignStmt) + .getattr("left") + .if_isinstance(oir.FieldAccess) ) - else: - def _iterator(): - for n in node.walk_values(): - if isinstance(n, oir.AssignStmt): - yield from n.right.walk_values().if_isinstance(oir.FieldAccess) - elif isinstance(n, oir.While): - yield from n.cond.walk_values().if_isinstance(oir.FieldAccess) - elif isinstance(n, oir.MaskStmt): - yield from n.mask.walk_values().if_isinstance(oir.FieldAccess) + def read_access_iterator(): + for node in code_block.walk_values(): + if isinstance(node, oir.AssignStmt): + yield from node.right.walk_values().if_isinstance(oir.FieldAccess) + elif isinstance(node, oir.While): + yield from node.cond.walk_values().if_isinstance(oir.FieldAccess) + elif isinstance(node, oir.MaskStmt): + yield from node.mask.walk_values().if_isinstance(oir.FieldAccess) + + return read_access_iterator() - iterator = _iterator() + +def _mapped_access_iterator( + node: oir.CodeBlock | oir.MaskStmt | oir.While, access_type: AccessType +): + iterator = _field_access_iterator(node, access_type) + write_access = access_type == AccessType.WRITE yield from ( eve.utils.xiter(iterator).map( lambda acc: ( acc.name, acc.offset, - get_tasklet_symbol(acc.name, acc.offset, is_target=get_outputs), + get_tasklet_symbol(acc.name, offset=acc.offset, is_target=write_access), ) ) ).unique(key=lambda x: x[2]) def _get_tasklet_inout_memlets( - node: oir.HorizontalExecution, + node: oir.CodeBlock | oir.MaskStmt | oir.While, + access_type: AccessType, *, - get_outputs: bool, global_ctx: DaCeIRBuilder.GlobalContext, - **kwargs: Any, -) -> List[dcir.Memlet]: - access_infos = compute_dcir_access_infos( + horizontal_extent, + k_interval, + grid_subset: dcir.GridSubset, + dcir_statements: list[dcir.Stmt], +) -> list[dcir.Memlet]: + access_infos = compute_tasklet_access_infos( node, - block_extents=global_ctx.library_node.get_extents, - oir_decls=global_ctx.library_node.declarations, - collect_read=not get_outputs, - collect_write=get_outputs, - **kwargs, + collect_read=access_type == AccessType.READ, + collect_write=access_type == AccessType.WRITE, + declarations=global_ctx.library_node.declarations, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + grid_subset=grid_subset, ) - memlets: List[dcir.Memlet] = [] - for name, offset, tasklet_symbol in _access_iter(node, get_outputs=get_outputs): + names = [ + access.name + for statement in dcir_statements + for access in statement.walk_values().if_isinstance(dcir.ScalarAccess, dcir.IndexAccess) + ] + + memlets: list[dcir.Memlet] = [] + for name, offset, tasklet_symbol in _mapped_access_iterator(node, access_type): + # Avoid adding extra inputs/outputs to the tasklet + if name not in access_infos: + continue + + # Find `tasklet_symbol` in dcir_statements because we can't know (from the oir statements) + # where the tasklet boundaries will be. Consider + # + # with computation(PARALLEL), interval(...): + # statement1 + # if condition: + # statement2 + # statement3 + # + # statements 1 and 3 will end up in the same CodeBlock but aren't in the same tasklet. + if tasklet_symbol not in names: + continue + access_info = access_infos[name] if not access_info.variable_offset_axes: offset_dict = offset.to_dict() @@ -100,8 +143,8 @@ def _get_tasklet_inout_memlets( field=name, connector=tasklet_symbol, access_info=access_info, - is_read=not get_outputs, - is_write=get_outputs, + is_read=access_type == AccessType.READ, + is_write=access_type == AccessType.WRITE, ) ) return memlets @@ -303,7 +346,13 @@ def visit_HorizontalRestriction( symbol_collector.add_symbol(axis.domain_symbol()) return dcir.HorizontalRestriction( - mask=node.mask, body=self.visit(node.body, symbol_collector=symbol_collector, **kwargs) + mask=node.mask, + body=self.visit( + node.body, + symbol_collector=symbol_collector, + inside_horizontal_region=True, + **kwargs, + ), ) def visit_VariableKOffset( @@ -319,166 +368,351 @@ def visit_FieldAccess( node: oir.FieldAccess, *, is_target: bool, - targets: Set[eve.SymbolRef], - var_offset_fields: Set[eve.SymbolRef], - K_write_with_offset: Set[eve.SymbolRef], + targets: list[oir.FieldAccess | oir.ScalarAccess], + var_offset_fields: set[eve.SymbolRef], + K_write_with_offset: set[eve.SymbolRef], **kwargs: Any, - ) -> Union[dcir.IndexAccess, dcir.ScalarAccess]: + ) -> dcir.IndexAccess | dcir.ScalarAccess: """Generate the relevant accessor to match the memlet that was previously setup. - When a Field is written in K, we force the usage of the OUT memlet throughout the stencil - to make sure all side effects are being properly resolved. Frontend checks ensure that no - parallel code issues sips here. + Args: + is_target (bool): true if we write to this FieldAccess """ - res: Union[dcir.IndexAccess, dcir.ScalarAccess] + # Distinguish between writing to a variable and reading a previously written variable. + # In the latter case (read after write), we need to read from the "gtOUT__" symbol. + is_write = is_target + is_target = is_target or ( + # read after write (within a code block) + any( + isinstance(t, oir.FieldAccess) and t.name == node.name and t.offset == node.offset + for t in targets + ) + ) + name = get_tasklet_symbol(node.name, offset=node.offset, is_target=is_target) + + access_node: dcir.IndexAccess | dcir.ScalarAccess if node.name in var_offset_fields.union(K_write_with_offset): - # If write in K, we consider the variable to always be a target - is_target = is_target or node.name in targets or node.name in K_write_with_offset - name = get_tasklet_symbol(node.name, node.offset, is_target=is_target) - res = dcir.IndexAccess( + access_node = dcir.IndexAccess( name=name, + is_target=is_target, offset=self.visit( node.offset, - is_target=is_target, + is_target=False, + targets=targets, + var_offset_fields=var_offset_fields, + K_write_with_offset=K_write_with_offset, + **kwargs, + ), + data_index=self.visit( + node.data_index, + is_target=False, targets=targets, var_offset_fields=var_offset_fields, K_write_with_offset=K_write_with_offset, **kwargs, ), - data_index=node.data_index, dtype=node.dtype, ) - else: - is_target = is_target or ( - node.name in targets and node.offset == common.CartesianOffset.zero() + elif node.data_index: + access_node = dcir.IndexAccess( + name=name, + offset=None, + is_target=is_target, + data_index=self.visit( + node.data_index, + is_target=False, + targets=targets, + var_offset_fields=var_offset_fields, + K_write_with_offset=K_write_with_offset, + **kwargs, + ), + dtype=node.dtype, ) - name = get_tasklet_symbol(node.name, node.offset, is_target=is_target) - if node.data_index: - res = dcir.IndexAccess( - name=name, offset=None, data_index=node.data_index, dtype=node.dtype - ) - else: - res = dcir.ScalarAccess(name=name, dtype=node.dtype) - if is_target: - targets.add(node.name) - return res + else: + access_node = dcir.ScalarAccess(name=name, dtype=node.dtype, is_target=is_write) + + if is_write and not any( + isinstance(t, oir.FieldAccess) and t.name == node.name and t.offset == node.offset + for t in targets + ): + targets.append(node) + return access_node def visit_ScalarAccess( self, node: oir.ScalarAccess, *, + is_target: bool, + targets: list[oir.FieldAccess | oir.ScalarAccess], global_ctx: DaCeIRBuilder.GlobalContext, symbol_collector: DaCeIRBuilder.SymbolCollector, - **kwargs: Any, + **_: Any, ) -> dcir.ScalarAccess: if node.name in global_ctx.library_node.declarations: + # Handle stencil parameters differently because they are always available symbol_collector.add_symbol(node.name, dtype=node.dtype) - return dcir.ScalarAccess(name=node.name, dtype=node.dtype) - - def visit_AssignStmt(self, node: oir.AssignStmt, *, targets, **kwargs: Any) -> dcir.AssignStmt: - # the visiting order matters here, since targets must not contain the target symbols from the left visit - right = self.visit(node.right, is_target=False, targets=targets, **kwargs) - left = self.visit(node.left, is_target=True, targets=targets, **kwargs) - return dcir.AssignStmt(left=left, right=right) - - def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs: Any) -> dcir.MaskStmt: - return dcir.MaskStmt( - mask=self.visit(node.mask, is_target=False, **kwargs), - body=self.visit(node.body, **kwargs), + return dcir.ScalarAccess(name=node.name, dtype=node.dtype, is_target=is_target) + + # Distinguish between writing to a variable and reading a previously written variable. + # In the latter case (read after write), we need to read from the "gtOUT__" symbol. + is_write = is_target + is_target = is_target or ( + # read after write (within a code block) + any(isinstance(t, oir.ScalarAccess) and t.name == node.name for t in targets) ) - def visit_While(self, node: oir.While, **kwargs: Any) -> dcir.While: - return dcir.While( - cond=self.visit(node.cond, is_target=False, **kwargs), - body=self.visit(node.body, **kwargs), - ) + if is_write and not any( + isinstance(t, oir.ScalarAccess) and t.name == node.name for t in targets + ): + targets.append(node) - def visit_Cast(self, node: oir.Cast, **kwargs: Any) -> dcir.Cast: - return dcir.Cast(dtype=node.dtype, expr=self.visit(node.expr, **kwargs)) - - def visit_NativeFuncCall(self, node: oir.NativeFuncCall, **kwargs: Any) -> dcir.NativeFuncCall: - return dcir.NativeFuncCall( - func=node.func, args=self.visit(node.args, **kwargs), dtype=node.dtype + # Rename local scalars inside tasklets such that we can pass them from one state + # to another (same as we do for index access). + tasklet_name = get_tasklet_symbol(node.name, is_target=is_target) + return dcir.ScalarAccess( + name=tasklet_name, original_name=node.name, dtype=node.dtype, is_target=is_write ) - def visit_TernaryOp(self, node: oir.TernaryOp, **kwargs: Any) -> dcir.TernaryOp: - return dcir.TernaryOp( - cond=self.visit(node.cond, **kwargs), - true_expr=self.visit(node.true_expr, **kwargs), - false_expr=self.visit(node.false_expr, **kwargs), - dtype=node.dtype, - ) + def visit_AssignStmt(self, node: oir.AssignStmt, **kwargs: Any) -> dcir.AssignStmt: + # Visiting order matters because targets must not contain the target symbols from the left visit + right = self.visit(node.right, is_target=False, **kwargs) + left = self.visit(node.left, is_target=True, **kwargs) + return dcir.AssignStmt(left=left, right=right, loc=node.loc) - def visit_HorizontalExecution( + def _condition_tasklet( self, - node: oir.HorizontalExecution, + node: oir.MaskStmt | oir.While, *, global_ctx: DaCeIRBuilder.GlobalContext, - iteration_ctx: DaCeIRBuilder.IterationContext, symbol_collector: DaCeIRBuilder.SymbolCollector, + horizontal_extent, k_interval, + iteration_ctx: DaCeIRBuilder.IterationContext, + targets: list[oir.FieldAccess | oir.ScalarAccess], **kwargs: Any, - ): - extent = global_ctx.library_node.get_extents(node) - decls = [self.visit(decl, **kwargs) for decl in node.declarations] - targets: Set[str] = set() - stmts = [ - self.visit( - stmt, + ) -> dcir.Tasklet: + condition_expression = node.mask if isinstance(node, oir.MaskStmt) else node.cond + prefix = "if" if isinstance(node, oir.MaskStmt) else "while" + tmp_name = f"{prefix}_expression_{id(node)}" + + # Reset the set of targets (used for detecting read after write inside a tasklet) + targets.clear() + + statement = dcir.AssignStmt( + right=self.visit( + condition_expression, + is_target=False, targets=targets, global_ctx=global_ctx, symbol_collector=symbol_collector, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + iteration_ctx=iteration_ctx, **kwargs, - ) - for stmt in node.body - ] + ), + left=dcir.ScalarAccess( + name=get_tasklet_symbol(tmp_name, is_target=True), + original_name=tmp_name, + dtype=common.DataType.BOOL, + loc=node.loc, + is_target=True, + ), + loc=node.loc, + ) - stages_idx = next( - idx - for idx, item in enumerate(global_ctx.library_node.expansion_specification) - if isinstance(item, Stages) + read_memlets: list[dcir.Memlet] = _get_tasklet_inout_memlets( + node, + AccessType.READ, + global_ctx=global_ctx, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + grid_subset=iteration_ctx.grid_subset, + dcir_statements=[statement], ) - expansion_items = global_ctx.library_node.expansion_specification[stages_idx + 1 :] - iteration_ctx = iteration_ctx.push_axes_extents( - {k: v for k, v in zip(dcir.Axis.dims_horizontal(), extent)} + tasklet = dcir.Tasklet( + label=f"eval_{prefix}_{id(node)}", + stmts=[statement], + read_memlets=read_memlets, + write_memlets=[], + ) + # See notes inside the function + self._fix_memlet_array_access( + tasklet=tasklet, + memlets=read_memlets, + global_context=global_ctx, + symbol_collector=symbol_collector, + targets=targets, + **kwargs, ) - iteration_ctx = iteration_ctx.push_expansion_items(expansion_items) - assert iteration_ctx.grid_subset == dcir.GridSubset.single_gridpoint() + return tasklet - read_memlets = _get_tasklet_inout_memlets( + def visit_MaskStmt( + self, + node: oir.MaskStmt, + global_ctx: DaCeIRBuilder.GlobalContext, + iteration_ctx: DaCeIRBuilder.IterationContext, + symbol_collector: DaCeIRBuilder.SymbolCollector, + horizontal_extent, + k_interval, + targets: list[oir.FieldAccess | oir.ScalarAccess], + inside_horizontal_region: bool = False, + **kwargs: Any, + ) -> dcir.MaskStmt | dcir.Condition: + if inside_horizontal_region: + # inside horizontal regions, we use old-style mask statements that + # might translate to if statements inside the tasklet + return dcir.MaskStmt( + mask=self.visit( + node.mask, + is_target=False, + global_ctx=global_ctx, + iteration_ctx=iteration_ctx, + symbol_collector=symbol_collector, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + inside_horizontal_region=inside_horizontal_region, + targets=targets, + **kwargs, + ), + body=self.visit( + node.body, + global_ctx=global_ctx, + iteration_ctx=iteration_ctx, + symbol_collector=symbol_collector, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + inside_horizontal_region=inside_horizontal_region, + targets=targets, + **kwargs, + ), + ) + + tasklet = self._condition_tasklet( node, - get_outputs=False, global_ctx=global_ctx, - grid_subset=iteration_ctx.grid_subset, + symbol_collector=symbol_collector, + horizontal_extent=horizontal_extent, k_interval=k_interval, + iteration_ctx=iteration_ctx, + targets=targets, + **kwargs, ) + code_block = self.visit( + oir.CodeBlock(body=node.body, loc=node.loc, label=f"condition_{id(node)}"), + global_ctx=global_ctx, + symbol_collector=symbol_collector, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + iteration_ctx=iteration_ctx, + targets=targets, + **kwargs, + ) + targets.clear() + return dcir.Condition(condition=tasklet, true_states=gtc_utils.listify(code_block)) + + def visit_While( + self, + node: oir.While, + global_ctx: DaCeIRBuilder.GlobalContext, + iteration_ctx: DaCeIRBuilder.IterationContext, + symbol_collector: DaCeIRBuilder.SymbolCollector, + horizontal_extent, + k_interval, + targets: list[oir.FieldAccess | oir.ScalarAccess], + inside_horizontal_region: bool = False, + **kwargs: Any, + ) -> dcir.While | dcir.WhileLoop: + if inside_horizontal_region: + # inside horizontal regions, we use old-style while statements that + # might translate to while statements inside the tasklet + return dcir.While( + cond=self.visit( + node.cond, + is_target=False, + global_ctx=global_ctx, + iteration_ctx=iteration_ctx, + symbol_collector=symbol_collector, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + inside_horizontal_region=inside_horizontal_region, + targets=targets, + **kwargs, + ), + body=self.visit( + node.body, + global_ctx=global_ctx, + iteration_ctx=iteration_ctx, + symbol_collector=symbol_collector, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + inside_horizontal_region=inside_horizontal_region, + targets=targets, + **kwargs, + ), + ) - write_memlets = _get_tasklet_inout_memlets( + tasklet = self._condition_tasklet( node, - get_outputs=True, global_ctx=global_ctx, - grid_subset=iteration_ctx.grid_subset, + symbol_collector=symbol_collector, + iteration_ctx=iteration_ctx, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + targets=targets, + **kwargs, + ) + code_block = self.visit( + oir.CodeBlock(body=node.body, loc=node.loc, label=f"while_{id(node)}"), + global_ctx=global_ctx, + symbol_collector=symbol_collector, + iteration_ctx=iteration_ctx, + horizontal_extent=horizontal_extent, k_interval=k_interval, + targets=targets, + **kwargs, ) + targets.clear() + return dcir.WhileLoop(condition=tasklet, body=code_block) - dcir_node = dcir.Tasklet( - decls=decls, stmts=stmts, read_memlets=read_memlets, write_memlets=write_memlets + def visit_Cast(self, node: oir.Cast, **kwargs: Any) -> dcir.Cast: + return dcir.Cast(dtype=node.dtype, expr=self.visit(node.expr, **kwargs)) + + def visit_NativeFuncCall(self, node: oir.NativeFuncCall, **kwargs: Any) -> dcir.NativeFuncCall: + return dcir.NativeFuncCall( + func=node.func, args=self.visit(node.args, **kwargs), dtype=node.dtype ) - for memlet in [*read_memlets, *write_memlets]: + def visit_TernaryOp(self, node: oir.TernaryOp, **kwargs: Any) -> dcir.TernaryOp: + return dcir.TernaryOp( + cond=self.visit(node.cond, **kwargs), + true_expr=self.visit(node.true_expr, **kwargs), + false_expr=self.visit(node.false_expr, **kwargs), + dtype=node.dtype, + ) + + def _fix_memlet_array_access( + self, + *, + tasklet: dcir.Tasklet, + memlets: list[dcir.Memlet], + global_context: DaCeIRBuilder.GlobalContext, + symbol_collector: DaCeIRBuilder.SymbolCollector, + **kwargs: Any, + ) -> None: + for memlet in memlets: """ This loop handles the special case of a tasklet performing array access. The memlet should pass the full array shape (no tiling) and the tasklet expression for array access should use all explicit indexes. """ - array_ndims = len(global_ctx.arrays[memlet.field].shape) - field_decl = global_ctx.library_node.field_decls[memlet.field] + array_ndims = len(global_context.arrays[memlet.field].shape) + field_decl = global_context.library_node.field_decls[memlet.field] # calculate array subset on original memlet memlet_subset = make_dace_subset( - global_ctx.library_node.access_infos[memlet.field], + global_context.library_node.access_infos[memlet.field], memlet.access_info, field_decl.data_dims, ) @@ -490,22 +724,171 @@ def visit_HorizontalExecution( ] if len(memlet_data_index) < array_ndims: reshape_memlet = False - for access_node in dcir_node.walk_values().if_isinstance(dcir.IndexAccess): + for access_node in tasklet.walk_values().if_isinstance(dcir.IndexAccess): if access_node.data_index and access_node.name == memlet.connector: - access_node.data_index = memlet_data_index + access_node.data_index - assert len(access_node.data_index) == array_ndims + # Order matters! + # Resolve first the cartesian dimensions packed in memlet_data_index + access_node.explicit_indices = [] + for data_index in memlet_data_index: + access_node.explicit_indices.append( + self.visit( + data_index, + symbol_collector=symbol_collector, + global_ctx=global_context, + **kwargs, + ) + ) + # Separate between case where K is offset or absolute and + # where it's a regular offset (should be dealt with the above memlet_data_index) + if access_node.offset: + access_node.explicit_indices.append(access_node.offset) + # Add any remaining data dimensions indexing + for data_index in access_node.data_index: + access_node.explicit_indices.append( + self.visit( + data_index, + symbol_collector=symbol_collector, + global_ctx=global_context, + is_target=False, + **kwargs, + ) + ) + assert len(access_node.explicit_indices) == array_ndims reshape_memlet = True if reshape_memlet: # ensure that memlet symbols used for array indexing are defined in context for sym in memlet.access_info.grid_subset.free_symbols: symbol_collector.add_symbol(sym) # set full shape on memlet - memlet.access_info = global_ctx.library_node.access_infos[memlet.field] + memlet.access_info = global_context.library_node.access_infos[memlet.field] + + def visit_CodeBlock( + self, + node: oir.CodeBlock, + *, + global_ctx: DaCeIRBuilder.GlobalContext, + iteration_ctx: DaCeIRBuilder.IterationContext, + symbol_collector: DaCeIRBuilder.SymbolCollector, + horizontal_extent, + k_interval, + targets: list[oir.FieldAccess | oir.ScalarAccess], + **kwargs: Any, + ): + # Reset the set of targets (used for detecting read after write inside a tasklet) + targets.clear() + statements = [ + self.visit( + statement, + targets=targets, + global_ctx=global_ctx, + symbol_collector=symbol_collector, + iteration_ctx=iteration_ctx, + k_interval=k_interval, + horizontal_extent=horizontal_extent, + **kwargs, + ) + for statement in node.body + ] + + # Gather all statements that aren't control flow (e.g. everything except Condition and WhileLoop), + # put them in a tasklet, and call "to_state" on it. + # Then, return a new list with types that are either ComputationState, Condition, or WhileLoop. + dace_nodes: list[dcir.ComputationState | dcir.Condition | dcir.WhileLoop] = [] + current_block: list[dcir.Stmt] = [] + for index, statement in enumerate(statements): + is_control_flow = isinstance(statement, (dcir.Condition, dcir.WhileLoop)) + if not is_control_flow: + current_block.append(statement) + + last_statement = index == len(statements) - 1 + if (is_control_flow or last_statement) and len(current_block) > 0: + read_memlets: list[dcir.Memlet] = _get_tasklet_inout_memlets( + node, + AccessType.READ, + global_ctx=global_ctx, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + grid_subset=iteration_ctx.grid_subset, + dcir_statements=current_block, + ) + write_memlets: list[dcir.Memlet] = _get_tasklet_inout_memlets( + node, + AccessType.WRITE, + global_ctx=global_ctx, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + grid_subset=iteration_ctx.grid_subset, + dcir_statements=current_block, + ) + tasklet = dcir.Tasklet( + label=node.label, + stmts=current_block, + read_memlets=read_memlets, + write_memlets=write_memlets, + ) + # See notes inside the function + self._fix_memlet_array_access( + tasklet=tasklet, + memlets=[*read_memlets, *write_memlets], + global_context=global_ctx, + symbol_collector=symbol_collector, + targets=targets, + **kwargs, + ) + + dace_nodes.append(*self.to_state(tasklet, grid_subset=iteration_ctx.grid_subset)) + + # reset block scope + current_block = [] + + # append control flow statement after new tasklet (if applicable) + if is_control_flow: + dace_nodes.append(statement) + + return dace_nodes + + def visit_HorizontalExecution( + self, + node: oir.HorizontalExecution, + *, + global_ctx: DaCeIRBuilder.GlobalContext, + iteration_ctx: DaCeIRBuilder.IterationContext, + symbol_collector: DaCeIRBuilder.SymbolCollector, + k_interval, + **kwargs: Any, + ): + extent = global_ctx.library_node.get_extents(node) + + stages_idx = next( + idx + for idx, item in enumerate(global_ctx.library_node.expansion_specification) + if isinstance(item, Stages) + ) + expansion_items = global_ctx.library_node.expansion_specification[stages_idx + 1 :] + + iteration_ctx = iteration_ctx.push_axes_extents( + {k: v for k, v in zip(dcir.Axis.dims_horizontal(), extent)} + ) + iteration_ctx = iteration_ctx.push_expansion_items(expansion_items) + assert iteration_ctx.grid_subset == dcir.GridSubset.single_gridpoint() + + code_block = oir.CodeBlock(body=node.body, loc=node.loc, label=f"he_{id(node)}") + targets: list[oir.FieldAccess | oir.ScalarAccess] = [] + dcir_nodes = self.visit( + code_block, + global_ctx=global_ctx, + iteration_ctx=iteration_ctx, + symbol_collector=symbol_collector, + horizontal_extent=global_ctx.library_node.get_extents(node), + k_interval=k_interval, + targets=targets, + **kwargs, + ) for item in reversed(expansion_items): iteration_ctx = iteration_ctx.pop() - dcir_node = self._process_iteration_item( - [dcir_node], + dcir_nodes = self._process_iteration_item( + dcir_nodes, item, global_ctx=global_ctx, iteration_ctx=iteration_ctx, @@ -514,7 +897,8 @@ def visit_HorizontalExecution( ) # pop stages context (pushed with push_grid_subset) iteration_ctx.pop() - return dcir_node + + return dcir_nodes def visit_VerticalLoopSection( self, @@ -577,7 +961,10 @@ def to_dataflow( nodes = flatten_list(nodes) if all(isinstance(n, (dcir.NestedSDFG, dcir.DomainMap, dcir.Tasklet)) for n in nodes): return nodes - if not all(isinstance(n, (dcir.ComputationState, dcir.DomainLoop)) for n in nodes): + if not all( + isinstance(n, (dcir.ComputationState, dcir.Condition, dcir.DomainLoop, dcir.WhileLoop)) + for n in nodes + ): raise ValueError("Can't mix dataflow and state nodes on same level.") read_memlets, write_memlets, field_memlets = union_inout_memlets(nodes) @@ -594,6 +981,7 @@ def to_dataflow( write_memlets = [ memlet.remove_read() for memlet in field_memlets if memlet.field in write_fields ] + return [ dcir.NestedSDFG( label=global_ctx.library_node.label, @@ -609,9 +997,12 @@ def to_dataflow( def to_state(self, nodes, *, grid_subset: dcir.GridSubset): nodes = flatten_list(nodes) - if all(isinstance(n, (dcir.ComputationState, dcir.DomainLoop)) for n in nodes): + if all( + isinstance(n, (dcir.ComputationState, dcir.Condition, dcir.DomainLoop, dcir.WhileLoop)) + for n in nodes + ): return nodes - if all(isinstance(n, (dcir.NestedSDFG, dcir.DomainMap, dcir.Tasklet)) for n in nodes): + if all(isinstance(n, (dcir.DomainMap, dcir.NestedSDFG, dcir.Tasklet)) for n in nodes): return [dcir.ComputationState(computations=nodes, grid_subset=grid_subset)] raise ValueError("Can't mix dataflow and state nodes on same level.") @@ -864,6 +1255,7 @@ def visit_VerticalLoop( read_fields = set(memlet.field for memlet in read_memlets) write_fields = set(memlet.field for memlet in write_memlets) + return dcir.NestedSDFG( label=global_ctx.library_node.label, states=self.to_state(computations, grid_subset=iteration_ctx.grid_subset), diff --git a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py index 20d7743661..06ef69dcf4 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py @@ -17,8 +17,7 @@ import dace.subsets import sympy -from gt4py.cartesian.gtc.dace import daceir as dcir -from gt4py.cartesian.gtc.dace.constants import CONNECTOR_PREFIX_IN, CONNECTOR_PREFIX_OUT +from gt4py.cartesian.gtc.dace import daceir as dcir, prefix from gt4py.cartesian.gtc.dace.expansion.daceir_builder import DaCeIRBuilder from gt4py.cartesian.gtc.dace.expansion.sdfg_builder import StencilComputationSDFGBuilder @@ -78,11 +77,11 @@ def _fix_context( """ # change connector names for in_edge in parent_state.in_edges(node): - assert in_edge.dst_conn.startswith(CONNECTOR_PREFIX_IN) - in_edge.dst_conn = in_edge.dst_conn.removeprefix(CONNECTOR_PREFIX_IN) + assert in_edge.dst_conn.startswith(prefix.CONNECTOR_IN) + in_edge.dst_conn = in_edge.dst_conn.removeprefix(prefix.CONNECTOR_IN) for out_edge in parent_state.out_edges(node): - assert out_edge.src_conn.startswith(CONNECTOR_PREFIX_OUT) - out_edge.src_conn = out_edge.src_conn.removeprefix(CONNECTOR_PREFIX_OUT) + assert out_edge.src_conn.startswith(prefix.CONNECTOR_OUT) + out_edge.src_conn = out_edge.src_conn.removeprefix(prefix.CONNECTOR_OUT) # union input and output subsets subsets = {} @@ -120,17 +119,25 @@ def _fix_context( if key in nsdfg.symbol_mapping: del nsdfg.symbol_mapping[key] + for edge in parent_state.in_edges(node): + if edge.dst_conn not in nsdfg.in_connectors: + # Drop connection if connector is not found in the expansion of the library node + parent_state.remove_edge(edge) + if parent_state.in_degree(edge.src) + parent_state.out_degree(edge.src) == 0: + # Remove node if it is now isolated + parent_state.remove_node(edge.src) + @staticmethod def _get_parent_arrays( node: StencilComputation, parent_state: dace.SDFGState, parent_sdfg: dace.SDFG ) -> Dict[str, dace.data.Data]: parent_arrays: Dict[str, dace.data.Data] = {} for edge in (e for e in parent_state.in_edges(node) if e.dst_conn is not None): - parent_arrays[edge.dst_conn.removeprefix(CONNECTOR_PREFIX_IN)] = parent_sdfg.arrays[ + parent_arrays[edge.dst_conn.removeprefix(prefix.CONNECTOR_IN)] = parent_sdfg.arrays[ edge.data.data ] for edge in (e for e in parent_state.out_edges(node) if e.src_conn is not None): - parent_arrays[edge.src_conn.removeprefix(CONNECTOR_PREFIX_OUT)] = parent_sdfg.arrays[ + parent_arrays[edge.src_conn.removeprefix(prefix.CONNECTOR_OUT)] = parent_sdfg.arrays[ edge.data.data ] return parent_arrays diff --git a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py index 3aeda7a484..c199891c13 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py @@ -18,12 +18,16 @@ import dace.subsets from gt4py import eve -from gt4py.cartesian.gtc.dace import daceir as dcir +from gt4py.cartesian.gtc.dace import daceir as dcir, prefix from gt4py.cartesian.gtc.dace.expansion.tasklet_codegen import TaskletCodegen from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo, make_dace_subset +def node_name_from_connector(connector: str) -> str: + return connector.removeprefix(prefix.TASKLET_OUT).removeprefix(prefix.TASKLET_IN) + + class StencilComputationSDFGBuilder(eve.VisitorWithSymbolTableTrait): @dataclass class NodeContext: @@ -36,18 +40,17 @@ class SDFGContext: state: dace.SDFGState state_stack: List[dace.SDFGState] = dataclasses.field(default_factory=list) - def add_state(self): - new_state = self.sdfg.add_state() + def add_state(self, label: Optional[str] = None) -> None: + new_state = self.sdfg.add_state(label=label) for edge in self.sdfg.out_edges(self.state): self.sdfg.remove_edge(edge) self.sdfg.add_edge(new_state, edge.dst, edge.data) self.sdfg.add_edge(self.state, new_state, dace.InterstateEdge()) self.state = new_state - return self - def add_loop(self, index_range: dcir.Range): - loop_state = self.sdfg.add_state() - after_state = self.sdfg.add_state() + def add_loop(self, index_range: dcir.Range) -> None: + loop_state = self.sdfg.add_state("loop_state") + after_state = self.sdfg.add_state("loop_after") for edge in self.sdfg.out_edges(self.state): self.sdfg.remove_edge(edge) self.sdfg.add_edge(after_state, edge.dst, edge.data) @@ -75,9 +78,126 @@ def add_loop(self, index_range: dcir.Range): self.state_stack.append(after_state) self.state = loop_state - return self - def pop_loop(self): + def pop_loop(self) -> None: + self._pop_last("loop_after") + + def add_condition(self, node: dcir.Condition) -> None: + """Inserts a condition after the current self.state. + + The condition consists of an initial state connected to a guard state, which branches + to a true_state and a false_state based on the given condition. Both states then merge + into a merge_state. + + self.state is set to init_state and the other states are pushed on the stack to be + popped with `pop_condition_*()` methods. + """ + # Data model validators enforce this to exist + assert isinstance(node.condition.stmts[0], dcir.AssignStmt) + assert isinstance(node.condition.stmts[0].left, dcir.ScalarAccess) + condition_name = node.condition.stmts[0].left.original_name + + merge_state = self.sdfg.add_state("condition_after") + for edge in self.sdfg.out_edges(self.state): + self.sdfg.remove_edge(edge) + self.sdfg.add_edge(merge_state, edge.dst, edge.data) + + # Evaluate node condition + init_state = self.sdfg.add_state("condition_init") + self.sdfg.add_edge(self.state, init_state, dace.InterstateEdge()) + + # Promote condition (from init_state) to symbol + condition_state = self.sdfg.add_state("condition_guard") + self.sdfg.add_edge( + init_state, + condition_state, + dace.InterstateEdge(assignments=dict(if_condition=condition_name)), + ) + + true_state = self.sdfg.add_state("condition_true") + self.sdfg.add_edge( + condition_state, true_state, dace.InterstateEdge(condition="if_condition") + ) + self.sdfg.add_edge(true_state, merge_state, dace.InterstateEdge()) + + false_state = self.sdfg.add_state("condition_false") + self.sdfg.add_edge( + condition_state, false_state, dace.InterstateEdge(condition="not if_condition") + ) + self.sdfg.add_edge(false_state, merge_state, dace.InterstateEdge()) + + self.state_stack.append(merge_state) + self.state_stack.append(false_state) + self.state_stack.append(true_state) + self.state_stack.append(condition_state) + self.state = init_state + + def pop_condition_guard(self) -> None: + self._pop_last("condition_guard") + + def pop_condition_true(self) -> None: + self._pop_last("condition_true") + + def pop_condition_false(self) -> None: + self._pop_last("condition_false") + + def pop_condition_after(self) -> None: + self._pop_last("condition_after") + + def add_while(self, node: dcir.WhileLoop) -> None: + """Inserts a while loop after the current state.""" + # Data model validators enforce this to exist + assert isinstance(node.condition.stmts[0], dcir.AssignStmt) + assert isinstance(node.condition.stmts[0].left, dcir.ScalarAccess) + condition_name = node.condition.stmts[0].left.original_name + + after_state = self.sdfg.add_state("while_after") + for edge in self.sdfg.out_edges(self.state): + self.sdfg.remove_edge(edge) + self.sdfg.add_edge(after_state, edge.dst, edge.data) + + # Evaluate loop condition + init_state = self.sdfg.add_state("while_init") + self.sdfg.add_edge(self.state, init_state, dace.InterstateEdge()) + + # Promote condition (from init_state) to symbol + guard_state = self.sdfg.add_state("while_guard") + self.sdfg.add_edge( + init_state, + guard_state, + dace.InterstateEdge(assignments=dict(loop_condition=condition_name)), + ) + + loop_state = self.sdfg.add_state("while_loop") + self.sdfg.add_edge( + guard_state, loop_state, dace.InterstateEdge(condition="loop_condition") + ) + # Loop back to init_state to re-evaluate the loop condition + self.sdfg.add_edge(loop_state, init_state, dace.InterstateEdge()) + + # Exit the loop + self.sdfg.add_edge( + guard_state, after_state, dace.InterstateEdge(condition="not loop_condition") + ) + + self.state_stack.append(after_state) + self.state_stack.append(loop_state) + self.state_stack.append(guard_state) + self.state = init_state + + def pop_while_guard(self) -> None: + self._pop_last("while_guard") + + def pop_while_loop(self) -> None: + self._pop_last("while_loop") + + def pop_while_after(self) -> None: + self._pop_last("while_after") + + def _pop_last(self, node_label: str | None = None) -> None: + if node_label is not None: + assert self.state_stack[-1].label.startswith(node_label) + self.state = self.state_stack[-1] del self.state_stack[-1] @@ -131,6 +251,91 @@ def _add_empty_edges( exit_node, None, *node_ctx.output_node_and_conns[None], dace.Memlet() ) + def visit_WhileLoop( + self, + node: dcir.WhileLoop, + *, + sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, + node_ctx: StencilComputationSDFGBuilder.NodeContext, + symtable: ChainMap[eve.SymbolRef, dcir.Decl], + **kwargs: Any, + ) -> None: + sdfg_ctx.add_while(node) + assert sdfg_ctx.state.label.startswith("while_init") + + read_acc_and_conn: dict[Optional[str], tuple[dace.nodes.Node, Optional[str]]] = {} + write_acc_and_conn: dict[Optional[str], tuple[dace.nodes.Node, Optional[str]]] = {} + for memlet in node.condition.read_memlets: + if memlet.field not in read_acc_and_conn: + read_acc_and_conn[memlet.field] = ( + sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), + None, + ) + for memlet in node.condition.write_memlets: + if memlet.field not in write_acc_and_conn: + write_acc_and_conn[memlet.field] = ( + sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), + None, + ) + eval_node_ctx = StencilComputationSDFGBuilder.NodeContext( + input_node_and_conns=read_acc_and_conn, output_node_and_conns=write_acc_and_conn + ) + self.visit( + node.condition, sdfg_ctx=sdfg_ctx, node_ctx=eval_node_ctx, symtable=symtable, **kwargs + ) + + sdfg_ctx.pop_while_guard() + sdfg_ctx.pop_while_loop() + + for state in node.body: + self.visit(state, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx, symtable=symtable, **kwargs) + + sdfg_ctx.pop_while_after() + + def visit_Condition( + self, + node: dcir.Condition, + *, + sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, + node_ctx: StencilComputationSDFGBuilder.NodeContext, + symtable: ChainMap[eve.SymbolRef, dcir.Decl], + **kwargs: Any, + ) -> None: + sdfg_ctx.add_condition(node) + assert sdfg_ctx.state.label.startswith("condition_init") + + read_acc_and_conn: dict[Optional[str], tuple[dace.nodes.Node, Optional[str]]] = {} + write_acc_and_conn: dict[Optional[str], tuple[dace.nodes.Node, Optional[str]]] = {} + for memlet in node.condition.read_memlets: + if memlet.field not in read_acc_and_conn: + read_acc_and_conn[memlet.field] = ( + sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), + None, + ) + for memlet in node.condition.write_memlets: + if memlet.field not in write_acc_and_conn: + write_acc_and_conn[memlet.field] = ( + sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), + None, + ) + eval_node_ctx = StencilComputationSDFGBuilder.NodeContext( + input_node_and_conns=read_acc_and_conn, output_node_and_conns=write_acc_and_conn + ) + self.visit( + node.condition, sdfg_ctx=sdfg_ctx, node_ctx=eval_node_ctx, symtable=symtable, **kwargs + ) + + sdfg_ctx.pop_condition_guard() + sdfg_ctx.pop_condition_true() + for state in node.true_states: + self.visit(state, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx, symtable=symtable, **kwargs) + + sdfg_ctx.pop_condition_false() + for state in node.false_states: + self.visit(state, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx, symtable=symtable, **kwargs) + + sdfg_ctx.pop_condition_after() + def visit_Tasklet( self, node: dcir.Tasklet, @@ -145,16 +350,100 @@ def visit_Tasklet( read_memlets=node.read_memlets, write_memlets=node.write_memlets, symtable=symtable, + sdfg=sdfg_ctx.sdfg, ) + # We are breaking up vertical loops inside stencils in multiple Tasklets + # It might thus happen that we write a "local" scalar in one Tasklet and + # read it in another Tasklet (downstream). + # We thus create output connectors for all writes to scalar variables + # inside Tasklets. And input connectors for all scalar reads unless + # previously written in the same Tasklet. DaCe's simplify pipeline will get + # rid of any dead dataflow introduced with this general approach. + scalar_inputs: set[str] = set() + scalar_outputs: set[str] = set() + + # Gather scalar writes in this Tasklet + for access_node in node.walk_values().if_isinstance(dcir.AssignStmt): + target_name = access_node.left.name + + field_access = ( + len( + set( + memlet.connector + for memlet in [*node.write_memlets] + if memlet.connector == target_name + ) + ) + > 0 + ) + if field_access or target_name in scalar_outputs: + continue + + assert isinstance(access_node.left, dcir.ScalarAccess) + assert ( + access_node.left.original_name is not None + ), "Original name not found for '{access_nodes.left.name}'. DaCeIR error." + + original_name = access_node.left.original_name + scalar_outputs.add(target_name) + if original_name not in sdfg_ctx.sdfg.arrays: + sdfg_ctx.sdfg.add_scalar( + original_name, + dtype=data_type_to_dace_typeclass(access_node.left.dtype), + transient=True, + ) + + # Gather scalar reads in this Tasklet + for access_node in node.walk_values().if_isinstance(dcir.ScalarAccess): + read_name = access_node.name + field_access = ( + len( + set( + memlet.connector + for memlet in [*node.read_memlets, *node.write_memlets] + if memlet.connector == read_name + ) + ) + > 0 + ) + defined_symbol = any(read_name in symbol_map for symbol_map in symtable.maps) + + if ( + not field_access + and not defined_symbol + and not access_node.is_target + and read_name.startswith(prefix.TASKLET_IN) + and read_name not in scalar_inputs + ): + scalar_inputs.add(read_name) + + inputs = set(memlet.connector for memlet in node.read_memlets).union(scalar_inputs) + outputs = set(memlet.connector for memlet in node.write_memlets).union(scalar_outputs) + tasklet = sdfg_ctx.state.add_tasklet( - name=f"{sdfg_ctx.sdfg.label}_Tasklet", + name=node.label, code=code, - inputs=set(memlet.connector for memlet in node.read_memlets), - outputs=set(memlet.connector for memlet in node.write_memlets), + inputs=inputs, + outputs=outputs, debuginfo=get_dace_debuginfo(node), ) + # Add memlets for scalars access (read/write) + for connector in scalar_outputs: + original_name = node_name_from_connector(connector) + access_node = sdfg_ctx.state.add_write(original_name) + sdfg_ctx.state.add_memlet_path( + tasklet, access_node, src_conn=connector, memlet=dace.Memlet(data=original_name) + ) + for connector in scalar_inputs: + original_name = node_name_from_connector(connector) + access_node = sdfg_ctx.state.add_read(original_name) + sdfg_ctx.state.add_memlet_path( + access_node, tasklet, dst_conn=connector, memlet=dace.Memlet(data=original_name) + ) + + # Add memlets for field access (read/write) self.visit( node.read_memlets, scope_node=tasklet, @@ -171,9 +460,6 @@ def visit_Tasklet( symtable=symtable, **kwargs, ) - StencilComputationSDFGBuilder._add_empty_edges( - tasklet, tasklet, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx - ) def visit_Range(self, node: dcir.Range, **kwargs: Any) -> Dict[str, str]: start, end = node.interval.to_dace_symbolic() @@ -204,13 +490,13 @@ def visit_DomainMap( input_node_and_conns: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] = {} output_node_and_conns: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] = {} for field in set(memlet.field for memlet in scope_node.read_memlets): - map_entry.add_in_connector("IN_" + field) - map_entry.add_out_connector("OUT_" + field) - input_node_and_conns[field] = (map_entry, "OUT_" + field) + map_entry.add_in_connector(f"{prefix.PASSTHROUGH_IN}{field}") + map_entry.add_out_connector(f"{prefix.PASSTHROUGH_OUT}{field}") + input_node_and_conns[field] = (map_entry, f"{prefix.PASSTHROUGH_OUT}{field}") for field in set(memlet.field for memlet in scope_node.write_memlets): - map_exit.add_in_connector("IN_" + field) - map_exit.add_out_connector("OUT_" + field) - output_node_and_conns[field] = (map_exit, "IN_" + field) + map_exit.add_in_connector(f"{prefix.PASSTHROUGH_IN}{field}") + map_exit.add_out_connector(f"{prefix.PASSTHROUGH_OUT}{field}") + output_node_and_conns[field] = (map_exit, f"{prefix.PASSTHROUGH_IN}{field}") if not input_node_and_conns: input_node_and_conns[None] = (map_entry, None) if not output_node_and_conns: @@ -226,7 +512,7 @@ def visit_DomainMap( scope_node=map_entry, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx, - connector_prefix="IN_", + connector_prefix=prefix.PASSTHROUGH_IN, **kwargs, ) self.visit( @@ -234,7 +520,7 @@ def visit_DomainMap( scope_node=map_exit, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx, - connector_prefix="OUT_", + connector_prefix=prefix.PASSTHROUGH_OUT, **kwargs, ) StencilComputationSDFGBuilder._add_empty_edges( @@ -248,7 +534,7 @@ def visit_DomainLoop( sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, **kwargs: Any, ) -> None: - sdfg_ctx = sdfg_ctx.add_loop(node.index_range) + sdfg_ctx.add_loop(node.index_range) self.visit(node.loop_states, sdfg_ctx=sdfg_ctx, **kwargs) sdfg_ctx.pop_loop() @@ -260,6 +546,13 @@ def visit_ComputationState( **kwargs: Any, ) -> None: sdfg_ctx.add_state() + + # node_ctx is used to keep track of memlets per ComputationState. Conditions and WhileLoops + # will (recursively) introduce more than one compute state per vertical loop. We thus drop + # any node_ctx that is potentially passed down and instead create a new one for each + # ComputationState that we encounter. + kwargs.pop("node_ctx", None) + read_acc_and_conn: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] = {} write_acc_and_conn: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] = {} for computation in node.computations: @@ -333,7 +626,13 @@ def visit_NestedSDFG( symbol_mapping = {decl.name: decl.to_dace_symbol() for decl in node.symbol_decls} for computation_state in node.states: - self.visit(computation_state, sdfg_ctx=inner_sdfg_ctx, symtable=symtable, **kwargs) + self.visit( + computation_state, + sdfg_ctx=inner_sdfg_ctx, + node_ctx=node_ctx, + symtable=symtable, + **kwargs, + ) if sdfg_ctx is not None and node_ctx is not None: nsdfg = sdfg_ctx.state.add_nested_sdfg( diff --git a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py index 2948b9d76d..29104b2a6e 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py @@ -30,23 +30,17 @@ def _visit_offset( node: Union[dcir.VariableKOffset, common.CartesianOffset], *, access_info: dcir.FieldAccessInfo, - decl: dcir.FieldDecl, **kwargs: Any, ) -> str: int_sizes: List[Optional[int]] = [] for i, axis in enumerate(access_info.axes()): memlet_shape = access_info.shape - if ( - str(memlet_shape[i]).isnumeric() - and axis not in decl.access_info.variable_offset_axes - ): + if str(memlet_shape[i]).isnumeric() and axis not in access_info.variable_offset_axes: int_sizes.append(int(memlet_shape[i])) else: int_sizes.append(None) sym_offsets = [ - dace.symbolic.pystr_to_symbolic( - self.visit(off, access_info=access_info, decl=decl, **kwargs) - ) + dace.symbolic.pystr_to_symbolic(self.visit(off, access_info=access_info, **kwargs)) for off in (node.to_dict()["i"], node.to_dict()["j"], node.k) ] for axis in access_info.variable_offset_axes: @@ -62,10 +56,26 @@ def _visit_offset( res = dace.subsets.Range([r for i, r in enumerate(ranges.ranges) if int_sizes[i] != 1]) return str(res) - def visit_CartesianOffset(self, node: common.CartesianOffset, **kwargs: Any) -> str: + def _explicit_indexing( + self, node: common.CartesianOffset | dcir.VariableKOffset, **kwargs: Any + ) -> str: + """If called from the explicit pass we need to be add manually the relative indexing""" + return f"__k+{self.visit(node.k, **kwargs)}" + + def visit_CartesianOffset( + self, node: common.CartesianOffset, explicit=False, **kwargs: Any + ) -> str: + if explicit: + return self._explicit_indexing(node, **kwargs) + return self._visit_offset(node, **kwargs) - def visit_VariableKOffset(self, node: dcir.VariableKOffset, **kwargs: Any) -> str: + def visit_VariableKOffset( + self, node: dcir.VariableKOffset, explicit=False, **kwargs: Any + ) -> str: + if explicit: + return self._explicit_indexing(node, **kwargs) + return self._visit_offset(node, **kwargs) def visit_IndexAccess( @@ -73,6 +83,7 @@ def visit_IndexAccess( node: dcir.IndexAccess, *, is_target: bool, + sdfg: dace.SDFG, symtable: ChainMap[eve.SymbolRef, dcir.Decl], **kwargs: Any, ) -> str: @@ -90,22 +101,46 @@ def visit_IndexAccess( "Memlet connector and tasklet variable mismatch, DaCe IR error." ) from None - index_strs = [] - if node.offset is not None: - index_strs.append( - self.visit( - node.offset, - decl=symtable[memlet.field], - access_info=memlet.access_info, - symtable=symtable, - in_idx=True, - **kwargs, + index_strs: list[str] = [] + if node.explicit_indices: + # Full array access with every dimensions accessed in full. + # Everything was packed in `explicit_indices` in `DaCeIRBuilder._fix_memlet_array_access()` + # along the `reshape_memlet=True` code path. + assert len(node.explicit_indices) == len(sdfg.arrays[memlet.field].shape) + for idx in node.explicit_indices: + index_strs.append( + self.visit( + idx, + symtable=symtable, + in_idx=True, + explicit=True, + **kwargs, + ) ) + else: + # Grid-point access, I & J are unitary, K can be offsetted with variable + # Resolve K offset (also resolves I & J) + if node.offset is not None: + index_strs.append( + self.visit( + node.offset, + access_info=memlet.access_info, + symtable=symtable, + in_idx=True, + **kwargs, + ) + ) + # Add any data dimensions + index_strs.extend( + self.visit(idx, symtable=symtable, in_idx=True, **kwargs) for idx in node.data_index ) - index_strs.extend( - self.visit(idx, symtable=symtable, in_idx=True, **kwargs) for idx in node.data_index + # Filter empty strings + non_empty_indices = list(filter(None, index_strs)) + return ( + f"{node.name}[{','.join(non_empty_indices)}]" + if len(non_empty_indices) > 0 + else node.name ) - return f"{node.name}[{','.join(index_strs)}]" def visit_AssignStmt(self, node: dcir.AssignStmt, **kwargs: Any) -> str: # Visiting order matters because targets must not contain the target symbols from the left visit @@ -207,10 +242,8 @@ def visit_UnaryOperator(self, op: common.UnaryOperator, **kwargs: Any) -> str: Param = as_fmt("{name}") - LocalScalarDecl = as_fmt("{name}: {dtype}") - def visit_Tasklet(self, node: dcir.Tasklet, **kwargs: Any) -> str: - return "\n".join(self.visit(node.decls, **kwargs) + self.visit(node.stmts, **kwargs)) + return "\n".join(self.visit(node.stmts, **kwargs)) def _visit_conditional( self, diff --git a/src/gt4py/cartesian/gtc/dace/expansion/utils.py b/src/gt4py/cartesian/gtc/dace/expansion/utils.py index 7a29ec99a6..637b348a03 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/utils.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/utils.py @@ -40,8 +40,8 @@ def visit_Tasklet(self, node: dcir.Tasklet): else: res_body.append(newstmt) return dcir.Tasklet( + label=f"he_remover_{id(node)}", stmts=res_body, - decls=node.decls, read_memlets=node.read_memlets, write_memlets=node.write_memlets, ) diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py index bd06da7d8f..d80e14296b 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py @@ -17,8 +17,7 @@ import gt4py.cartesian.gtc.oir as oir from gt4py import eve -from gt4py.cartesian.gtc.dace import daceir as dcir -from gt4py.cartesian.gtc.dace.constants import CONNECTOR_PREFIX_IN, CONNECTOR_PREFIX_OUT +from gt4py.cartesian.gtc.dace import daceir as dcir, prefix from gt4py.cartesian.gtc.dace.nodes import StencilComputation from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass from gt4py.cartesian.gtc.dace.utils import ( @@ -129,7 +128,7 @@ def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFG for field in access_collection.read_fields(): access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field])) - connector_name = f"{CONNECTOR_PREFIX_IN}{field}" + connector_name = f"{prefix.CONNECTOR_IN}{field}" library_node.add_in_connector(connector_name) subset = ctx.make_input_dace_subset(node, field) state.add_edge( @@ -138,7 +137,7 @@ def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFG for field in access_collection.write_fields(): access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field])) - connector_name = f"{CONNECTOR_PREFIX_OUT}{field}" + connector_name = f"{prefix.CONNECTOR_OUT}{field}" library_node.add_out_connector(connector_name) subset = ctx.make_output_dace_subset(node, field) state.add_edge( diff --git a/src/gt4py/cartesian/gtc/dace/prefix.py b/src/gt4py/cartesian/gtc/dace/prefix.py new file mode 100644 index 0000000000..1da9eb95f3 --- /dev/null +++ b/src/gt4py/cartesian/gtc/dace/prefix.py @@ -0,0 +1,23 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +from typing import Final + + +# DaCe passthrough prefixes +PASSTHROUGH_IN: Final[str] = "IN_" +PASSTHROUGH_OUT: Final[str] = "OUT_" + +# StencilComputation in/out connector prefixes +CONNECTOR_IN: Final[str] = "__in_" +CONNECTOR_OUT: Final[str] = "__out_" + +# Tasklet in/out connector prefixes +TASKLET_IN: Final[str] = "gtIN__" +TASKLET_OUT: Final[str] = "gtOUT__" diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index 4e8a0f0c7b..4ef48ebcd9 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -19,7 +19,7 @@ from gt4py import eve from gt4py.cartesian.gtc import common, oir from gt4py.cartesian.gtc.common import CartesianOffset, VariableKOffset -from gt4py.cartesian.gtc.dace import daceir as dcir +from gt4py.cartesian.gtc.dace import daceir as dcir, prefix from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_horizontal_block_extents @@ -67,22 +67,25 @@ def replace_strides(arrays: List[dace.data.Array], get_layout_map) -> Dict[str, def get_tasklet_symbol( - name: eve.SymbolRef, offset: Union[CartesianOffset, VariableKOffset], is_target: bool + name: str, + *, + offset: Optional[CartesianOffset | VariableKOffset] = None, + is_target: bool, ): - if is_target: - return f"gtOUT__{name}" - - acc_name = f"gtIN__{name}" - if offset is not None: - offset_strs = [] - for axis in dcir.Axis.dims_3d(): - off = offset.to_dict()[axis.lower()] - if off is not None and off != 0: - offset_strs.append(axis.lower() + ("m" if off < 0 else "p") + f"{abs(off):d}") - suffix = "_".join(offset_strs) - if suffix != "": - acc_name += suffix - return acc_name + access_name = f"{prefix.TASKLET_OUT}{name}" if is_target else f"{prefix.TASKLET_IN}{name}" + if offset is None: + return access_name + + # add (per axis) offset markers, e.g. gtIN__A_km1 for A[0, 0, -1] + offset_strings = [] + for axis in dcir.Axis.dims_3d(): + axis_offset = offset.to_dict()[axis.lower()] + if axis_offset is not None and axis_offset != 0: + offset_strings.append( + axis.lower() + ("m" if axis_offset < 0 else "p") + f"{abs(axis_offset):d}" + ) + + return access_name + "_".join(offset_strings) def axes_list_from_flags(flags): @@ -196,7 +199,8 @@ def visit_MaskStmt(self, node: oir.MaskStmt, *, is_conditional=False, **kwargs): self.visit(node.body, is_conditional=True, **kwargs) def visit_While(self, node: oir.While, *, is_conditional=False, **kwargs): - self.generic_visit(node, is_conditional=True, **kwargs) + self.visit(node.cond, is_conditional=is_conditional, **kwargs) + self.visit(node.body, is_conditional=True, **kwargs) @staticmethod def _global_grid_subset( @@ -242,12 +246,8 @@ def _make_access_info( is_write, ) -> dcir.FieldAccessInfo: # Check we have expression offsets in K - # OR write offsets in K offset = [offset_node.to_dict()[k] for k in "ijk"] - if isinstance(offset_node, oir.VariableKOffset) or (offset[2] != 0 and is_write): - variable_offset_axes = [dcir.Axis.K] - else: - variable_offset_axes = [] + variable_offset_axes = [dcir.Axis.K] if isinstance(offset_node, oir.VariableKOffset) else [] global_subset = self._global_grid_subset(region, he_grid, offset) intervals = {} @@ -266,7 +266,6 @@ def _make_access_info( return dcir.FieldAccessInfo( grid_subset=grid_subset, global_grid_subset=global_subset, - dynamic_access=len(variable_offset_axes) > 0 or is_conditional or region is not None, variable_offset_axes=variable_offset_axes, ) @@ -347,6 +346,170 @@ def compute_dcir_access_infos( return ctx.access_infos +class TaskletAccessInfoCollector(eve.NodeVisitor): + @dataclass + class Context: + axes: dict[str, list[dcir.Axis]] + access_infos: dict[str, dcir.FieldAccessInfo] = field(default_factory=dict) + + def __init__( + self, collect_read: bool, collect_write: bool, *, horizontal_extent, k_interval, grid_subset + ): + self.collect_read: bool = collect_read + self.collect_write: bool = collect_write + + self.ij_grid = dcir.GridSubset.from_gt4py_extent(horizontal_extent) + self.he_grid = self.ij_grid.set_interval(dcir.Axis.K, k_interval) + self.grid_subset = grid_subset + + def visit_CodeBlock(self, _node: oir.CodeBlock, **_kwargs): + raise RuntimeError("We shouldn't reach code blocks anymore") + + def visit_AssignStmt(self, node: oir.AssignStmt, **kwargs): + self.visit(node.right, is_write=False, **kwargs) + self.visit(node.left, is_write=True, **kwargs) + + def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs): + self.visit(node.mask, is_write=False, **kwargs) + self.visit(node.body, **kwargs) + + def visit_While(self, node: oir.While, **kwargs): + self.visit(node.cond, is_write=False, **kwargs) + self.visit(node.body, **kwargs) + + def visit_HorizontalRestriction(self, node: oir.HorizontalRestriction, **kwargs): + self.visit(node.mask, is_write=False, **kwargs) + self.visit(node.body, region=node.mask, **kwargs) + + def _global_grid_subset( + self, + region: Optional[common.HorizontalMask], + offset: list[Optional[int]], + ): + res: dict[dcir.Axis, dcir.DomainInterval | dcir.IndexWithExtent | dcir.TileInterval] = {} + if region is not None: + for axis, oir_interval in zip(dcir.Axis.dims_horizontal(), region.intervals): + he_grid_interval = self.he_grid.intervals[axis] + assert isinstance(he_grid_interval, dcir.DomainInterval) + start = ( + oir_interval.start if oir_interval.start is not None else he_grid_interval.start + ) + end = oir_interval.end if oir_interval.end is not None else he_grid_interval.end + dcir_interval = dcir.DomainInterval( + start=dcir.AxisBound.from_common(axis, start), + end=dcir.AxisBound.from_common(axis, end), + ) + res[axis] = dcir.DomainInterval.union(dcir_interval, res.get(axis, dcir_interval)) + if dcir.Axis.K in self.he_grid.intervals: + off = offset[dcir.Axis.K.to_idx()] or 0 + he_grid_k_interval = self.he_grid.intervals[dcir.Axis.K] + assert not isinstance(he_grid_k_interval, dcir.TileInterval) + res[dcir.Axis.K] = he_grid_k_interval.shifted(off) + for axis in dcir.Axis.dims_horizontal(): + iteration_interval = self.he_grid.intervals[axis] + mask_interval = res.get(axis, iteration_interval) + res[axis] = dcir.DomainInterval.intersection( + axis, iteration_interval, mask_interval + ).shifted(offset[axis.to_idx()]) + return dcir.GridSubset(intervals=res) + + def _make_access_info( + self, + offset_node: CartesianOffset | VariableKOffset, + axes, + region: Optional[common.HorizontalMask], + ) -> dcir.FieldAccessInfo: + # Check we have expression offsets in K + offset = [offset_node.to_dict()[k] for k in "ijk"] + variable_offset_axes = [dcir.Axis.K] if isinstance(offset_node, VariableKOffset) else [] + + global_subset = self._global_grid_subset(region, offset) + intervals = {} + for axis in axes: + extent = ( + (0, 0) + if axis in variable_offset_axes + else (offset[axis.to_idx()], offset[axis.to_idx()]) + ) + intervals[axis] = dcir.IndexWithExtent( + axis=axis, value=axis.iteration_symbol(), extent=extent + ) + + return dcir.FieldAccessInfo( + grid_subset=dcir.GridSubset(intervals=intervals), + global_grid_subset=global_subset, + # Field access inside horizontal regions might or might not happen + dynamic_access=region is not None, + variable_offset_axes=variable_offset_axes, + ) + + def visit_FieldAccess( + self, + node: oir.FieldAccess, + *, + is_write: bool, + region: Optional[common.HorizontalMask] = None, + ctx: TaskletAccessInfoCollector.Context, + **kwargs, + ): + self.visit(node.offset, ctx=ctx, is_write=False, region=region, **kwargs) + + if (is_write and not self.collect_write) or (not is_write and not self.collect_read): + return + + access_info = self._make_access_info( + node.offset, + axes=ctx.axes[node.name], + region=region, + ) + ctx.access_infos[node.name] = access_info.union( + ctx.access_infos.get(node.name, access_info) + ) + + +def compute_tasklet_access_infos( + node: oir.CodeBlock | oir.MaskStmt | oir.While, + *, + collect_read: bool = True, + collect_write: bool = True, + declarations: dict[str, oir.Decl], + horizontal_extent, + k_interval, + grid_subset, +): + """ + Compute access information needed to build Memlets for the Tasklet + associated with the given `node`. + """ + axes = { + name: axes_list_from_flags(declaration.dimensions) + for name, declaration in declarations.items() + if isinstance(declaration, oir.FieldDecl) + } + ctx = TaskletAccessInfoCollector.Context(axes=axes, access_infos=dict()) + collector = TaskletAccessInfoCollector( + collect_read=collect_read, + collect_write=collect_write, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + grid_subset=grid_subset, + ) + if isinstance(node, oir.CodeBlock): + collector.visit(node.body, ctx=ctx) + elif isinstance(node, oir.MaskStmt): + # node.mask is a simple expression. + # Pass `is_write` explicitly since we don't automatically set it in `visit_AssignStmt()` + collector.visit(node.mask, ctx=ctx, is_write=False) + elif isinstance(node, oir.While): + # node.cond is a simple expression. + # Pass `is_write` explicitly since we don't automatically set it in `visit_AssignStmt()` + collector.visit(node.cond, ctx=ctx, is_write=False) + else: + raise ValueError("Unexpected node type.") + + return ctx.access_infos + + def make_dace_subset( context_info: dcir.FieldAccessInfo, access_info: dcir.FieldAccessInfo, @@ -357,7 +520,7 @@ def make_dace_subset( for axis in access_info.axes(): if axis in access_info.variable_offset_axes: clamped_access_info = clamped_access_info.clamp_full_axis(axis) - if axis in clamped_context_info.variable_offset_axes: + if axis in context_info.variable_offset_axes: clamped_context_info = clamped_context_info.clamp_full_axis(axis) res_ranges = [] diff --git a/src/gt4py/cartesian/gtc/oir.py b/src/gt4py/cartesian/gtc/oir.py index 9f24db6e48..1ba36b5077 100644 --- a/src/gt4py/cartesian/gtc/oir.py +++ b/src/gt4py/cartesian/gtc/oir.py @@ -33,6 +33,10 @@ class Stmt(common.Stmt): pass +class CodeBlock(common.BlockStmt[Stmt], Stmt): + label: str + + class Literal(common.Literal, Expr): pass diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py index 16a1860d9d..faeca7b8dc 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py @@ -6,29 +6,31 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import pathlib -import re -import typing +import pytest +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import dace +else: + dace = pytest.importorskip("dace") import hypothesis.strategies as hyp_st import numpy as np -import pytest +import pathlib +import re +import typing from gt4py import cartesian as gt4pyc, storage as gt_storage from gt4py.cartesian import gtscript from gt4py.cartesian.gtscript import PARALLEL, computation, interval from gt4py.cartesian.stencil_builder import StencilBuilder from gt4py.storage.cartesian import utils as storage_utils +from gt4py.cartesian.backend.dace_lazy_stencil import DaCeLazyStencil from cartesian_tests.utils import OriginWrapper - -dace = pytest.importorskip("dace") -from gt4py.cartesian.backend.dace_lazy_stencil import ( # noqa: E402 [import-shadowed-by-loop-var] 'importorskip' is needed - DaCeLazyStencil, -) - - +# Because "dace tests" filter by `requires_dace`, we still need to add the marker. +# This global variable add the marker to all test functions in this module. pytestmark = [pytest.mark.requires_dace, pytest.mark.usefixtures("dace_env")] diff --git a/src/gt4py/cartesian/gtc/dace/constants.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/__init__.py similarity index 56% rename from src/gt4py/cartesian/gtc/dace/constants.py rename to tests/cartesian_tests/unit_tests/test_gtc/dace/__init__.py index 5565f1c186..c1f188446b 100644 --- a/src/gt4py/cartesian/gtc/dace/constants.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/__init__.py @@ -6,10 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import pytest -from typing import Final - - -# StencilComputation in/out connector prefixes -CONNECTOR_PREFIX_IN: Final = "__in_" -CONNECTOR_PREFIX_OUT: Final = "__out_" +# Skip this entire folder when we collecting tests and "dace" is not installed as a dependency. +pytest.importorskip("dace") diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_daceir_builder.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_daceir_builder.py new file mode 100644 index 0000000000..af23d7056a --- /dev/null +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_daceir_builder.py @@ -0,0 +1,109 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from gt4py.cartesian.gtc.dace import daceir as dcir + +from cartesian_tests.unit_tests.test_gtc.dace import utils +from cartesian_tests.unit_tests.test_gtc.oir_utils import ( + AssignStmtFactory, + BinaryOpFactory, + HorizontalExecutionFactory, + LiteralFactory, + LocalScalarFactory, + MaskStmtFactory, + ScalarAccessFactory, + StencilFactory, + WhileFactory, +) + + +# Because "dace tests" filter by `requires_dace`, we still need to add the marker. +# This global variable add the marker to all test functions in this module. +pytestmark = pytest.mark.requires_dace + + +def test_dcir_code_structure_condition() -> None: + """Tests the following code structure: + + ComputationState + Condition + true_states: [ComputationState] + false_states: [] + ComputationState + """ + stencil = StencilFactory( + vertical_loops__0__sections__0__horizontal_executions=[ + HorizontalExecutionFactory( + body=[ + AssignStmtFactory( + left=ScalarAccessFactory(name="tmp"), + right=BinaryOpFactory( + left=LiteralFactory(value="0"), right=LiteralFactory(value="2") + ), + ), + MaskStmtFactory(), + AssignStmtFactory( + left=ScalarAccessFactory(name="other"), + right=ScalarAccessFactory(name="tmp"), + ), + ], + declarations=[LocalScalarFactory(name="tmp"), LocalScalarFactory(name="other")], + ), + ] + ) + expansions = utils.library_node_expansions(stencil) + assert len(expansions) == 1, "expect one vertical loop to be expanded" + + nested_SDFG = utils.nested_SDFG_inside_triple_loop(expansions[0]) + assert isinstance(nested_SDFG.states[0], dcir.ComputationState) + assert isinstance(nested_SDFG.states[1], dcir.Condition) + assert nested_SDFG.states[1].true_states + assert isinstance(nested_SDFG.states[1].true_states[0], dcir.ComputationState) + assert not nested_SDFG.states[1].false_states + assert isinstance(nested_SDFG.states[2], dcir.ComputationState) + + +def test_dcir_code_structure_while() -> None: + """Tests the following code structure + + ComputationState + WhileLoop + body: [ComputationState] + ComputationState + """ + stencil = StencilFactory( + vertical_loops__0__sections__0__horizontal_executions=[ + HorizontalExecutionFactory( + body=[ + AssignStmtFactory( + left=ScalarAccessFactory(name="tmp"), + right=BinaryOpFactory( + left=LiteralFactory(value="0"), right=LiteralFactory(value="2") + ), + ), + WhileFactory(), + AssignStmtFactory( + left=ScalarAccessFactory(name="other"), + right=ScalarAccessFactory(name="tmp"), + ), + ], + declarations=[LocalScalarFactory(name="tmp"), LocalScalarFactory(name="other")], + ), + ] + ) + expansions = utils.library_node_expansions(stencil) + assert len(expansions) == 1, "expect one vertical loop to be expanded" + + nested_SDFG = utils.nested_SDFG_inside_triple_loop(expansions[0]) + assert isinstance(nested_SDFG.states[0], dcir.ComputationState) + assert isinstance(nested_SDFG.states[1], dcir.WhileLoop) + assert nested_SDFG.states[1].body + assert isinstance(nested_SDFG.states[1].body[0], dcir.ComputationState) + assert isinstance(nested_SDFG.states[2], dcir.ComputationState) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_sdfg_builder.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_sdfg_builder.py new file mode 100644 index 0000000000..561e994b27 --- /dev/null +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_sdfg_builder.py @@ -0,0 +1,144 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dace +import pytest + +from gt4py.cartesian.gtc.common import BuiltInLiteral, DataType +from gt4py.cartesian.gtc.dace.expansion.sdfg_builder import StencilComputationSDFGBuilder + +from cartesian_tests.unit_tests.test_gtc.dace import utils +from cartesian_tests.unit_tests.test_gtc.oir_utils import ( + AssignStmtFactory, + BinaryOpFactory, + HorizontalExecutionFactory, + LiteralFactory, + LocalScalarFactory, + MaskStmtFactory, + ScalarAccessFactory, + StencilFactory, +) + + +# Because "dace tests" filter by `requires_dace`, we still need to add the marker. +# This global variable add the marker to all test functions in this module. +pytestmark = pytest.mark.requires_dace + + +def test_scalar_access_multiple_tasklets() -> None: + """Test scalar access if an oir.CodeBlock is split over multiple Tasklets. + + We are breaking up vertical loops inside stencils in multiple Tasklets. It might thus happen that + we write a "local" scalar in one Tasklet and read it in another Tasklet (downstream). + We thus create output connectors for all writes to scalar variables inside Tasklets. And input + connectors for all scalar reads unless previously written in the same Tasklet. DaCe's simplify + pipeline will get rid of any dead dataflow introduced with this general approach. + """ + stencil = StencilFactory( + vertical_loops__0__sections__0__horizontal_executions=[ + HorizontalExecutionFactory( + body=[ + AssignStmtFactory( + left=ScalarAccessFactory(name="tmp"), + right=BinaryOpFactory( + left=LiteralFactory(value="0"), right=LiteralFactory(value="2") + ), + ), + MaskStmtFactory( + mask=LiteralFactory(value=BuiltInLiteral.TRUE, dtype=DataType.BOOL), body=[] + ), + AssignStmtFactory( + left=ScalarAccessFactory(name="other"), + right=ScalarAccessFactory(name="tmp"), + ), + ], + declarations=[LocalScalarFactory(name="tmp"), LocalScalarFactory(name="other")], + ), + ] + ) + expansions = utils.library_node_expansions(stencil) + nsdfg = StencilComputationSDFGBuilder().visit(expansions[0]) + assert isinstance(nsdfg.sdfg, dace.SDFG) + + for node in nsdfg.sdfg.nodes()[1].nodes(): + if not isinstance(node, dace.nodes.NestedSDFG): + continue + + nested = node.sdfg + for state in nested.states(): + if state.name == "block_0": + nodes = state.nodes() + assert ( + len(list(filter(lambda node: isinstance(node, dace.nodes.Tasklet), nodes))) == 1 + ) + assert ( + len( + list( + filter( + lambda node: isinstance(node, dace.nodes.AccessNode) + and node.data == "tmp", + nodes, + ) + ) + ) + == 1 + ), "one AccessNode of tmp" + + edges = state.edges() + tasklet = list(filter(lambda node: isinstance(node, dace.nodes.Tasklet), nodes))[0] + write_access = list( + filter( + lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "tmp", + nodes, + ) + )[0] + assert len(edges) == 1, "one edge expected" + assert ( + edges[0].src == tasklet and edges[0].dst == write_access + ), "write access of 'tmp'" + + if state.name == "block_1": + nodes = state.nodes() + assert ( + len(list(filter(lambda node: isinstance(node, dace.nodes.Tasklet), nodes))) == 1 + ) + assert ( + len( + list( + filter( + lambda node: isinstance(node, dace.nodes.AccessNode) + and node.data == "tmp", + nodes, + ) + ) + ) + == 1 + ), "one AccessNode of tmp" + + edges = state.edges() + tasklet = list(filter(lambda node: isinstance(node, dace.nodes.Tasklet), nodes))[0] + read_access = list( + filter( + lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "tmp", + nodes, + ) + )[0] + write_access = list( + filter( + lambda node: isinstance(node, dace.nodes.AccessNode) + and node.data == "other", + nodes, + ) + )[0] + assert len(edges) == 2, "two edges expected" + assert ( + edges[0].src == tasklet and edges[0].dst == write_access + ), "write access of 'other'" + assert ( + edges[1].src == read_access and edges[1].dst == tasklet + ), "read access of 'tmp'" diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_utils.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_utils.py new file mode 100644 index 0000000000..ab501d722e --- /dev/null +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_utils.py @@ -0,0 +1,44 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from typing import Optional + +from gt4py.cartesian.gtc.common import DataType, CartesianOffset +from gt4py.cartesian.gtc.dace import daceir as dcir +from gt4py.cartesian.gtc.dace import prefix +from gt4py.cartesian.gtc.dace import utils + +# Because "dace tests" filter by `requires_dace`, we still need to add the marker. +# This global variable add the marker to all test functions in this module. +pytestmark = pytest.mark.requires_dace + + +@pytest.mark.parametrize( + "name,is_target,offset,expected", + [ + ("A", False, None, f"{prefix.TASKLET_IN}A"), + ("A", True, None, f"{prefix.TASKLET_OUT}A"), + ("A", True, CartesianOffset(i=0, j=0, k=-1), f"{prefix.TASKLET_OUT}Akm1"), + ("A", False, CartesianOffset(i=1, j=-2, k=3), f"{prefix.TASKLET_IN}Aip1_jm2_kp3"), + ( + "A", + True, + dcir.VariableKOffset(k=dcir.Literal(value="3", dtype=DataType.INT32)), + f"{prefix.TASKLET_OUT}A", + ), + ], +) +def test_get_tasklet_symbol( + name: str, + is_target: bool, + offset: Optional[CartesianOffset | dcir.VariableKOffset], + expected: str, +) -> None: + assert utils.get_tasklet_symbol(name, is_target=is_target, offset=offset) == expected diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/utils.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/utils.py new file mode 100644 index 0000000000..b976631017 --- /dev/null +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/utils.py @@ -0,0 +1,54 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dace + +from gt4py.cartesian.gtc.dace import daceir as dcir +from gt4py.cartesian.gtc.dace.expansion.daceir_builder import DaCeIRBuilder +from gt4py.cartesian.gtc.dace.nodes import StencilComputation +from gt4py.cartesian.gtc.dace.oir_to_dace import OirSDFGBuilder +from gt4py.cartesian.gtc.dace.expansion.expansion import StencilComputationExpansion + +from cartesian_tests.unit_tests.test_gtc.oir_utils import StencilFactory + + +def library_node_expansions(stencil: StencilFactory) -> list[dcir.NestedSDFG]: + """Return all expanded library nodes in a given stencil.""" + sdfg = OirSDFGBuilder().visit(stencil) + assert isinstance(sdfg, dace.SDFG) + + expansions = [] + for state in sdfg.nodes(): + for node in state.nodes(): + if not isinstance(node, StencilComputation): + continue + + arrays = StencilComputationExpansion._get_parent_arrays(node, state, sdfg) + nested_SDFG = DaCeIRBuilder().visit( + node.oir_node, + global_ctx=DaCeIRBuilder.GlobalContext(library_node=node, arrays=arrays), + ) + expansions.append(nested_SDFG) + + return expansions + + +def nested_SDFG_inside_triple_loop(nSDFG: dcir.NestedSDFG) -> dcir.NestedSDFG: + """Pick the inner nested SDFG out of the triple loop.""" + assert isinstance(nSDFG, dcir.NestedSDFG) + assert isinstance(nSDFG.states[0], dcir.ComputationState) + assert isinstance(nSDFG.states[0].computations[0], dcir.DomainMap) + assert isinstance(nSDFG.states[0].computations[0].computations[0], dcir.DomainMap) + assert isinstance( + nSDFG.states[0].computations[0].computations[0].computations[0], dcir.DomainMap + ) + assert isinstance( + nSDFG.states[0].computations[0].computations[0].computations[0].computations[0], + dcir.NestedSDFG, + ) + return nSDFG.states[0].computations[0].computations[0].computations[0].computations[0] diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_dace.py b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_dace.py new file mode 100644 index 0000000000..9b8c127156 --- /dev/null +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_dace.py @@ -0,0 +1,159 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import dace +else: + dace = pytest.importorskip("dace") + +from gt4py.cartesian.gtc import oir +from gt4py.cartesian.gtc.common import DataType +from gt4py.cartesian.gtc.dace.nodes import StencilComputation +from gt4py.cartesian.gtc.dace.oir_to_dace import OirSDFGBuilder + +from cartesian_tests.unit_tests.test_gtc.oir_utils import ( + AssignStmtFactory, + FieldAccessFactory, + FieldDeclFactory, + ScalarAccessFactory, + StencilFactory, +) + +# Because "dace tests" filter by `requires_dace`, we still need to add the marker. +# This global variable add the marker to all test functions in this module. +pytestmark = pytest.mark.requires_dace + + +def test_oir_sdfg_builder_copy_stencil() -> None: + stencil_name = "copy" + stencil = StencilFactory( + name=stencil_name, + params=[ + FieldDeclFactory(name="A", dtype=DataType.FLOAT32), + FieldDeclFactory(name="B", dtype=DataType.FLOAT32), + ], + vertical_loops__0__sections__0__horizontal_executions__0__body=[ + AssignStmtFactory(left=FieldAccessFactory(name="B"), right=FieldAccessFactory(name="A")) + ], + ) + sdfg = OirSDFGBuilder().visit(stencil) + + assert isinstance(sdfg, dace.SDFG), "DaCe SDFG expected" + assert sdfg.name == stencil_name, "Stencil name is preserved" + assert len(sdfg.arrays) == 2, "two arrays expected (A and B)" + + a_array = sdfg.arrays.get("A") + assert a_array is not None, "Array A expected to be defined" + assert a_array.ctype == "float", "A is of type `float`" + assert a_array.offset == (0, 0, 0), "CartesianOffset.zero() expected" + + b_array = sdfg.arrays.get("B") + assert b_array is not None, "Array B expected to be defined" + assert b_array.ctype == "float", "B is of type `float`" + assert b_array.offset == (0, 0, 0), "CartesianOffset.zero() expected" + + states = sdfg.nodes() + assert len(states) >= 1, "at least one state expected" + + # expect StencilComputation, AccessNode(A), and AccessNode(B) in the last block + last_block = states[len(states) - 1] + nodes = last_block.nodes() + assert ( + len(list(filter(lambda node: isinstance(node, StencilComputation), nodes))) == 1 + ), "one StencilComputation library node" + assert ( + len( + list( + filter( + lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "A", nodes + ) + ) + ) + == 1 + ), "one AccessNode of A" + assert ( + len( + list( + filter( + lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "B", nodes + ) + ) + ) + == 1 + ), "one AccessNode of B" + + edges = last_block.edges() + assert len(edges) == 2, "read and write memlet path expected" + + library_node = list(filter(lambda node: isinstance(node, StencilComputation), nodes))[0] + read_access = list( + filter(lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "A", nodes) + )[0] + write_access = list( + filter(lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "B", nodes) + )[0] + + assert edges[0].src == read_access and edges[0].dst == library_node, "read access expected" + assert edges[1].src == library_node and edges[1].dst == write_access, "write access expected" + + +def test_oir_sdfg_builder_assign_scalar_param() -> None: + stencil_name = "scalar_assign" + stencil = StencilFactory( + name=stencil_name, + params=[ + FieldDeclFactory(name="A", dtype=DataType.FLOAT64), + oir.ScalarDecl(name="b", dtype=DataType.INT32), + ], + vertical_loops__0__sections__0__horizontal_executions__0__body=[ + AssignStmtFactory( + left=FieldAccessFactory(name="A"), right=ScalarAccessFactory(name="b") + ) + ], + ) + sdfg = OirSDFGBuilder().visit(stencil) + + assert isinstance(sdfg, dace.SDFG), "DaCe SDFG expected" + assert sdfg.name == stencil_name, "Stencil name is preserved" + assert len(sdfg.arrays) == 1, "one array expected (A)" + + a_array = sdfg.arrays.get("A") + assert a_array is not None, "Array A expected to be defined" + assert a_array.ctype == "double", "Array A is of type `double`" + assert a_array.offset == (0, 0, 0), "CartesianOffset.zeros() expected" + assert "b" in sdfg.symbols.keys(), "expected `b` as scalar parameter" + + states = sdfg.nodes() + assert len(states) >= 1, "at least one state expected" + + last_block = states[len(states) - 1] + nodes = last_block.nodes() + assert ( + len(list(filter(lambda node: isinstance(node, StencilComputation), nodes))) == 1 + ), "one StencilComputation library node" + assert ( + len( + list( + filter( + lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "A", nodes + ) + ) + ) + == 1 + ), "one AccessNode of A" + + edges = last_block.edges() + library_node = list(filter(lambda node: isinstance(node, StencilComputation), nodes))[0] + write_access = list( + filter(lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "A", nodes) + )[0] + assert len(edges) == 1, "write memlet path expected" + assert edges[0].src == library_node and edges[0].dst == write_access, "write access expected" From 8b96ee404ec71f93abb315f2d93cb7f98e0bf437 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 19 Mar 2025 11:42:29 +0100 Subject: [PATCH 178/178] refactor[cartesian, storage]: break dependency cycle; replace `GT4PY_USE_HIP` (#1916) ## Description This PR breaks the dependency cycle between `gt4py.cartesian` and `gt4py.storage`. The last puzzle piece was the distinction between AMD and NVIDIA GPUs. This was controlled by `GT4PY_USE_HIP`. With this PR we re-use `CUPY_DEVICE_TYPE` (for `_core/definitions`) for this purpose. `GT4PY_USE_HIP` could be set by an environment variable or be auto-detected. Auto-detection for `GT4PY_USE_HIP` and `CUPY_DEVICE_TYPE` is the same. And according to @stubbiali, the environment variable was never[^1] needed because auto-detection worked well. We thus don't expect issues removing the environment variable. Related issue: https://github.com/GridTools/gt4py/issues/1880 ## Requirements - [x] All fixes and/or new features come with corresponding tests. Assumed to covered by existing tests. - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A [^1]: https://github.com/GridTools/gt4py/pull/1867#issuecomment-2662742715 Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- src/gt4py/cartesian/backend/dace_backend.py | 3 ++- src/gt4py/cartesian/backend/pyext_builder.py | 14 +++++++------- src/gt4py/cartesian/config.py | 16 +++------------- src/gt4py/storage/cartesian/utils.py | 6 +++--- tach.toml | 2 +- 5 files changed, 16 insertions(+), 25 deletions(-) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 5fef6d88ba..81775ade1e 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -19,6 +19,7 @@ from dace.sdfg.utils import inline_sdfgs from gt4py import storage as gt_storage +from gt4py._core import definitions as core_defs from gt4py.cartesian import config as gt_config from gt4py.cartesian.backend.base import CLIBackendMixin, register from gt4py.cartesian.backend.gtc_common import ( @@ -523,7 +524,7 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: StencilBuilder, sdfg: dace.SDF with dace.config.temporary_config(): # To prevent conflict with 3rd party usage of DaCe config always make sure that any # changes be under the temporary_config manager - if gt_config.GT4PY_USE_HIP: + if core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.ROCM: dace.config.Config.set("compiler", "cuda", "backend", value="hip") dace.config.Config.set("compiler", "cuda", "max_concurrent_streams", value=-1) dace.config.Config.set( diff --git a/src/gt4py/cartesian/backend/pyext_builder.py b/src/gt4py/cartesian/backend/pyext_builder.py index 8f49ce6f22..8875e3e3af 100644 --- a/src/gt4py/cartesian/backend/pyext_builder.py +++ b/src/gt4py/cartesian/backend/pyext_builder.py @@ -18,6 +18,7 @@ from setuptools import distutils from setuptools.command.build_ext import build_ext +from gt4py._core import definitions as core_defs from gt4py.cartesian import config as gt_config @@ -51,6 +52,7 @@ def get_gt_pyext_build_opts( ) -> Dict[str, Union[str, List[str], Dict[str, Any]]]: include_dirs = [gt_config.build_settings["boost_include_path"]] extra_compile_args_from_config = gt_config.build_settings["extra_compile_args"] + is_rocm_gpu = core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.ROCM if uses_cuda: compute_capability = get_cuda_compute_capability() @@ -68,8 +70,6 @@ def get_gt_pyext_build_opts( gt_include_path = gt_config.build_settings["gt_include_path"] - import os - extra_compile_args = dict( cxx=[ "-std=c++17", @@ -93,7 +93,7 @@ def get_gt_pyext_build_opts( "-DBOOST_OPTIONAL_USE_OLD_DEFINITION_OF_NONE", *extra_compile_args_from_config["cuda"], ] - if gt_config.GT4PY_USE_HIP: + if is_rocm_gpu: extra_compile_args["cuda"] += [ "-isystem{}".format(gt_include_path), "-isystem{}".format(gt_config.build_settings["boost_include_path"]), @@ -125,7 +125,7 @@ def get_gt_pyext_build_opts( extra_compile_args["cxx"].append( "-isystem{}".format(os.path.join(dace_path, "runtime/include")) ) - if gt_config.GT4PY_USE_HIP: + if is_rocm_gpu: extra_compile_args["cuda"].append( "-isystem{}".format(os.path.join(dace_path, "runtime/include")) ) @@ -158,7 +158,7 @@ def get_gt_pyext_build_opts( if uses_cuda: cuda_flags = [] for cpp_flag in cpp_flags: - if gt_config.GT4PY_USE_HIP: + if is_rocm_gpu: cuda_flags.extend([cpp_flag]) else: cuda_flags.extend(["--compiler-options", cpp_flag]) @@ -309,7 +309,7 @@ def build_pybind_cuda_ext( library_dirs = library_dirs or [] library_dirs = [*library_dirs, gt_config.build_settings["cuda_library_path"]] libraries = libraries or [] - if gt_config.GT4PY_USE_HIP: + if core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.ROCM: libraries = [*libraries, "hiprtc"] else: libraries = [*libraries, "cudart"] @@ -363,7 +363,7 @@ def cuda_compile(obj, src, ext, cc_args, extra_postargs, pp_opts): cflags = copy.deepcopy(extra_postargs) try: if os.path.splitext(src)[-1] == ".cu": - if gt_config.GT4PY_USE_HIP: + if core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.ROCM: cuda_exec = os.path.join(gt_config.build_settings["cuda_bin_path"], "hipcc") else: cuda_exec = os.path.join(gt_config.build_settings["cuda_bin_path"], "nvcc") diff --git a/src/gt4py/cartesian/config.py b/src/gt4py/cartesian/config.py index 5aa32506b7..a48f612c84 100644 --- a/src/gt4py/cartesian/config.py +++ b/src/gt4py/cartesian/config.py @@ -12,6 +12,8 @@ import gridtools_cpp +from gt4py._core import definitions as core_defs + GT4PY_INSTALLATION_PATH: str = os.path.dirname(os.path.abspath(__file__)) @@ -26,18 +28,6 @@ CUDA_HOST_CXX: Optional[str] = os.environ.get("CUDA_HOST_CXX", None) -if "GT4PY_USE_HIP" in os.environ: - GT4PY_USE_HIP: bool = bool(int(os.environ["GT4PY_USE_HIP"])) -else: - # Autodetect cupy with ROCm/HIP support - try: - import cupy as _cp - - GT4PY_USE_HIP = _cp.cuda.get_hipcc_path() is not None - del _cp - except Exception: - GT4PY_USE_HIP = False - GT_INCLUDE_PATH: str = os.path.abspath(gridtools_cpp.get_include_dir()) GT_CPP_TEMPLATE_DEPTH: int = 1024 @@ -66,7 +56,7 @@ "parallel_jobs": multiprocessing.cpu_count(), "cpp_template_depth": os.environ.get("GT_CPP_TEMPLATE_DEPTH", GT_CPP_TEMPLATE_DEPTH), } -if GT4PY_USE_HIP: +if core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.ROCM: build_settings["cuda_library_path"] = os.path.join(CUDA_ROOT, "lib") else: build_settings["cuda_library_path"] = os.path.join(CUDA_ROOT, "lib64") diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index d2c5ff066f..2275c1cd57 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -19,7 +19,6 @@ from numpy.typing import DTypeLike from gt4py._core import definitions as core_defs -from gt4py.cartesian import config as gt_config from gt4py.eve.extended_typing import ArrayInterface, CUDAArrayInterface from gt4py.storage import allocators @@ -259,9 +258,10 @@ def _allocate_gpu( ) -> Tuple["cp.ndarray", "cp.ndarray"]: assert cp is not None assert _GPUBufferAllocator is not None, "GPU allocation library or device not found" + if core_defs.CUPY_DEVICE_TYPE is None: + raise ValueError("CUPY_DEVICE_TYPE detection failed.") device = core_defs.Device( # type: ignore[type-var] - (core_defs.DeviceType.ROCM if gt_config.GT4PY_USE_HIP else core_defs.DeviceType.CUDA), - 0, + core_defs.CUPY_DEVICE_TYPE, 0 ) buffer = _GPUBufferAllocator.allocate( shape, diff --git a/tach.toml b/tach.toml index d23b5fb14d..78541c5dff 100644 --- a/tach.toml +++ b/tach.toml @@ -16,6 +16,7 @@ depends_on = [ [[modules]] path = "gt4py.cartesian" depends_on = [ + { path = "gt4py._core" }, { path = "gt4py.eve" }, { path = "gt4py.storage" }, ] @@ -36,6 +37,5 @@ depends_on = [ path = "gt4py.storage" depends_on = [ { path = "gt4py._core" }, - { path = "gt4py.cartesian" }, # for backward-compatibility the cartesian allocators are in `gt4py.storage` { path = "gt4py.eve" }, ]