From 768e4585e09563155edece4e1839dacc80331641 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 2 Mar 2025 12:51:06 +0100 Subject: [PATCH 01/44] First draft --- src/gt4py/next/ffront/past_to_itir.py | 4 +++- src/gt4py/next/iterator/builtins.py | 6 ++++++ src/gt4py/next/iterator/embedded.py | 5 +++++ src/gt4py/next/iterator/runtime.py | 5 +++++ .../next/iterator/type_system/type_synthesizer.py | 7 +++++++ .../next/program_processors/codegens/gtfn/codegen.py | 12 ++++++++++++ .../next/program_processors/codegens/gtfn/gtfn_ir.py | 1 + .../codegens/gtfn/itir_to_gtfn_ir.py | 8 ++++++++ src/gt4py/next/program_processors/runners/gtfn.py | 4 ++-- 9 files changed, 49 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 4bc1dfb2f8..b480203050 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -349,7 +349,9 @@ def _construct_itir_domain_arg( domain_args_kind = [] for dim_i, dim in enumerate(out_dims): # an expression for the range of a dimension - dim_range = itir.SymRef(id=_range_arg_from_field(out_field.id, dim_i)) + dim_range = im.call("get_domain")( + out_field.id, itir.AxisLiteral(value=dim.value, kind=dim.kind) + ) dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range) # bounds lower: itir.Expr diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 8e5f7addca..f197107542 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -402,6 +402,11 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() +@builtin_dispatch +def get_domain(*args): + raise BackendNotSelectedError() + + UNARY_MATH_NUMBER_BUILTINS = {"abs", "neg"} UNARY_LOGICAL_BUILTINS = {"not_"} UNARY_MATH_FP_BUILTINS = { @@ -474,6 +479,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "cartesian_domain", "cast_", "deref", + "get_domain", "if_", "index", # `index(dim)` creates a dim-field that has the current index at each point "shift", diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index da0516d26b..6e5bb4608c 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1683,6 +1683,11 @@ def set_at(expr: common.Field, domain: common.DomainLike, target: common.Mutable operators._tuple_assign_field(target, expr, common.domain(domain)) +@runtime.get_domain.register(EMBEDDED) +def get_domain(field: common.Field, dim: common.Dimension) -> tuple[int, int]: + return (field.domain[dim].unit_range.start, field.domain[dim].unit_range.stop) + + @runtime.if_stmt.register(EMBEDDED) def if_stmt(cond: bool, true_branch: Callable[[], None], false_branch: Callable[[], None]) -> None: """ diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index c9a5b15de7..c831f33f26 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -213,6 +213,11 @@ def set_at(*args): return BackendNotSelectedError() +@builtin_dispatch +def get_domain(*args): + return BackendNotSelectedError() + + @builtin_dispatch def if_stmt(*args): return BackendNotSelectedError() diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 131b773dd2..d7baa72f12 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -329,6 +329,13 @@ def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: return applied_as_fieldop +@_register_builtin_type_synthesizer +def get_domain(field: ts.FieldType, dim: ts.DimensionType) -> ts.TupleType: + return ts.TupleType( + types=[ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()))] * 2 + ) + + @_register_builtin_type_synthesizer def scan( scan_pass: TypeSynthesizer, direction: ts.ScalarType, init: ts.ScalarType diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 969e203689..a9f596effc 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -110,6 +110,8 @@ def visit_Literal(self, node: gtfn_ir.Literal, **kwargs: Any) -> str: return self.asfloat(node.value) case "bool": return node.value.lower() + case "axis_literal": + return node.value + "_t" case _: # TODO(tehrengruber): we should probably shouldn't just allow anything here. Revisit. return node.value @@ -272,6 +274,16 @@ def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Coll #include #include + namespace gridtools::fn { + // TODO(tehrengruber): `typename gridtools::sid::lower_bounds_type, typename gridtools::sid::upper_bounds_type` + // fails as type used for index calculations in gtfn differs + template + GT_FUNCTION gridtools::tuple get_domain(S &&sid, D) { + return {gridtools::host_device::at_key(gridtools::sid::get_lower_bounds(sid)), + gridtools::host_device::at_key(gridtools::sid::get_upper_bounds(sid))}; + } + } + namespace generated{ namespace gtfn = ::gridtools::fn; diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 831694791a..e45ce983e5 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -226,6 +226,7 @@ class TemporaryAllocation(Node): "can_deref", "cartesian_domain", "unstructured_domain", + "get_domain", "named_range", "reduce", "index", diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 104e2eccc1..2bc25de203 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -485,6 +485,14 @@ def _visit_unstructured_domain(self, node: itir.FunCall, **kwargs: Any) -> Node: tagged_sizes=sizes, tagged_offsets=domain_offsets, connectivities=connectivities ) + def _visit_get_domain(self, node: itir.FunCall, **kwargs: Any) -> Node: + field, dim = node.args + + return FunCall( + fun=SymRef(id="get_domain"), + args=[self.visit(field, **kwargs), self.visit(dim, **kwargs)], + ) + def visit_FunCall(self, node: itir.FunCall, **kwargs: Any) -> Node: if isinstance(node.fun, itir.SymRef): if node.fun.id in self._unary_op_map: diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index a8961fd9bc..2fa273322e 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -22,7 +22,7 @@ from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler -from gt4py.next.otf.compilation.build_systems import compiledb +from gt4py.next.otf.compilation.build_systems import cmake from gt4py.next.program_processors.codegens.gtfn import gtfn_module @@ -141,7 +141,7 @@ class Params: lambda: config.CMAKE_BUILD_TYPE ) builder_factory: compiler.BuildSystemProjectGenerator = factory.LazyAttribute( - lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) + lambda o: cmake.CMakeFactory(cmake_build_type=o.cmake_build_type) ) cached_translation = factory.Trait( From ac7db53d988ea9635b70cdf9eb6341bbce10681e Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 2 Mar 2025 12:54:03 +0100 Subject: [PATCH 02/44] Remove debugging leftovers --- src/gt4py/next/program_processors/runners/gtfn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 2fa273322e..a8961fd9bc 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -22,7 +22,7 @@ from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler -from gt4py.next.otf.compilation.build_systems import cmake +from gt4py.next.otf.compilation.build_systems import compiledb from gt4py.next.program_processors.codegens.gtfn import gtfn_module @@ -141,7 +141,7 @@ class Params: lambda: config.CMAKE_BUILD_TYPE ) builder_factory: compiler.BuildSystemProjectGenerator = factory.LazyAttribute( - lambda o: cmake.CMakeFactory(cmake_build_type=o.cmake_build_type) + lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) ) cached_translation = factory.Trait( From 373181042f631f36e470ec1e41c3b9c0953b9eaa Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 30 Jul 2025 09:23:00 +0200 Subject: [PATCH 03/44] Get domain from tuple element --- src/gt4py/next/ffront/past_to_itir.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 04265985bc..1cd8f65b90 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -9,11 +9,13 @@ from __future__ import annotations import dataclasses +import functools from typing import Any, Optional, cast import devtools from gt4py.eve import NodeTranslator, concepts, traits +from gt4py.eve import utils as eve_utils from gt4py.next import common, config, errors from gt4py.next.ffront import ( fbuiltins, @@ -362,13 +364,17 @@ def _construct_itir_domain_arg( " fields defined on the same dimensions. This error should be " " caught in type deduction already." ) + # if the out_field is a (potentially nested) tuple we get the domain from its first + # element + first_out_el_path = eve_utils.first(type_info.primitive_constituents(out_field.type, with_path_arg=True))[1] + first_out_el = functools.reduce(lambda expr, i: im.tuple_get(i, expr), first_out_el_path, out_field.id) domain_args = [] domain_args_kind = [] for dim_i, dim in enumerate(out_dims): # an expression for the range of a dimension dim_range = im.call("get_domain")( - out_field.id, itir.AxisLiteral(value=dim.value, kind=dim.kind) + first_out_el, itir.AxisLiteral(value=dim.value, kind=dim.kind) ) dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range) # bounds From 0e4eb57ae840829b75e0f36e4ce8b0c48dd82435 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 18 Aug 2025 14:27:04 +0200 Subject: [PATCH 04/44] Rename get_domain to get_domain_range --- src/gt4py/next/ffront/past_to_itir.py | 13 ++++++++----- src/gt4py/next/iterator/builtins.py | 4 ++-- src/gt4py/next/iterator/embedded.py | 4 ++-- src/gt4py/next/iterator/runtime.py | 2 +- .../next/iterator/type_system/type_synthesizer.py | 5 +++-- .../program_processors/codegens/gtfn/codegen.py | 3 ++- .../program_processors/codegens/gtfn/gtfn_ir.py | 2 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 4 ++-- 8 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 79c6a4b36b..f6f9c088fa 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -14,8 +14,7 @@ import devtools -from gt4py.eve import NodeTranslator, concepts, traits -from gt4py.eve import utils as eve_utils +from gt4py.eve import NodeTranslator, concepts, traits, utils as eve_utils from gt4py.next import common, config, errors from gt4py.next.ffront import ( fbuiltins, @@ -366,14 +365,18 @@ def _construct_itir_domain_arg( ) # if the out_field is a (potentially nested) tuple we get the domain from its first # element - first_out_el_path = eve_utils.first(type_info.primitive_constituents(out_field.type, with_path_arg=True))[1] - first_out_el = functools.reduce(lambda expr, i: im.tuple_get(i, expr), first_out_el_path, out_field.id) + first_out_el_path = eve_utils.first( + type_info.primitive_constituents(out_field.type, with_path_arg=True) + )[1] + first_out_el = functools.reduce( + lambda expr, i: im.tuple_get(i, expr), first_out_el_path, out_field.id + ) domain_args = [] domain_args_kind = [] for dim_i, dim in enumerate(out_dims): # an expression for the range of a dimension - dim_range = im.call("get_domain")( + dim_range = im.call("get_domain_range")( first_out_el, itir.AxisLiteral(value=dim.value, kind=dim.kind) ) dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 153105413b..e54c6ea3d7 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -413,7 +413,7 @@ def concat_where(*args): @builtin_dispatch -def get_domain(*args): +def get_domain_range(*args): raise BackendNotSelectedError() @@ -490,7 +490,7 @@ def get_domain(*args): "cartesian_domain", "cast_", "deref", - "get_domain", + "get_domain_range", "if_", "index", # `index(dim)` creates a dim-field that has the current index at each point "shift", diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 739f7f2dfa..9e5c9a0efb 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1678,8 +1678,8 @@ def set_at(expr: common.Field, domain: common.DomainLike, target: common.Mutable operators._tuple_assign_field(target, expr, common.domain(domain)) -@runtime.get_domain.register(EMBEDDED) -def get_domain(field: common.Field, dim: common.Dimension) -> tuple[int, int]: +@runtime.get_domain_range.register(EMBEDDED) +def get_domain_range(field: common.Field, dim: common.Dimension) -> tuple[int, int]: return (field.domain[dim].unit_range.start, field.domain[dim].unit_range.stop) diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index c927cda843..a995385b0f 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -216,7 +216,7 @@ def set_at(*args): @builtin_dispatch -def get_domain(*args): +def get_domain_range(*args): return BackendNotSelectedError() diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 258020ea7e..cc0759b79b 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -589,9 +589,10 @@ def applied_as_fieldop( @_register_builtin_type_synthesizer -def get_domain(field: ts.FieldType, dim: ts.DimensionType) -> ts.TupleType: +def get_domain_range(field: ts.FieldType, dim: ts.DimensionType) -> ts.TupleType: return ts.TupleType( - types=[ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()))] * 2 + types=[ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()))] + * 2 ) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index bf84bb8519..f142ff006e 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -273,11 +273,12 @@ def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Coll #include #include + // TODO(tehrengruber): This should disappear as soon as we introduce a proper builtin. namespace gridtools::fn { // TODO(tehrengruber): `typename gridtools::sid::lower_bounds_type, typename gridtools::sid::upper_bounds_type` // fails as type used for index calculations in gtfn differs template - GT_FUNCTION gridtools::tuple get_domain(S &&sid, D) { + GT_FUNCTION gridtools::tuple get_domain_range(S &&sid, D) { return {gridtools::host_device::at_key(gridtools::sid::get_lower_bounds(sid)), gridtools::host_device::at_key(gridtools::sid::get_upper_bounds(sid))}; } diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 1e5023f5be..aa5a94991c 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -231,7 +231,7 @@ class TemporaryAllocation(Node): "can_deref", "cartesian_domain", "unstructured_domain", - "get_domain", + "get_domain_range", "named_range", "reduce", "index", diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index dac7823451..02fb4fdbe2 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -471,11 +471,11 @@ def _visit_unstructured_domain(self, node: itir.FunCall, **kwargs: Any) -> Node: tagged_sizes=sizes, tagged_offsets=domain_offsets, connectivities=connectivities ) - def _visit_get_domain(self, node: itir.FunCall, **kwargs: Any) -> Node: + def _visit_get_domain_range(self, node: itir.FunCall, **kwargs: Any) -> Node: field, dim = node.args return FunCall( - fun=SymRef(id="get_domain"), + fun=SymRef(id="get_domain_range"), args=[self.visit(field, **kwargs), self.visit(dim, **kwargs)], ) From d916337482d1ac52152b3ebcfda70e3f4bcd00a2 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 18 Aug 2025 15:15:55 +0200 Subject: [PATCH 05/44] Remove compile time args --- .../next/advanced/ToolchainWalkthrough.md | 10 ++-- src/gt4py/next/backend.py | 2 +- src/gt4py/next/ffront/past_process_args.py | 44 +++----------- src/gt4py/next/ffront/past_to_itir.py | 27 --------- src/gt4py/next/iterator/embedded.py | 5 -- src/gt4py/next/otf/arguments.py | 60 +------------------ .../runners/dace/workflow/decoration.py | 6 +- .../next/program_processors/runners/gtfn.py | 1 - .../gtfn_tests/test_gtfn_module.py | 6 +- 9 files changed, 20 insertions(+), 141 deletions(-) diff --git a/docs/user/next/advanced/ToolchainWalkthrough.md b/docs/user/next/advanced/ToolchainWalkthrough.md index 4d71c0ffe9..8cb8293b7a 100644 --- a/docs/user/next/advanced/ToolchainWalkthrough.md +++ b/docs/user/next/advanced/ToolchainWalkthrough.md @@ -134,13 +134,13 @@ So far we have gotten away with empty compile time arguments, now we need to sup ```python jit_args = gtx.otf.arguments.JITArgs.from_signature( - gtx.ones(domain={I: 10}, dtype=gtx.float64), - out=gtx.zeros(domain={I: 10}, dtype=gtx.float64), - offset_provider=OFFSET_PROVIDER, + gtx.ones(domain={I: 10}, dtype=gtx.float64), + out=gtx.zeros(domain={I: 10}, dtype=gtx.float64), + offset_provider=OFFSET_PROVIDER, ) -aot_args = gtx.otf.arguments.CompileTimeArgs.from_concrete_no_size( - *jit_args.args, **jit_args.kwargs +aot_args = gtx.otf.arguments.CompileTimeArgs.from_concrete( + *jit_args.args, **jit_args.kwargs ) ``` diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index e8fa6b2ac5..a20c72c24b 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -159,7 +159,7 @@ def __call__( def jit(self, program: INPUT_DATA, *args: Any, **kwargs: Any) -> stages.CompiledProgram: if not isinstance(program, IT_PRG): args, kwargs = signature.convert_to_positional(program, *args, **kwargs) - aot_args = arguments.CompileTimeArgs.from_concrete_no_size(*args, **kwargs) + aot_args = arguments.CompileTimeArgs.from_concrete(*args, **kwargs) return self.compile(program, aot_args) def compile(self, program: INPUT_DATA, compile_time_args: CARG) -> stages.CompiledProgram: diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index f0360e05ba..95db2837dd 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -24,14 +24,14 @@ def transform_program_args(inp: AOT_PRG) -> AOT_PRG: - rewritten_args, size_args, kwargs = _process_args( + rewritten_args, rewritten_kwargs = _process_args( past_node=inp.data.past_node, args=inp.args.args, kwargs=inp.args.kwargs ) return toolchain.CompilableProgram( data=inp.data, args=arguments.CompileTimeArgs( - args=tuple((*rewritten_args, *(size_args))), - kwargs=kwargs, + args=rewritten_args, + kwargs=rewritten_kwargs, offset_provider=inp.args.offset_provider, column_axis=inp.args.column_axis, ), @@ -65,50 +65,20 @@ def _process_args( past_node: past.Program, args: Sequence[ts.TypeSpec | arguments.StaticArg], kwargs: dict[str, ts.TypeSpec | arguments.StaticArg], -) -> tuple[tuple, tuple, dict[str, Any]]: +) -> tuple[tuple, dict[str, Any]]: if not isinstance(past_node.type, ts_ffront.ProgramType): raise TypeError("Can not process arguments for PAST programs prior to type inference.") args, kwargs = type_info.canonicalize_arguments(past_node.type, args, kwargs) + + # validate arguments arg_types = tuple(arg.type_ if isinstance(arg, arguments.StaticArg) else arg for arg in args) kwarg_types = { k: (v.type_ if isinstance(v, arguments.StaticArg) else v) for k, v in kwargs.items() } _validate_args(past_node=past_node, arg_types=arg_types, kwarg_types=kwarg_types) - implicit_domain = any( - isinstance(stmt, past.Call) and "domain" not in stmt.kwargs for stmt in past_node.body - ) - - # extract size of all field arguments - size_args: list[ts.TypeSpec] = [] - rewritten_args = list(args) - for param_idx, param in enumerate(past_node.params): - if implicit_domain and isinstance(param.type, (ts.FieldType, ts.TupleType)): - # TODO(tehrengruber): Previously this function was called with the actual arguments - # not their type. The check using the shape here is not functional anymore and - # should instead be placed in a proper location. - ranges_and_dims = [ - *_field_constituents_range_and_dims(arg_types[param_idx], param.type) - ] - # check that all non-scalar like constituents have the same shape and dimension, e.g. - # for `(scalar, (field1, field2))` the two fields need to have the same shape and - # dimension - if ranges_and_dims: - range_, dims = ranges_and_dims[0] - if not all( - el_range == range_ and el_dims == dims - for (el_range, el_dims) in ranges_and_dims - ): - raise ValueError( - "Constituents of composite arguments (e.g. the elements of a" - " tuple) need to have the same shape and dimensions." - ) - index_type = ts.ScalarType(kind=ts.ScalarKind.INT32) - size_args.extend( - range_ if range_ else [ts.TupleType(types=[index_type, index_type])] * len(dims) # type: ignore[arg-type] # shape is always empty - ) - return tuple(rewritten_args), tuple(size_args), kwargs + return args, kwargs def _field_constituents_range_and_dims( diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index f6f9c088fa..f0c8060e65 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -223,32 +223,6 @@ def apply( ) -> itir.Program: return cls(grid_type=grid_type).visit(node, function_definitions=function_definitions) - def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: - """Generate symbols for each field param and dimension.""" - size_params = [] - for param in node.params: - fields_dims: list[list[common.Dimension]] = ( - type_info.primitive_constituents(param.type) - .if_isinstance(ts.FieldType) - .getattr("dims") - .filter(lambda dims: len(dims) > 0) - .to_list() - ) - if len(fields_dims) > 0: # otherwise `param` has no constituent which is of `FieldType` - assert all(field_dims == fields_dims[0] for field_dims in fields_dims) - index_type = ts.ScalarType( - kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) - ) - for dim_idx in range(len(fields_dims[0])): - size_params.append( - itir.Sym( - id=_range_arg_from_field(param.id, dim_idx), - type=ts.TupleType(types=[index_type, index_type]), - ) - ) - - return size_params - def visit_Program( self, node: past.Program, @@ -265,7 +239,6 @@ def visit_Program( implicit_domain = False if any("domain" not in body_entry.kwargs for body_entry in node.body): - params = params + self._gen_size_params_from_program(node) implicit_domain = True set_ats = [self._visit_field_operator_call(stmt, **kwargs) for stmt in node.body] diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 9e5c9a0efb..b04ba8c42d 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1862,11 +1862,6 @@ def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any): common.UnitRange(0, 0), # empty: indicates column operation, will update later ) - import inspect - - if len(args) < len(inspect.getfullargspec(fun).args): - args = (*args, *arguments.iter_size_args(args)) - with embedded_context.update(**context_vars): fun(*args) diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 39540baebb..980c9849e2 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -55,7 +55,7 @@ def offset_provider_type(self) -> common.OffsetProviderType: return common.offset_provider_to_type(self.offset_provider) @classmethod - def from_concrete_no_size(cls, *args: Any, **kwargs: Any) -> Self: + def from_concrete(cls, *args: Any, **kwargs: Any) -> Self: """Convert concrete GTX program arguments into their compile-time counterparts.""" compile_args = tuple(type_translation.from_value(arg) for arg in args) kwargs_copy = kwargs.copy() @@ -69,17 +69,6 @@ def from_concrete_no_size(cls, *args: Any, **kwargs: Any) -> Self: }, ) - @classmethod - def from_concrete(cls, *args: Any, **kwargs: Any) -> Self: - """Convert concrete GTX program arguments to compile-time, adding (compile-time) dimension size arguments.""" - no_size = cls.from_concrete_no_size(*args, **kwargs) - return cls( - args=(*no_size.args, *iter_size_compile_args(no_size.args)), - offset_provider=no_size.offset_provider, - column_axis=no_size.column_axis, - kwargs=no_size.kwargs, - ) - @classmethod def empty(cls) -> Self: return cls(tuple(), {}, {}, None) @@ -88,7 +77,7 @@ def empty(cls) -> Self: def jit_to_aot_args( inp: JITArgs, ) -> CompileTimeArgs: - return CompileTimeArgs.from_concrete_no_size(*inp.args, **inp.kwargs) + return CompileTimeArgs.from_concrete(*inp.args, **inp.kwargs) def adapted_jit_to_aot_args_factory() -> workflow.Workflow[ @@ -110,47 +99,4 @@ def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]: return element case _: pass - return None - - -def iter_size_args(args: tuple[Any, ...]) -> Iterator[tuple[int, int]]: - """ - Yield the size of each field argument in each dimension. - - This can be used to generate domain size arguments for FieldView Programs that use an implicit domain. - """ - for arg in args: - match arg: - case tuple(): - # we only need the first field, because all fields in a tuple must have the same dims and sizes - first_field = find_first_field(arg) - if first_field: - yield from iter_size_args((first_field,)) - case common.Field(): - for range_ in arg.domain.ranges: - assert isinstance(range_, common.UnitRange) - yield (range_.start, range_.stop) - case _: - pass - - -def iter_size_compile_args( - args: Iterable[ts.TypeSpec | StaticArg], -) -> Iterator[ts.TypeSpec]: - """ - Yield a compile-time size argument for every compile-time field argument in each dimension. - - This can be used inside transformation workflows to generate compile-time domain size arguments for FieldView Programs that use an implicit domain. - """ - for arg in args: - type_ = arg.type_ if isinstance(arg, StaticArg) else arg - field_constituents: list[ts.FieldType] = typing.cast( - list[ts.FieldType], - type_info.primitive_constituents(type_).if_isinstance(ts.FieldType).to_list(), - ) - if field_constituents: - # we only need the first field, because all fields in a tuple must have the same dims and sizes - index_type = ts.ScalarType(kind=ts.ScalarKind.INT32) - yield from [ - ts.TupleType(types=[index_type, index_type]) for _ in field_constituents[0].dims - ] + return None \ No newline at end of file diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index b551381354..054e148e01 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -38,11 +38,7 @@ def decorated_program( if out is not None: args = (*args, out) - if fun.implicit_domain: - # Generate implicit domain size arguments only if necessary - size_args = arguments.iter_size_args(args) - args = (*args, *size_args) - + # TODO: this doesn't belong here and should by done in the dace backend if not fun.sdfg_program._lastargs: # First call, the SDFG is not intitalized, so forward the call to `CompiledSDFG` # to proper initilize it. Later calls to this SDFG will be handled through diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index e395bcf991..ae326184d8 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -74,7 +74,6 @@ def decorated_program( # generate implicit domain size arguments only if necessary, using `iter_size_args()` inp( *converted_args, - *(arguments.iter_size_args(args) if inp.implicit_domain else ()), *conn_args, **opt_kwargs, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 510c03e314..acfbede4eb 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -77,7 +77,7 @@ def test_codegen(program_example): module = gtfn_module.translate_program_cpu( stages.CompilableProgram( data=fencil, - args=arguments.CompileTimeArgs.from_concrete_no_size( + args=arguments.CompileTimeArgs.from_concrete( *parameters, **{"offset_provider": {}} ), ) @@ -91,7 +91,7 @@ def test_hash_and_diskcache(program_example, tmp_path): fencil, parameters = program_example compilable_program = stages.CompilableProgram( data=fencil, - args=arguments.CompileTimeArgs.from_concrete_no_size( + args=arguments.CompileTimeArgs.from_concrete( *parameters, **{"offset_provider": {}} ), ) @@ -135,7 +135,7 @@ def test_gtfn_file_cache(program_example): fencil, parameters = program_example compilable_program = stages.CompilableProgram( data=fencil, - args=arguments.CompileTimeArgs.from_concrete_no_size( + args=arguments.CompileTimeArgs.from_concrete( *parameters, **{"offset_provider": {}} ), ) From 25e24e9afc0ff71868efdca13cf631a8289fa2be Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 18 Aug 2025 15:16:18 +0200 Subject: [PATCH 06/44] Fix format --- docs/user/next/advanced/ToolchainWalkthrough.md | 10 ++++------ src/gt4py/next/ffront/past_to_itir.py | 2 +- src/gt4py/next/iterator/embedded.py | 1 - src/gt4py/next/otf/arguments.py | 6 +++--- .../runners/dace/workflow/decoration.py | 2 +- src/gt4py/next/program_processors/runners/gtfn.py | 2 +- .../codegens_tests/gtfn_tests/test_gtfn_module.py | 12 +++--------- 7 files changed, 13 insertions(+), 22 deletions(-) diff --git a/docs/user/next/advanced/ToolchainWalkthrough.md b/docs/user/next/advanced/ToolchainWalkthrough.md index 8cb8293b7a..d730eed37e 100644 --- a/docs/user/next/advanced/ToolchainWalkthrough.md +++ b/docs/user/next/advanced/ToolchainWalkthrough.md @@ -134,14 +134,12 @@ So far we have gotten away with empty compile time arguments, now we need to sup ```python jit_args = gtx.otf.arguments.JITArgs.from_signature( - gtx.ones(domain={I: 10}, dtype=gtx.float64), - out=gtx.zeros(domain={I: 10}, dtype=gtx.float64), - offset_provider=OFFSET_PROVIDER, + gtx.ones(domain={I: 10}, dtype=gtx.float64), + out=gtx.zeros(domain={I: 10}, dtype=gtx.float64), + offset_provider=OFFSET_PROVIDER, ) -aot_args = gtx.otf.arguments.CompileTimeArgs.from_concrete( - *jit_args.args, **jit_args.kwargs -) +aot_args = gtx.otf.arguments.CompileTimeArgs.from_concrete(*jit_args.args, **jit_args.kwargs) ``` ```python diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index f0c8060e65..e7b9412ebf 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -25,7 +25,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.stages import AOT_PRG -from gt4py.next.iterator import builtins, ir as itir +from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import remap_symbols from gt4py.next.otf import arguments, stages, workflow diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index b04ba8c42d..c43a3422b5 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -54,7 +54,6 @@ ) from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, runtime -from gt4py.next.otf import arguments from gt4py.next.type_system import type_specifications as ts, type_translation diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 980c9849e2..04f4bcf1e8 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -10,14 +10,14 @@ import dataclasses import typing -from typing import Any, Generic, Iterable, Iterator, Optional +from typing import Any, Generic, Optional from typing_extensions import Self from gt4py._core import definitions as core_defs from gt4py.next import common from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_info, type_specifications as ts, type_translation +from gt4py.next.type_system import type_specifications as ts, type_translation DATA_T = typing.TypeVar("DATA_T") @@ -99,4 +99,4 @@ def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]: return element case _: pass - return None \ No newline at end of file + return None diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 054e148e01..240de755fe 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -15,7 +15,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common as gtx_common, config, metrics, utils as gtx_utils -from gt4py.next.otf import arguments, stages +from gt4py.next.otf import stages from gt4py.next.program_processors.runners.dace import sdfg_callable, workflow as dace_worflow from . import common as dace_common diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index ae326184d8..29cacd06e1 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -20,7 +20,7 @@ from gt4py._core import locking from gt4py.next import backend, common, config, field_utils, metrics from gt4py.next.embedded import nd_array_field -from gt4py.next.otf import arguments, recipes, stages, workflow +from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import compiledb diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index acfbede4eb..6eba78bb23 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -77,9 +77,7 @@ def test_codegen(program_example): module = gtfn_module.translate_program_cpu( stages.CompilableProgram( data=fencil, - args=arguments.CompileTimeArgs.from_concrete( - *parameters, **{"offset_provider": {}} - ), + args=arguments.CompileTimeArgs.from_concrete(*parameters, **{"offset_provider": {}}), ) ) assert module.entry_point.name == fencil.id @@ -91,9 +89,7 @@ def test_hash_and_diskcache(program_example, tmp_path): fencil, parameters = program_example compilable_program = stages.CompilableProgram( data=fencil, - args=arguments.CompileTimeArgs.from_concrete( - *parameters, **{"offset_provider": {}} - ), + args=arguments.CompileTimeArgs.from_concrete(*parameters, **{"offset_provider": {}}), ) hash = stages.fingerprint_compilable_program(compilable_program) @@ -135,9 +131,7 @@ def test_gtfn_file_cache(program_example): fencil, parameters = program_example compilable_program = stages.CompilableProgram( data=fencil, - args=arguments.CompileTimeArgs.from_concrete( - *parameters, **{"offset_provider": {}} - ), + args=arguments.CompileTimeArgs.from_concrete(*parameters, **{"offset_provider": {}}), ) cached_gtfn_translation_step = gtfn.GTFNBackendFactory( gpu=False, cached=True, otf_workflow__cached_translation=True From e856a1899b1254a6f5d1168e071909cb91727acf Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 18 Aug 2025 15:46:13 +0200 Subject: [PATCH 07/44] Fix failing tests --- tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py index cbaa84454d..ae70d27ea8 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py @@ -99,8 +99,6 @@ def test_copy_lowering(copy_program_def, gtir_identity_fundef): params=[ P(itir.Sym, id=eve.SymbolName("in_field")), P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_0_range")), - P(itir.Sym, id=eve.SymbolName("__out_0_range")), ], body=[set_at_pattern], ) @@ -191,8 +189,6 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) params=[ P(itir.Sym, id=eve.SymbolName("in_field")), P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_0_range")), - P(itir.Sym, id=eve.SymbolName("__out_0_range")), ], body=[set_at_pattern], ) From 00a11a62695b0a9a1bd9510cb9d8ec5622e67a11 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 18 Aug 2025 16:30:59 +0200 Subject: [PATCH 08/44] Fix failing tests --- .../ffront_tests/test_past_to_gtir.py | 73 +++++++------------ 1 file changed, 25 insertions(+), 48 deletions(-) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py index ae70d27ea8..aede44283e 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py @@ -8,6 +8,7 @@ import re +from typing import Literal import pytest @@ -40,6 +41,25 @@ def gtir_identity_fundef(): ) +def get_domain_range_pattern(field: str, dim: str, idx: Literal[0, 1]): + return P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), + args=[ + P( + itir.Literal, + value=str(idx), + type=ts.ScalarType(kind=ts.ScalarKind.INT32), + ), + P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("get_domain_range")), + args=[P(itir.SymRef, id=eve.SymbolRef(field)), P(itir.AxisLiteral, value=dim)], + ), + ], + ) + + def test_copy_lowering(copy_program_def, gtir_identity_fundef): past_node = ProgramParser.apply_to_function(copy_program_def) itir_node = ProgramLowering.apply( @@ -58,30 +78,8 @@ def test_copy_lowering(copy_program_def, gtir_identity_fundef): fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), args=[ P(itir.AxisLiteral, value="IDim"), - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), - args=[ - P( - itir.Literal, - value="0", - type=ts.ScalarType(kind=ts.ScalarKind.INT32), - ), - P(itir.SymRef, id=eve.SymbolRef("__out_0_range")), - ], - ), - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), - args=[ - P( - itir.Literal, - value="1", - type=ts.ScalarType(kind=ts.ScalarKind.INT32), - ), - P(itir.SymRef, id=eve.SymbolRef("__out_0_range")), - ], - ), + get_domain_range_pattern("out", "IDim", 0), + get_domain_range_pattern("out", "IDim", 1), ], ) ], @@ -128,18 +126,7 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) itir.FunCall, fun=P(itir.SymRef, id=eve.SymbolRef("plus")), args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), - args=[ - P( - itir.Literal, - value="0", - type=ts.ScalarType(kind=ts.ScalarKind.INT32), - ), - P(itir.SymRef, id=eve.SymbolRef("__out_0_range")), - ], - ), + get_domain_range_pattern("out", "IDim", 0), P( itir.Literal, value="1", @@ -155,18 +142,7 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) itir.FunCall, fun=P(itir.SymRef, id=eve.SymbolRef("plus")), args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), - args=[ - P( - itir.Literal, - value="0", - type=ts.ScalarType(kind=ts.ScalarKind.INT32), - ), - P(itir.SymRef, id=eve.SymbolRef("__out_0_range")), - ], - ), + get_domain_range_pattern("out", "IDim", 0), P( itir.Literal, value="2", @@ -183,6 +159,7 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) ], ), ) + program_pattern = P( itir.Program, id=eve.SymbolName("copy_restrict_program"), From f0bb72f15867bfa019d76fa8b4d79219b09327f7 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 22 Aug 2025 13:30:51 +0200 Subject: [PATCH 09/44] Extend prototype for multiple output domains --- src/gt4py/next/embedded/operators.py | 32 +++++++--- .../next/ffront/past_passes/type_deduction.py | 64 +++++++++++-------- src/gt4py/next/ffront/past_to_itir.py | 35 ++++++---- src/gt4py/next/iterator/embedded.py | 8 ++- src/gt4py/next/iterator/ir_utils/misc.py | 10 ++- .../next/iterator/transforms/infer_domain.py | 13 +++- 6 files changed, 110 insertions(+), 52 deletions(-) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index ea393e2ad0..b309c6601e 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -108,8 +108,18 @@ def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) -> domain = kwargs.pop("domain", None) - out_domain = common.domain(domain) if domain is not None else _get_out_domain(out) + if domain is not None: + if isinstance(out, tuple) and not isinstance(domain, tuple): + out_domain = tuple([domain] * len(out)) + else: + out_domain = domain + else: + if isinstance(out, tuple): + out_domain = tuple(_get_out_domain(o) for o in out) + else: + out_domain = _get_out_domain(out) + # TODO? new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain) with embedded_context.update(**new_context_kwargs): @@ -128,26 +138,32 @@ def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) -> return op(*args, **kwargs) -def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.NothingType: - vertical_dim_filtered = [nr for nr in domain if nr.dim.kind == common.DimensionKind.VERTICAL] - assert len(vertical_dim_filtered) <= 1 - return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING +def _get_vertical_range( + domain: common.Domain | tuple[common.Domain, ...], +) -> common.NamedRange | eve.NothingType | tuple[common.NamedRange | eve.NothingType, ...]: + if isinstance(domain, tuple): + return tuple(_get_vertical_range(sub) for sub in domain) + else: + vertical_dim_filtered = [nr for nr in domain if nr.dim.kind == common.DimensionKind.VERTICAL] + assert len(vertical_dim_filtered) <= 1 + return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING def _tuple_assign_field( target: tuple[common.MutableField | tuple, ...] | common.MutableField, source: tuple[common.Field | tuple, ...] | common.Field, - domain: common.Domain, + domain: common.DomainLike | tuple[common.DomainLike | tuple, ...], ) -> None: @utils.tree_map - def impl(target: common.MutableField, source: common.Field) -> None: + def impl(target: common.MutableField, source: common.Field, domain: common.DomainLike) -> None: + domain = common.domain(domain) if isinstance(source, common.Field): target[domain] = source[domain] else: assert core_defs.is_scalar_type(source) target[domain] = source - impl(target, source) + impl(target, source, domain) def _intersect_scan_args( diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 9355273588..2980014e27 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -8,7 +8,7 @@ from typing import Any, Optional, cast from gt4py.eve import NodeTranslator, traits -from gt4py.next import errors +from gt4py.next import common, errors from gt4py.next.ffront import ( dialect_ast_enums, program_ast as past, @@ -57,28 +57,35 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict) -> None: _ensure_no_sliced_field(new_kwargs["out"]) domain_kwarg = new_kwargs["domain"] - if not isinstance(domain_kwarg, past.Dict): + if not isinstance(domain_kwarg, (past.Dict, past.TupleExpr)): raise ValueError(f"Only Dictionaries allowed in 'domain', got '{type(domain_kwarg)}'.") - - if len(domain_kwarg.values_) == 0 and len(domain_kwarg.keys_) == 0: - raise ValueError("Empty domain not allowed.") - - for dim in domain_kwarg.keys_: - if not isinstance(dim.type, ts.DimensionType): - raise ValueError( - f"Only 'Dimension' allowed in domain dictionary keys, got '{dim}' which is of type '{dim.type}'." - ) - for domain_values in domain_kwarg.values_: - if len(domain_values.elts) != 2: - raise ValueError( - f"Only 2 values allowed in domain range, got {len(domain_values.elts)}." - ) - if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar( - domain_values.elts[1] - ): - raise ValueError( - f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'." - ) + if isinstance(domain_kwarg, past.Dict): + domain_kwarg_ = [domain_kwarg] + elif isinstance(domain_kwarg, past.TupleExpr): + out_kwarg = new_kwargs["out"] + assert isinstance(out_kwarg, past.TupleExpr) + assert len(out_kwarg.elts) == len(domain_kwarg.elts) + domain_kwarg_ = domain_kwarg.elts + for dom in domain_kwarg_: + if len(dom.values_) == 0 and len(dom.keys_) == 0: + raise ValueError("Empty domain not allowed.") + + for dim in dom.keys_: + if not isinstance(dim.type, ts.DimensionType): + raise ValueError( + f"Only 'Dimension' allowed in domain dictionary keys, got '{dim}' which is of type '{dim.type}'." + ) + for domain_values in dom.values_: + if len(domain_values.elts) != 2: + raise ValueError( + f"Only 2 values allowed in domain range, got {len(domain_values.elts)}." + ) + if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar( + domain_values.elts[1] + ): + raise ValueError( + f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'." + ) class ProgramTypeDeduction(traits.VisitorWithSymbolTableTrait, NodeTranslator): @@ -133,9 +140,16 @@ def visit_Attribute(self, node: past.Attribute, **kwargs: Any) -> past.Attribute def visit_TupleExpr(self, node: past.TupleExpr, **kwargs: Any) -> past.TupleExpr: elts = self.visit(node.elts, **kwargs) - return past.TupleExpr( - elts=elts, type=ts.TupleType(types=[el.type for el in elts]), location=node.location - ) + ttype: ts.TupleType + if any(isinstance(elt, past.Dict) for elt in node.elts): + assert all(isinstance(elt, past.Dict) for elt in node.elts) + # TODO: add check that Dict is DomainLike + ttype = ts.TupleType( + types=[ts.DomainType(dims=[common.Dimension(elt.keys_[0].id)]) for elt in elts] + ) + else: + ttype = ts.TupleType(types=[elt.type for elt in elts]) + return past.TupleExpr(elts=elts, type=ttype, location=node.location) def _deduce_binop_type( self, node: past.BinOp, *, left: past.Expr, right: past.Expr, **kwargs: Any diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index cd51fe567d..4bb76986cc 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -160,15 +160,17 @@ def _range_arg_from_field(field_name: str, dim: int) -> str: return f"__{field_name}_{dim}_range" -def _flatten_tuple_expr(node: past.Expr) -> list[past.Name | past.Subscript]: - if isinstance(node, (past.Name, past.Subscript)): +def _flatten_tuple_expr(node: past.Expr) -> list[past.Name | past.Subscript | past.Dict]: + if isinstance(node, (past.Name, past.Subscript, past.Dict)): return [node] elif isinstance(node, past.TupleExpr): result = [] for e in node.elts: result.extend(_flatten_tuple_expr(e)) return result - raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.") + raise ValueError( + f"Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed, got '{type(node)}'." + ) @dataclasses.dataclass @@ -346,7 +348,7 @@ def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: def _construct_itir_domain_arg( self, - out_field: past.Name, + out_field: past.Name | past.Subscript | past.Dict, node_domain: Optional[past.Expr], slices: Optional[list[past.Slice]] = None, ) -> itir.FunCall: @@ -471,30 +473,37 @@ def _visit_stencil_call_out_arg( flattened = _flatten_tuple_expr(out_arg) first_field = flattened[0] - assert all( - self.visit(field.type).dims == self.visit(first_field.type).dims - for field in flattened - ), "Incompatible fields in tuple: all fields must have the same dimensions." field_slice = None if isinstance(first_field, past.Subscript): + raise AssertionError # TODO support slicing of multiple fields with different domain assert all(isinstance(field, past.Subscript) for field in flattened), ( "Incompatible field in tuple: either all fields or no field must be sliced." ) assert all( concepts.eq_nonlocated( first_field.slice_, - field.slice_, # type: ignore[union-attr] # mypy cannot deduce type + field.slice_, ) for field in flattened ), "Incompatible field in tuple: all fields must be sliced in the same way." field_slice = self._compute_field_slice(first_field) first_field = first_field.value - return ( - self._construct_itir_out_arg(out_arg), - self._construct_itir_domain_arg(first_field, domain_arg, field_slice), - ) + if isinstance(domain_arg, past.TupleExpr): + domain_args = [ + self._construct_itir_domain_arg(field, domain, None) + for field, domain in zip( + flattened, _flatten_tuple_expr(domain_arg), strict=True + ) + ] + domain_expr = im.make_tuple(*domain_args) + return self._construct_itir_out_arg(out_arg), domain_expr + else: + return ( + self._construct_itir_out_arg(out_arg), + self._construct_itir_domain_arg(first_field, domain_arg, field_slice), + ) else: raise AssertionError( "Unexpected 'out' argument. Must be a 'past.Subscript', 'past.Name' or 'past.TupleExpr' node." diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 3888ccf2de..6901888f21 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1674,8 +1674,12 @@ def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProvider @runtime.set_at.register(EMBEDDED) -def set_at(expr: common.Field, domain: common.DomainLike, target: common.MutableField) -> None: - operators._tuple_assign_field(target, expr, common.domain(domain)) +def set_at( + expr: common.Field, + domain: common.DomainLike | tuple[common.DomainLike | tuple, ...], + target: common.MutableField, +) -> None: + operators._tuple_assign_field(target, expr, domain) @runtime.if_stmt.register(EMBEDDED) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index 00ff9abbd9..e91bf14b9a 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -227,8 +227,16 @@ def grid_type_from_domain(domain: itir.FunCall) -> common.GridType: return common.GridType.UNSTRUCTURED +def _flatten_tuple_expr(domain_expr: itir.Expr) -> tuple[itir.Expr]: + if cpm.is_call_to(domain_expr, "make_tuple"): + return sum((_flatten_tuple_expr(arg) for arg in domain_expr.args), start=()) + else: + return (domain_expr,) + + def grid_type_from_program(program: itir.Program) -> common.GridType: - domains = program.walk_values().if_isinstance(itir.SetAt).getattr("domain").to_set() + domain_exprs = program.walk_values().if_isinstance(itir.SetAt).getattr("domain").to_set() + domains = sum((_flatten_tuple_expr(domain_expr) for domain_expr in domain_exprs), start=()) grid_types = {grid_type_from_domain(d) for d in domains} if len(grid_types) != 1: raise ValueError( diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index c22b775468..7fa5151e3a 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -519,6 +519,13 @@ def infer_expr( return expr, accessed_domains +def _make_symbolic_domain_tuple(domains: itir.Node) -> DomainAccess: + if cpm.is_call_to(domains, "make_tuple"): + return tuple(_make_symbolic_domain_tuple(arg) for arg in domains.args) + else: + return SymbolicDomain.from_expr(domains) + + def _infer_stmt( stmt: itir.Stmt, **kwargs: Unpack[InferenceOptions], @@ -528,9 +535,9 @@ def _infer_stmt( # between the domain stored in IR and in the annex domain = constant_folding.ConstantFolding.apply(stmt.domain) - transformed_call, _ = infer_expr( - stmt.expr, domain_utils.SymbolicDomain.from_expr(domain), **kwargs - ) + symbolic_domain = _make_symbolic_domain_tuple(domain) + + transformed_call, _ = infer_expr(stmt.expr, symbolic_domain, **kwargs) return itir.SetAt( expr=transformed_call, From d86551413f0205c5f321d1a33603990f98759bd4 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 22 Aug 2025 14:29:23 +0200 Subject: [PATCH 10/44] Fix some tests --- src/gt4py/next/embedded/operators.py | 4 ++-- src/gt4py/next/ffront/past_passes/type_deduction.py | 3 ++- tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index b309c6601e..6e8d79afb7 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -142,9 +142,9 @@ def _get_vertical_range( domain: common.Domain | tuple[common.Domain, ...], ) -> common.NamedRange | eve.NothingType | tuple[common.NamedRange | eve.NothingType, ...]: if isinstance(domain, tuple): - return tuple(_get_vertical_range(sub) for sub in domain) + return tuple(_get_vertical_range(dom) for dom in domain) else: - vertical_dim_filtered = [nr for nr in domain if nr.dim.kind == common.DimensionKind.VERTICAL] + vertical_dim_filtered = [nr for nr in common.domain(domain) if nr.dim.kind == common.DimensionKind.VERTICAL] assert len(vertical_dim_filtered) <= 1 return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 2980014e27..90c68f2ff9 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -63,7 +63,8 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict) -> None: domain_kwarg_ = [domain_kwarg] elif isinstance(domain_kwarg, past.TupleExpr): out_kwarg = new_kwargs["out"] - assert isinstance(out_kwarg, past.TupleExpr) + if not isinstance(out_kwarg, past.TupleExpr): + raise ValueError(f"TupleExpr are only allowed in 'domain', if '{out_kwarg}' is a tuple as well.") assert len(out_kwarg.elts) == len(domain_kwarg.elts) domain_kwarg_ = domain_kwarg.elts for dom in domain_kwarg_: diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index ad985e7ee8..4185ef0bd7 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -154,7 +154,7 @@ def domain_format_1_program(in_field: gtx.Field[[IDim], float64]): assert exc_info.match("Invalid call to 'domain_format_1'") assert ( - re.search("Only Dictionaries allowed in 'domain'", exc_info.value.__cause__.args[0]) + re.search("TupleExpr are only allowed in 'domain', if", exc_info.value.__cause__.args[0]) is not None ) From 65a831b9f91dfa6f611599349dc3227a0b19ec22 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 22 Aug 2025 16:01:28 +0200 Subject: [PATCH 11/44] Start working on direct fo calls with multiple output domains --- src/gt4py/next/ffront/decorator.py | 15 ++++++++++-- src/gt4py/next/ffront/past_process_args.py | 25 +++++++++---------- src/gt4py/next/ffront/past_to_itir.py | 28 ++++++++++++++-------- 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index fe3e2410fc..98390b4c96 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -625,6 +625,16 @@ def program_inner(definition: types.FunctionType) -> Program: OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) +def _slice_outs( + outs: common.Field | tuple[common.Field | tuple, ...], + domains: common.Domain | tuple[common.Domain | tuple, ...], +) -> common.Field | tuple[common.Field | tuple, ...]: + if isinstance(outs, tuple): + if not isinstance(domains, tuple): + domains = tuple([domains] * len(outs)) + return tuple(_slice_outs(out, domain) for out, domain in zip(outs, domains, strict=True)) + else: + return outs[common.domain(domains)] @dataclasses.dataclass(frozen=True) class FieldOperator(GTCallable, Generic[OperatorNodeT]): @@ -767,8 +777,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: raise errors.MissingArgumentError(None, "out", True) out = kwargs.pop("out") if "domain" in kwargs: - domain = common.domain(kwargs.pop("domain")) - out = utils.tree_map(lambda f: f[domain])(out) + out = _slice_outs(out, kwargs.pop("domain")) + #domain = common.domain(kwargs.pop("domain")) + #out = utils.tree_map(lambda f: f[domain])(out) args, kwargs = type_info.canonicalize_arguments( self.foast_stage.foast_node.type, args, kwargs diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index f0360e05ba..69789ba2e9 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -95,19 +95,20 @@ def _process_args( # for `(scalar, (field1, field2))` the two fields need to have the same shape and # dimension if ranges_and_dims: - range_, dims = ranges_and_dims[0] - if not all( - el_range == range_ and el_dims == dims - for (el_range, el_dims) in ranges_and_dims - ): - raise ValueError( - "Constituents of composite arguments (e.g. the elements of a" - " tuple) need to have the same shape and dimensions." - ) + # range_, dims = ranges_and_dims[0] # TODO + # if not all( + # el_range == range_ and el_dims == dims + # for (el_range, el_dims) in ranges_and_dims + # ): + # raise ValueError( + # "Constituents of composite arguments (e.g. the elements of a" + # " tuple) need to have the same shape and dimensions." + # ) index_type = ts.ScalarType(kind=ts.ScalarKind.INT32) - size_args.extend( - range_ if range_ else [ts.TupleType(types=[index_type, index_type])] * len(dims) # type: ignore[arg-type] # shape is always empty - ) + for range_, dims in ranges_and_dims: + size_args.extend( + range_ if range_ else [ts.TupleType(types=[index_type, index_type])] * len(dims) # type: ignore[arg-type] # shape is always empty + ) return tuple(rewritten_args), tuple(size_args), kwargs diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 4bb76986cc..9616fe8fab 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -236,7 +236,7 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: .to_list() ) if len(fields_dims) > 0: # otherwise `param` has no constituent which is of `FieldType` - assert all(field_dims == fields_dims[0] for field_dims in fields_dims) + #assert all(field_dims == fields_dims[0] for field_dims in fields_dims) index_type = ts.ScalarType( kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) ) @@ -247,6 +247,14 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: type=ts.TupleType(types=[index_type, index_type]), ) ) + for field_dims in fields_dims: + for dim_idx in range(len(field_dims)): + size_params.append( + itir.Sym( + id=_range_arg_from_field(param.id, dim_idx), + type=ts.TupleType(types=[index_type, index_type]), + ) + ) return size_params @@ -355,15 +363,15 @@ def _construct_itir_domain_arg( assert isinstance(out_field.type, ts.TypeSpec) out_field_types = type_info.primitive_constituents(out_field.type).to_list() out_dims = cast(ts.FieldType, out_field_types[0]).dims - if any( - not isinstance(out_field_type, ts.FieldType) or out_field_type.dims != out_dims - for out_field_type in out_field_types - ): - raise AssertionError( - f"Expected constituents of '{out_field.id}' argument to be" - " fields defined on the same dimensions. This error should be " - " caught in type deduction already." - ) + # if any( + # not isinstance(out_field_type, ts.FieldType) or out_field_type.dims != out_dims + # for out_field_type in out_field_types + # ): # TODO + # raise AssertionError( + # f"Expected constituents of '{out_field.id}' argument to be" + # " fields defined on the same dimensions. This error should be " + # " caught in type deduction already." + # ) domain_args = [] domain_args_kind = [] From 510ea691a5efa41f7b3dae3b0a4100dbe8472943 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 22 Aug 2025 16:08:05 +0200 Subject: [PATCH 12/44] Add tests --- .../test_multiple_output_domains.py | 129 ++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py new file mode 100644 index 0000000000..0fea268734 --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -0,0 +1,129 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import numpy as np +import pytest + +import gt4py.next as gtx + +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import IDim, JDim, cartesian_case +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import exec_alloc_descriptor + +pytestmark = pytest.mark.uses_cartesian_shift + + +@gtx.field_operator +def testee_orig( + a: gtx.Field[[IDim], gtx.float32], b: gtx.Field[[IDim], gtx.float32] +) -> tuple[gtx.Field[[IDim], gtx.float32], gtx.Field[[IDim], gtx.float32]]: + return b, a + + +@gtx.program +def prog_orig( + a: gtx.Field[[IDim], gtx.float32], + b: gtx.Field[[IDim], gtx.float32], + out_a: gtx.Field[[IDim], gtx.float32], + out_b: gtx.Field[[IDim], gtx.float32], +): + testee_orig(a, b, out=(out_b, out_a), domain={IDim: (0, 10)}) + + +def test_program_orig(cartesian_case): + a = cases.allocate(cartesian_case, prog_orig, "a")() + b = cases.allocate(cartesian_case, prog_orig, "b")() + out_a = cases.allocate(cartesian_case, prog_orig, "out_a")() + out_b = cases.allocate(cartesian_case, prog_orig, "out_b")() + + cases.verify( + cartesian_case, + prog_orig, + a, + b, + out_a, + out_b, + inout=(out_b, out_a), + ref=(b, a), + ) + + +@gtx.field_operator +def testee( + a: gtx.Field[[IDim], gtx.float32], b: gtx.Field[[JDim], gtx.float32] +) -> tuple[gtx.Field[[JDim], gtx.float32], gtx.Field[[IDim], gtx.float32]]: + return b, a + + +@gtx.program +def prog( + a: gtx.Field[[IDim], gtx.float32], + b: gtx.Field[[JDim], gtx.float32], + out_a: gtx.Field[[IDim], gtx.float32], + out_b: gtx.Field[[JDim], gtx.float32], + i_size: gtx.int32, + j_size: gtx.int32, +): + testee(a, b, out=(out_b, out_a), domain=({JDim: (0, j_size)}, {IDim: (0, i_size)})) + + +def test_program(cartesian_case): + a = cases.allocate(cartesian_case, prog, "a")() + b = cases.allocate(cartesian_case, prog, "b")() + out_a = cases.allocate(cartesian_case, prog, "out_a")() + out_b = cases.allocate(cartesian_case, prog, "out_b")() + + cases.verify( + cartesian_case, + prog, + a, + b, + out_a, + out_b, + cartesian_case.default_sizes[IDim], + cartesian_case.default_sizes[JDim], + inout=(out_b, out_a), + ref=(b, a), + ) + + +def test_direct_fo_orig(cartesian_case): + a = cases.allocate(cartesian_case, testee_orig, "a")() + b = cases.allocate(cartesian_case, testee_orig, "b")() + out = cases.allocate(cartesian_case, testee_orig, cases.RETURN)() + + cases.verify( + cartesian_case, + testee_orig, + a, + b, + out=out, + ref=(b, a), + domain={IDim: (0, cartesian_case.default_sizes[IDim])} + ) + + +def test_direct_fo(cartesian_case): + a = cases.allocate(cartesian_case, testee, "a")() + b = cases.allocate(cartesian_case, testee, "b")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify( + cartesian_case, + testee, + a, + b, + out=out, + ref=(b, a), + domain=( + {JDim: (0, cartesian_case.default_sizes[JDim])}, + {IDim: (0, cartesian_case.default_sizes[IDim])}, + ), + ) + + From b910167f0f7fab97c79df87f1494279b0e269d10 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 22 Aug 2025 17:47:55 +0200 Subject: [PATCH 13/44] Fix embedded domain promotion --- src/gt4py/next/embedded/operators.py | 14 ++----- src/gt4py/next/ffront/decorator.py | 26 ++++++------ src/gt4py/next/ffront/past_process_args.py | 30 +++++++------ src/gt4py/next/ffront/past_to_itir.py | 36 ++++++++-------- .../test_multiple_output_domains.py | 42 +++++++++++-------- 5 files changed, 76 insertions(+), 72 deletions(-) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index 6e8d79afb7..85b8cf11bb 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -108,16 +108,7 @@ def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) -> domain = kwargs.pop("domain", None) - if domain is not None: - if isinstance(out, tuple) and not isinstance(domain, tuple): - out_domain = tuple([domain] * len(out)) - else: - out_domain = domain - else: - if isinstance(out, tuple): - out_domain = tuple(_get_out_domain(o) for o in out) - else: - out_domain = _get_out_domain(out) + out_domain = domain if domain is not None else _get_out_domain(out) # TODO? new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain) @@ -163,9 +154,12 @@ def impl(target: common.MutableField, source: common.Field, domain: common.Domai assert core_defs.is_scalar_type(source) target[domain] = source + if not isinstance(domain, tuple): # TODO: use a generic condition that also works for nested domains and targets + domain = utils.tree_map(lambda _: domain)(target) impl(target, source, domain) + def _intersect_scan_args( *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...], ) -> common.Domain: diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 98390b4c96..8cde17ffc7 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -625,16 +625,16 @@ def program_inner(definition: types.FunctionType) -> Program: OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) -def _slice_outs( - outs: common.Field | tuple[common.Field | tuple, ...], - domains: common.Domain | tuple[common.Domain | tuple, ...], -) -> common.Field | tuple[common.Field | tuple, ...]: - if isinstance(outs, tuple): - if not isinstance(domains, tuple): - domains = tuple([domains] * len(outs)) - return tuple(_slice_outs(out, domain) for out, domain in zip(outs, domains, strict=True)) - else: - return outs[common.domain(domains)] +# def _slice_outs( +# outs: common.Field | tuple[common.Field | tuple, ...], +# domains: common.Domain | tuple[common.Domain | tuple, ...], +# ) -> common.Field | tuple[common.Field | tuple, ...]: +# if isinstance(outs, tuple): +# if not isinstance(domains, tuple): +# domains = tuple([domains] * len(outs)) +# return tuple(_slice_outs(out, domain) for out, domain in zip(outs, domains, strict=True)) +# else: +# return outs[common.domain(domains)] @dataclasses.dataclass(frozen=True) class FieldOperator(GTCallable, Generic[OperatorNodeT]): @@ -777,9 +777,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: raise errors.MissingArgumentError(None, "out", True) out = kwargs.pop("out") if "domain" in kwargs: - out = _slice_outs(out, kwargs.pop("domain")) - #domain = common.domain(kwargs.pop("domain")) - #out = utils.tree_map(lambda f: f[domain])(out) + # out = _slice_outs(out, kwargs.pop("domain")) + domain = common.domain(kwargs.pop("domain")) + out = utils.tree_map(lambda f: f[domain])(out) args, kwargs = type_info.canonicalize_arguments( self.foast_stage.foast_node.type, args, kwargs diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index 69789ba2e9..d382346f59 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -95,20 +95,24 @@ def _process_args( # for `(scalar, (field1, field2))` the two fields need to have the same shape and # dimension if ranges_and_dims: - # range_, dims = ranges_and_dims[0] # TODO - # if not all( - # el_range == range_ and el_dims == dims - # for (el_range, el_dims) in ranges_and_dims - # ): - # raise ValueError( - # "Constituents of composite arguments (e.g. the elements of a" - # " tuple) need to have the same shape and dimensions." - # ) - index_type = ts.ScalarType(kind=ts.ScalarKind.INT32) - for range_, dims in ranges_and_dims: - size_args.extend( - range_ if range_ else [ts.TupleType(types=[index_type, index_type])] * len(dims) # type: ignore[arg-type] # shape is always empty + range_, dims = ranges_and_dims[0] + if not all( + el_range == range_ and el_dims == dims + for (el_range, el_dims) in ranges_and_dims + ): + raise ValueError( + "Constituents of composite arguments (e.g. the elements of a" + " tuple) need to have the same shape and dimensions." ) + index_type = ts.ScalarType(kind=ts.ScalarKind.INT32) + size_args.extend( + range_ if range_ else [ts.TupleType(types=[index_type, index_type])] * len(dims) + # type: ignore[arg-type] # shape is always empty + ) + # for range_, dims in ranges_and_dims: + # size_args.extend( + # range_ if range_ else [ts.TupleType(types=[index_type, index_type])] * len(dims) # type: ignore[arg-type] # shape is always empty + # ) return tuple(rewritten_args), tuple(size_args), kwargs diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 9616fe8fab..944a6bdbf9 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -236,7 +236,7 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: .to_list() ) if len(fields_dims) > 0: # otherwise `param` has no constituent which is of `FieldType` - #assert all(field_dims == fields_dims[0] for field_dims in fields_dims) + assert all(field_dims == fields_dims[0] for field_dims in fields_dims) index_type = ts.ScalarType( kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) ) @@ -247,14 +247,14 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: type=ts.TupleType(types=[index_type, index_type]), ) ) - for field_dims in fields_dims: - for dim_idx in range(len(field_dims)): - size_params.append( - itir.Sym( - id=_range_arg_from_field(param.id, dim_idx), - type=ts.TupleType(types=[index_type, index_type]), - ) - ) + # for field_dims in fields_dims: + # for dim_idx in range(len(field_dims)): + # size_params.append( + # itir.Sym( + # id=_range_arg_from_field(param.id, dim_idx), + # type=ts.TupleType(types=[index_type, index_type]), + # ) + # ) return size_params @@ -363,15 +363,15 @@ def _construct_itir_domain_arg( assert isinstance(out_field.type, ts.TypeSpec) out_field_types = type_info.primitive_constituents(out_field.type).to_list() out_dims = cast(ts.FieldType, out_field_types[0]).dims - # if any( - # not isinstance(out_field_type, ts.FieldType) or out_field_type.dims != out_dims - # for out_field_type in out_field_types - # ): # TODO - # raise AssertionError( - # f"Expected constituents of '{out_field.id}' argument to be" - # " fields defined on the same dimensions. This error should be " - # " caught in type deduction already." - # ) + if any( + not isinstance(out_field_type, ts.FieldType) or out_field_type.dims != out_dims + for out_field_type in out_field_types + ): # TODO + raise AssertionError( + f"Expected constituents of '{out_field.id}' argument to be" + " fields defined on the same dimensions. This error should be " + " caught in type deduction already." + ) domain_args = [] domain_args_kind = [] diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py index 0fea268734..10b6d5ac70 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -6,7 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import numpy as np import pytest import gt4py.next as gtx @@ -107,23 +106,30 @@ def test_direct_fo_orig(cartesian_case): domain={IDim: (0, cartesian_case.default_sizes[IDim])} ) +# TODO: +# - test without domain +# - test with nested tuples +# - test with different vertical domains KDim and KHalfDim +# - test from https://hackmd.io/m__8sBBATiqFWOPNMEPsfg +# - unstructured test with Local dimensions e.g. Vertex, E2V and Edge -def test_direct_fo(cartesian_case): - a = cases.allocate(cartesian_case, testee, "a")() - b = cases.allocate(cartesian_case, testee, "b")() - out = cases.allocate(cartesian_case, testee, cases.RETURN)() - - cases.verify( - cartesian_case, - testee, - a, - b, - out=out, - ref=(b, a), - domain=( - {JDim: (0, cartesian_case.default_sizes[JDim])}, - {IDim: (0, cartesian_case.default_sizes[IDim])}, - ), - ) +# +# def test_direct_fo(cartesian_case): +# a = cases.allocate(cartesian_case, testee, "a")() +# b = cases.allocate(cartesian_case, testee, "b")() +# out = cases.allocate(cartesian_case, testee, cases.RETURN)() +# +# cases.verify( +# cartesian_case, +# testee, +# a, +# b, +# out=out, +# ref=(b, a), +# domain=( +# {JDim: (0, cartesian_case.default_sizes[JDim])}, +# {IDim: (0, cartesian_case.default_sizes[IDim])}, +# ), +# ) From 8b02398efc86fbc6e482b8d5d06cbbd986998011 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 29 Aug 2025 11:22:55 +0200 Subject: [PATCH 14/44] Add more tests and extend type deduction --- src/gt4py/next/embedded/operators.py | 9 +- .../next/ffront/past_passes/type_deduction.py | 86 ++--- src/gt4py/next/ffront/past_to_itir.py | 6 +- tests/next_tests/integration_tests/cases.py | 2 + .../ffront_tests/ffront_test_utils.py | 10 +- .../test_multiple_output_domains.py | 297 +++++++++++++++++- 6 files changed, 355 insertions(+), 55 deletions(-) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index 85b8cf11bb..d1b308fc7c 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -135,7 +135,9 @@ def _get_vertical_range( if isinstance(domain, tuple): return tuple(_get_vertical_range(dom) for dom in domain) else: - vertical_dim_filtered = [nr for nr in common.domain(domain) if nr.dim.kind == common.DimensionKind.VERTICAL] + vertical_dim_filtered = [ + nr for nr in common.domain(domain) if nr.dim.kind == common.DimensionKind.VERTICAL + ] assert len(vertical_dim_filtered) <= 1 return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING @@ -154,12 +156,13 @@ def impl(target: common.MutableField, source: common.Field, domain: common.Domai assert core_defs.is_scalar_type(source) target[domain] = source - if not isinstance(domain, tuple): # TODO: use a generic condition that also works for nested domains and targets + if not isinstance( + domain, tuple + ): # TODO: use a generic condition that also works for nested domains and targets domain = utils.tree_map(lambda _: domain)(target) impl(target, source, domain) - def _intersect_scan_args( *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...], ) -> common.Domain: diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 90c68f2ff9..7ef69712fe 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -56,37 +56,44 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict) -> None: if "domain" in new_kwargs: _ensure_no_sliced_field(new_kwargs["out"]) - domain_kwarg = new_kwargs["domain"] - if not isinstance(domain_kwarg, (past.Dict, past.TupleExpr)): - raise ValueError(f"Only Dictionaries allowed in 'domain', got '{type(domain_kwarg)}'.") - if isinstance(domain_kwarg, past.Dict): - domain_kwarg_ = [domain_kwarg] - elif isinstance(domain_kwarg, past.TupleExpr): - out_kwarg = new_kwargs["out"] - if not isinstance(out_kwarg, past.TupleExpr): - raise ValueError(f"TupleExpr are only allowed in 'domain', if '{out_kwarg}' is a tuple as well.") - assert len(out_kwarg.elts) == len(domain_kwarg.elts) - domain_kwarg_ = domain_kwarg.elts - for dom in domain_kwarg_: - if len(dom.values_) == 0 and len(dom.keys_) == 0: - raise ValueError("Empty domain not allowed.") - - for dim in dom.keys_: - if not isinstance(dim.type, ts.DimensionType): + def check(dom, out): + if isinstance(dom, past.Dict): + if len(dom.values_) == 0 and len(dom.keys_) == 0: + raise ValueError("Empty domain not allowed.") + + for dim in dom.keys_: + if not isinstance(dim.type, ts.DimensionType): + raise ValueError( + f"Only 'Dimension' allowed in domain dictionary keys, got '{dim}' which is of type '{dim.type}'." + ) + for domain_values in dom.values_: + if len(domain_values.elts) != 2: + raise ValueError( + f"Only 2 values allowed in domain range, got {len(domain_values.elts)}." + ) + if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar( + domain_values.elts[1] + ): + raise ValueError( + f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'." + ) + elif isinstance(dom, past.TupleExpr): + if not isinstance(out, past.TupleExpr) and not isinstance(out.type, ts.TupleType): raise ValueError( - f"Only 'Dimension' allowed in domain dictionary keys, got '{dim}' which is of type '{dim.type}'." - ) - for domain_values in dom.values_: - if len(domain_values.elts) != 2: - raise ValueError( - f"Only 2 values allowed in domain range, got {len(domain_values.elts)}." - ) - if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar( - domain_values.elts[1] - ): - raise ValueError( - f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'." + f"TupleExpr are only allowed in 'domain', if '{out}' is a tuple as well." ) + if isinstance(out, past.TupleExpr): + out_elts = out.elts + assert len(out_elts) == len(dom.elts) + else: + out_elts = out.type.types + assert len(out_elts) == len(dom.elts) + for d, o in zip(dom.elts, out_elts): + check(d, o) + else: + raise ValueError(f"Only Dictionaries allowed in 'domain', got '{type(dom)}'.") + + check(new_kwargs["domain"], new_kwargs["out"]) class ProgramTypeDeduction(traits.VisitorWithSymbolTableTrait, NodeTranslator): @@ -141,15 +148,22 @@ def visit_Attribute(self, node: past.Attribute, **kwargs: Any) -> past.Attribute def visit_TupleExpr(self, node: past.TupleExpr, **kwargs: Any) -> past.TupleExpr: elts = self.visit(node.elts, **kwargs) - ttype: ts.TupleType - if any(isinstance(elt, past.Dict) for elt in node.elts): - assert all(isinstance(elt, past.Dict) for elt in node.elts) - # TODO: add check that Dict is DomainLike - ttype = ts.TupleType( - types=[ts.DomainType(dims=[common.Dimension(elt.keys_[0].id)]) for elt in elts] - ) + if any(isinstance(elt, past.Dict) for elt in elts): + assert all(isinstance(elt, (past.Dict, past.TupleExpr)) for elt in elts) + + def infer_type(elt): + if isinstance(elt, past.Dict): + # TODO: add check that Dict is DomainLike + return ts.DomainType(dims=[common.Dimension(elt.keys_[0].id)]) + elif isinstance(elt, past.TupleExpr): + return ts.TupleType(types=[infer_type(elt) for elt in elt.elts]) + else: + raise AssertionError(f"Unexpected element type {type(elt)} inside TupleExpr") + + ttype = ts.TupleType(types=[infer_type(elt) for elt in elts]) else: ttype = ts.TupleType(types=[elt.type for elt in elts]) + return past.TupleExpr(elts=elts, type=ttype, location=node.location) def _deduce_binop_type( diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 944a6bdbf9..70d3dc12d5 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -366,7 +366,7 @@ def _construct_itir_domain_arg( if any( not isinstance(out_field_type, ts.FieldType) or out_field_type.dims != out_dims for out_field_type in out_field_types - ): # TODO + ): # TODO raise AssertionError( f"Expected constituents of '{out_field.id}' argument to be" " fields defined on the same dimensions. This error should be " @@ -499,11 +499,11 @@ def _visit_stencil_call_out_arg( first_field = first_field.value if isinstance(domain_arg, past.TupleExpr): - domain_args = [ + domain_args = [ # TODO: Test with out as one argument which is a tuple, don't flatten field self._construct_itir_domain_arg(field, domain, None) for field, domain in zip( flattened, _flatten_tuple_expr(domain_arg), strict=True - ) + ) # TODO use field type -> apply_to_primitive_constituents, path -> find relevant Dict expr -> call _construct_itir_domain_arg, test with wrong structure as well ] domain_expr = im.make_tuple(*domain_args) return self._construct_itir_out_arg(out_arg), domain_expr diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 967cf0ab11..19e8e3ea6d 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -50,6 +50,7 @@ JDim, Joff, KDim, + KHalfDim, Koff, V2EDim, Vertex, @@ -701,6 +702,7 @@ def from_cartesian_grid_descriptor( IDim: grid_descriptor.sizes[0], JDim: grid_descriptor.sizes[1], KDim: grid_descriptor.sizes[2], + KHalfDim: grid_descriptor.sizes[3], }, grid_type=common.GridType.CARTESIAN, allocator=allocator, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index b2cb8b0a2c..cd6aaf5de3 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -135,6 +135,7 @@ def debug_itir(tree): IDim = gtx.Dimension("IDim") JDim = gtx.Dimension("JDim") KDim = gtx.Dimension("KDim", kind=gtx.DimensionKind.VERTICAL) +KHalfDim = gtx.Dimension("KHalf", kind=gtx.DimensionKind.VERTICAL) Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) Joff = gtx.FieldOffset("Joff", source=JDim, target=(JDim,)) Koff = gtx.FieldOffset("Koff", source=KDim, target=(KDim,)) @@ -170,15 +171,18 @@ def offset_provider(self) -> common.OffsetProvider: ... def offset_provider_type(self) -> common.OffsetProviderType: ... -def simple_cartesian_grid(sizes: int | tuple[int, int, int] = 10) -> CartesianGridDescriptor: +def simple_cartesian_grid( + sizes: int | tuple[int, int, int, int] = (5, 7, 9, 11), +) -> CartesianGridDescriptor: if isinstance(sizes, int): - sizes = (sizes,) * 3 - assert len(sizes) == 3, "sizes must be a tuple of three integers" + sizes = (sizes,) * 4 + assert len(sizes) == 4, "sizes must be a tuple of four integers" offset_provider = { "Ioff": IDim, "Joff": JDim, "Koff": KDim, + "KHalfoff": KHalfDim, } return types.SimpleNamespace( diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py index 10b6d5ac70..b4a7df9127 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -11,9 +11,27 @@ import gt4py.next as gtx from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import IDim, JDim, cartesian_case -from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import exec_alloc_descriptor +from next_tests.integration_tests.cases import ( + IDim, + JDim, + KDim, + C2E, + E2V, + V2E, + Edge, + Cell, + Vertex, + cartesian_case, + Case, +) +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + exec_alloc_descriptor, + mesh_descriptor, +) +from gt4py.next import common + +KHalfDim = gtx.Dimension("KHalf", kind=gtx.DimensionKind.VERTICAL) pytestmark = pytest.mark.uses_cartesian_shift @@ -30,8 +48,9 @@ def prog_orig( b: gtx.Field[[IDim], gtx.float32], out_a: gtx.Field[[IDim], gtx.float32], out_b: gtx.Field[[IDim], gtx.float32], + i_size: gtx.int32, ): - testee_orig(a, b, out=(out_b, out_a), domain={IDim: (0, 10)}) + testee_orig(a, b, out=(out_b, out_a), domain={IDim: (0, i_size)}) def test_program_orig(cartesian_case): @@ -47,6 +66,35 @@ def test_program_orig(cartesian_case): b, out_a, out_b, + cartesian_case.default_sizes[IDim], + inout=(out_b, out_a), + ref=(b, a), + ) + + +@gtx.program +def prog_no_domain( + a: gtx.Field[[IDim], gtx.float32], + b: gtx.Field[[IDim], gtx.float32], + out_a: gtx.Field[[IDim], gtx.float32], + out_b: gtx.Field[[IDim], gtx.float32], +): + testee_orig(a, b, out=(out_b, out_a)) + + +def test_program_no_domain(cartesian_case): + a = cases.allocate(cartesian_case, prog_no_domain, "a")() + b = cases.allocate(cartesian_case, prog_no_domain, "b")() + out_a = cases.allocate(cartesian_case, prog_no_domain, "out_a")() + out_b = cases.allocate(cartesian_case, prog_no_domain, "out_b")() + + cases.verify( + cartesian_case, + prog_no_domain, + a, + b, + out_a, + out_b, inout=(out_b, out_a), ref=(b, a), ) @@ -91,6 +139,238 @@ def test_program(cartesian_case): ) +@gtx.program +def prog_out_as_tuple( + a: gtx.Field[[IDim], gtx.float32], + b: gtx.Field[[JDim], gtx.float32], + out: tuple[gtx.Field[[JDim], gtx.float32], gtx.Field[[IDim], gtx.float32]], + i_size: gtx.int32, + j_size: gtx.int32, +): + testee(a, b, out=out, domain=({JDim: (0, j_size)}, {IDim: (0, i_size)})) + + +def test_program_out_as_tuple( + cartesian_case, +): # TODO: this fails for most backends, merge PR #1893 first + a = cases.allocate(cartesian_case, prog_out_as_tuple, "a")() + b = cases.allocate(cartesian_case, prog_out_as_tuple, "b")() + out = cases.allocate(cartesian_case, prog_out_as_tuple, "out")() + + cases.verify( + cartesian_case, + prog_out_as_tuple, + a, + b, + out, + cartesian_case.default_sizes[IDim], + cartesian_case.default_sizes[JDim], + inout=(out), + ref=(b, a), + ) + + +@gtx.field_operator +def testee_nested_tuples( + a: gtx.Field[[IDim], gtx.float32], + b: gtx.Field[[JDim], gtx.float32], + c: gtx.Field[[KDim], gtx.float32], +) -> tuple[ + tuple[gtx.Field[[IDim], gtx.float32], gtx.Field[[IDim], gtx.float32]], + gtx.Field[[IDim], gtx.float32], +]: + return (a, a), a + + +@gtx.program +def prog_nested_tuples( + a: gtx.Field[[IDim], gtx.float32], + b: gtx.Field[[JDim], gtx.float32], + c: gtx.Field[[KDim], gtx.float32], + out_a: gtx.Field[[IDim], gtx.float32], + out_b: gtx.Field[[JDim], gtx.float32], + out_c: gtx.Field[[KDim], gtx.float32], + i_size: gtx.int32, + j_size: gtx.int32, + k_size: gtx.int32, +): + testee_nested_tuples( + a, + b, + c, + out=((out_a, out_a), out_a), + domain=(({IDim: (0, i_size)}, {IDim: (0, i_size)}), {IDim: (0, i_size)}), + ) # TODO: use JDim, KDim + + +def test_program_nested_tuples( + cartesian_case, +): # TODO: this fails for most backends, merge PR #1893 first + a = cases.allocate(cartesian_case, prog_nested_tuples, "a")() + b = cases.allocate(cartesian_case, prog_nested_tuples, "b")() + c = cases.allocate(cartesian_case, prog_nested_tuples, "c")() + out_a = cases.allocate(cartesian_case, prog_nested_tuples, "out_a")() + out_b = cases.allocate(cartesian_case, prog_nested_tuples, "out_b")() + out_c = cases.allocate(cartesian_case, prog_nested_tuples, "out_c")() + + cases.verify( + cartesian_case, + prog_nested_tuples, + a, + b, + c, + out_a, + out_b, + out_c, + cartesian_case.default_sizes[IDim], + cartesian_case.default_sizes[JDim], + cartesian_case.default_sizes[KDim], + inout=((out_a, out_a), out_a), + ref=((a, a), a), + ) + + +@gtx.field_operator +def testee_two_vertical_dims( + a: gtx.Field[[KDim], gtx.float32], b: gtx.Field[[KHalfDim], gtx.float32] +) -> tuple[gtx.Field[[KHalfDim], gtx.float32], gtx.Field[[KDim], gtx.float32]]: + return b, a + + +@gtx.program +def prog_two_vertical_dims( + a: gtx.Field[[KDim], gtx.float32], + b: gtx.Field[[KHalfDim], gtx.float32], + out_a: gtx.Field[[KDim], gtx.float32], + out_b: gtx.Field[[KHalfDim], gtx.float32], + k_size: gtx.int32, + k_half_size: gtx.int32, +): + testee_two_vertical_dims( + a, b, out=(out_b, out_a), domain=({KHalfDim: (0, k_half_size)}, {KDim: (0, k_size)}) + ) + + +def test_program_two_vertical_dims(cartesian_case): + a = cases.allocate(cartesian_case, prog_two_vertical_dims, "a")() + b = cases.allocate(cartesian_case, prog_two_vertical_dims, "b")() + out_a = cases.allocate(cartesian_case, prog_two_vertical_dims, "out_a")() + out_b = cases.allocate(cartesian_case, prog_two_vertical_dims, "out_b")() + + cases.verify( + cartesian_case, + prog_two_vertical_dims, + a, + b, + out_a, + out_b, + cartesian_case.default_sizes[KDim], + cartesian_case.default_sizes[KHalfDim], + inout=(out_b, out_a), + ref=(b, a), + ) + + +@gtx.field_operator +def testee_shift_e2c(a: cases.EField) -> tuple[cases.CField, cases.EField]: + return a(C2E[1]), a + + +@gtx.program +def prog_unstructured( + a: cases.EField, + out_a: cases.EField, + out_a_shifted: cases.CField, + c_size: gtx.int32, + e_size: gtx.int32, +): + testee_shift_e2c( + a, out=(out_a_shifted, out_a), domain=({Cell: (0, c_size)}, {Edge: (0, e_size)}) + ) + + +def test_program_unstructured( + exec_alloc_descriptor, mesh_descriptor +): # TODO: this fails for definitions_numpy, please see test_temporaries_with_sizes.py + unstructured_case = Case( + exec_alloc_descriptor, + offset_provider=mesh_descriptor.offset_provider, + default_sizes={ + Edge: mesh_descriptor.num_edges, + Cell: mesh_descriptor.num_cells, + }, + grid_type=common.GridType.UNSTRUCTURED, + allocator=exec_alloc_descriptor.allocator, + ) + a = cases.allocate(unstructured_case, prog_unstructured, "a")() + out_a = cases.allocate(unstructured_case, prog_unstructured, "out_a")() + out_a_shifted = cases.allocate(unstructured_case, prog_unstructured, "out_a_shifted")() + + cases.verify( + unstructured_case, + prog_unstructured, + a, + out_a, + out_a_shifted, + unstructured_case.default_sizes[Cell], + unstructured_case.default_sizes[Edge], + inout=(out_a_shifted, out_a), + ref=((a.ndarray)[mesh_descriptor.offset_provider["C2E"].asnumpy()[:, 1]], a), + ) + + +@gtx.field_operator +def testee_temporary(a: cases.VField): + edge = a(E2V[1]) + cell = edge(C2E[1]) + return edge, cell + + +@gtx.program +def prog_temporary( + a: cases.VField, + out_edge: cases.EField, + out_cell: cases.CField, + c_size: gtx.int32, + e_size: gtx.int32, +): + testee_temporary( + a, out=(out_edge, out_cell), domain=({Edge: (0, e_size)}, {Cell: (0, c_size)}) + ) # TODO: specify other domain sizes? + + +def test_program_temporary( + exec_alloc_descriptor, mesh_descriptor +): # TODO: this fails for definitions_numpy, please see test_temporaries_with_sizes.py + unstructured_case = Case( + exec_alloc_descriptor, + offset_provider=mesh_descriptor.offset_provider, + default_sizes={ + Edge: mesh_descriptor.num_edges, + Cell: mesh_descriptor.num_cells, + Vertex: mesh_descriptor.num_vertices, + }, + grid_type=common.GridType.UNSTRUCTURED, + allocator=exec_alloc_descriptor.allocator, + ) + a = cases.allocate(unstructured_case, prog_temporary, "a")() + out_edge = cases.allocate(unstructured_case, prog_temporary, "out_edge")() + out_cell = cases.allocate(unstructured_case, prog_temporary, "out_cell")() + + e2v = (a.ndarray)[mesh_descriptor.offset_provider["E2V"].asnumpy()[:, 1]] + cases.verify( + unstructured_case, + prog_temporary, + a, + out_edge, + out_cell, + unstructured_case.default_sizes[Cell], + unstructured_case.default_sizes[Edge], + inout=(out_edge, out_cell), + ref=(e2v, e2v[mesh_descriptor.offset_provider["C2E"].asnumpy()[:, 1]]), + ) + + def test_direct_fo_orig(cartesian_case): a = cases.allocate(cartesian_case, testee_orig, "a")() b = cases.allocate(cartesian_case, testee_orig, "b")() @@ -103,15 +383,14 @@ def test_direct_fo_orig(cartesian_case): b, out=out, ref=(b, a), - domain={IDim: (0, cartesian_case.default_sizes[IDim])} + domain={IDim: (0, cartesian_case.default_sizes[IDim])}, ) + # TODO: -# - test without domain -# - test with nested tuples -# - test with different vertical domains KDim and KHalfDim +# - test with double nested tuples # - test from https://hackmd.io/m__8sBBATiqFWOPNMEPsfg -# - unstructured test with Local dimensions e.g. Vertex, E2V and Edge +# - vertical staggering with dependency # # def test_direct_fo(cartesian_case): @@ -131,5 +410,3 @@ def test_direct_fo_orig(cartesian_case): # {IDim: (0, cartesian_case.default_sizes[IDim])}, # ), # ) - - From 17bf43241551a8c90e89c939b84be945600b3e90 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 1 Sep 2025 17:59:24 +0200 Subject: [PATCH 15/44] Extend for nested tuples --- .../next/ffront/past_passes/type_deduction.py | 29 +++--- src/gt4py/next/ffront/past_to_itir.py | 21 +++-- .../test_multiple_output_domains.py | 88 ++++++++++++++++--- 3 files changed, 108 insertions(+), 30 deletions(-) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 7ef69712fe..3f771f281a 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -56,8 +56,14 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict) -> None: if "domain" in new_kwargs: _ensure_no_sliced_field(new_kwargs["out"]) - def check(dom, out): + def check( + dom: past.Dict | past.TupleExpr, out: past.TupleExpr | past.Name, level: int = 0 + ) -> None: if isinstance(dom, past.Dict): + # Only reject tuple outputs if nested (level > 0) + if level > 0 and (isinstance(out, past.TupleExpr) or isinstance(out, ts.TupleType)): + raise ValueError("Domain dict cannot map to tuple outputs.") + if len(dom.values_) == 0 and len(dom.keys_) == 0: raise ValueError("Empty domain not allowed.") @@ -78,20 +84,21 @@ def check(dom, out): f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'." ) elif isinstance(dom, past.TupleExpr): - if not isinstance(out, past.TupleExpr) and not isinstance(out.type, ts.TupleType): - raise ValueError( - f"TupleExpr are only allowed in 'domain', if '{out}' is a tuple as well." - ) if isinstance(out, past.TupleExpr): out_elts = out.elts - assert len(out_elts) == len(dom.elts) - else: + elif isinstance(out.type, ts.TupleType): out_elts = out.type.types - assert len(out_elts) == len(dom.elts) + else: + raise ValueError(f"Tuple domain requires tuple output, got {type(out)}.") + + if len(dom.elts) != len(out_elts): + raise ValueError("Mismatched tuple lengths between domain and output.") + for d, o in zip(dom.elts, out_elts): - check(d, o) + check(d, o, level=level + 1) + else: - raise ValueError(f"Only Dictionaries allowed in 'domain', got '{type(dom)}'.") + raise ValueError(f"'domain' must be Dict or TupleExpr, got {type(dom)}.") check(new_kwargs["domain"], new_kwargs["out"]) @@ -151,7 +158,7 @@ def visit_TupleExpr(self, node: past.TupleExpr, **kwargs: Any) -> past.TupleExpr if any(isinstance(elt, past.Dict) for elt in elts): assert all(isinstance(elt, (past.Dict, past.TupleExpr)) for elt in elts) - def infer_type(elt): + def infer_type(elt: past.Dict | past.TupleExpr) -> ts.DomainType | ts.TupleType: if isinstance(elt, past.Dict): # TODO: add check that Dict is DomainLike return ts.DomainType(dims=[common.Dimension(elt.keys_[0].id)]) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index bb64dfcd92..24a4be423d 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -475,13 +475,20 @@ def _visit_stencil_call_out_arg( first_field = first_field.value if isinstance(domain_arg, past.TupleExpr): - domain_args = [ # TODO: Test with out as one argument which is a tuple, don't flatten field - self._construct_itir_domain_arg(field, domain, None) - for field, domain in zip( - flattened, _flatten_tuple_expr(domain_arg), strict=True - ) # TODO use field type -> apply_to_primitive_constituents, path -> find relevant Dict expr -> call _construct_itir_domain_arg, test with wrong structure as well - ] - domain_expr = im.make_tuple(*domain_args) + # TODO: Test with out as one argument which is a tuple, don't flatten field + + domain_expr = type_info.apply_to_primitive_constituents( + lambda field_type, path: self._construct_itir_domain_arg( + functools.reduce(lambda e, i: e.elts[i], path, out_arg), + functools.reduce(lambda e, i: e.elts[i], path, domain_arg) + if isinstance(domain_arg, past.TupleExpr) + else domain_arg, + None, + ), + out_arg.type, + with_path_arg=True, + tuple_constructor=im.make_tuple, + ) return self._construct_itir_out_arg(out_arg), domain_expr else: return ( diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py index b4a7df9127..bb047160db 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -9,7 +9,6 @@ import pytest import gt4py.next as gtx - from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( IDim, @@ -176,10 +175,10 @@ def testee_nested_tuples( b: gtx.Field[[JDim], gtx.float32], c: gtx.Field[[KDim], gtx.float32], ) -> tuple[ - tuple[gtx.Field[[IDim], gtx.float32], gtx.Field[[IDim], gtx.float32]], - gtx.Field[[IDim], gtx.float32], + tuple[gtx.Field[[IDim], gtx.float32], gtx.Field[[JDim], gtx.float32]], + gtx.Field[[KDim], gtx.float32], ]: - return (a, a), a + return (a, b), c @gtx.program @@ -198,14 +197,14 @@ def prog_nested_tuples( a, b, c, - out=((out_a, out_a), out_a), - domain=(({IDim: (0, i_size)}, {IDim: (0, i_size)}), {IDim: (0, i_size)}), - ) # TODO: use JDim, KDim + out=((out_a, out_b), out_c), + domain=(({IDim: (0, i_size)}, {JDim: (0, j_size)}), {KDim: (0, k_size)}), + ) def test_program_nested_tuples( cartesian_case, -): # TODO: this fails for most backends, merge PR #1893 first +): a = cases.allocate(cartesian_case, prog_nested_tuples, "a")() b = cases.allocate(cartesian_case, prog_nested_tuples, "b")() c = cases.allocate(cartesian_case, prog_nested_tuples, "c")() @@ -225,8 +224,74 @@ def test_program_nested_tuples( cartesian_case.default_sizes[IDim], cartesian_case.default_sizes[JDim], cartesian_case.default_sizes[KDim], - inout=((out_a, out_a), out_a), - ref=((a, a), a), + inout=((out_a, out_b), out_c), + ref=((a, b), c), + ) + + +@gtx.field_operator +def testee_double_nested_tuples( + a: gtx.Field[[IDim], gtx.float32], + b: gtx.Field[[JDim], gtx.float32], + c: gtx.Field[[KDim], gtx.float32], +) -> tuple[ + tuple[ + gtx.Field[[IDim], gtx.float32], + tuple[gtx.Field[[JDim], gtx.float32], gtx.Field[[KDim], gtx.float32]], + ], + gtx.Field[[KDim], gtx.float32], +]: + return (a, (b, c)), c + + +@gtx.program +def prog_double_nested_tuples( + a: gtx.Field[[IDim], gtx.float32], + b: gtx.Field[[JDim], gtx.float32], + c: gtx.Field[[KDim], gtx.float32], + out_a: gtx.Field[[IDim], gtx.float32], + out_b: gtx.Field[[JDim], gtx.float32], + out_c: gtx.Field[[KDim], gtx.float32], + i_size: gtx.int32, + j_size: gtx.int32, + k_size: gtx.int32, +): + testee_double_nested_tuples( + a, + b, + c, + out=((out_a, (out_b, out_c)), out_c), + domain=( + ({IDim: (0, i_size)}, ({JDim: (0, j_size)}, {KDim: (0, k_size)})), + {KDim: (0, k_size)}, + ), + ) + + +def test_program_double_nested_tuples( + cartesian_case, +): + a = cases.allocate(cartesian_case, prog_double_nested_tuples, "a")() + b = cases.allocate(cartesian_case, prog_double_nested_tuples, "b")() + c = cases.allocate(cartesian_case, prog_double_nested_tuples, "c")() + out_a = cases.allocate(cartesian_case, prog_double_nested_tuples, "out_a")() + out_b = cases.allocate(cartesian_case, prog_double_nested_tuples, "out_b")() + out_c = cases.allocate(cartesian_case, prog_double_nested_tuples, "out_c")() + + cases.verify( + cartesian_case, + prog_double_nested_tuples, + a, + b, + c, + out_a, + out_b, + out_c, + cartesian_case.default_sizes[IDim], + cartesian_case.default_sizes[JDim], + cartesian_case.default_sizes[KDim], + inout=((out_a, (out_b, out_c)), out_c), + ref=((a, (b, c)), c), ) @@ -388,9 +453,8 @@ def test_direct_fo_orig(cartesian_case): # TODO: -# - test with double nested tuples -# - test from https://hackmd.io/m__8sBBATiqFWOPNMEPsfg # - vertical staggering with dependency +# - cleanup and refactor tests # # def test_direct_fo(cartesian_case): From 87fa841eb421d51c69d7542495682b40f2a5b1dd Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 1 Sep 2025 20:55:03 +0200 Subject: [PATCH 16/44] Cleanup tests --- .../test_multiple_output_domains.py | 114 +++++++++--------- 1 file changed, 58 insertions(+), 56 deletions(-) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py index bb047160db..3fda90f7d1 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -18,10 +18,16 @@ E2V, V2E, Edge, + EField, + CField, + VField, Cell, Vertex, cartesian_case, Case, + IField, + JField, + KField, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, @@ -35,18 +41,16 @@ @gtx.field_operator -def testee_orig( - a: gtx.Field[[IDim], gtx.float32], b: gtx.Field[[IDim], gtx.float32] -) -> tuple[gtx.Field[[IDim], gtx.float32], gtx.Field[[IDim], gtx.float32]]: +def testee_orig(a: IField, b: IField) -> tuple[IField, IField]: return b, a @gtx.program def prog_orig( - a: gtx.Field[[IDim], gtx.float32], - b: gtx.Field[[IDim], gtx.float32], - out_a: gtx.Field[[IDim], gtx.float32], - out_b: gtx.Field[[IDim], gtx.float32], + a: IField, + b: IField, + out_a: IField, + out_b: IField, i_size: gtx.int32, ): testee_orig(a, b, out=(out_b, out_a), domain={IDim: (0, i_size)}) @@ -73,10 +77,10 @@ def test_program_orig(cartesian_case): @gtx.program def prog_no_domain( - a: gtx.Field[[IDim], gtx.float32], - b: gtx.Field[[IDim], gtx.float32], - out_a: gtx.Field[[IDim], gtx.float32], - out_b: gtx.Field[[IDim], gtx.float32], + a: IField, + b: IField, + out_a: IField, + out_b: IField, ): testee_orig(a, b, out=(out_b, out_a)) @@ -100,18 +104,16 @@ def test_program_no_domain(cartesian_case): @gtx.field_operator -def testee( - a: gtx.Field[[IDim], gtx.float32], b: gtx.Field[[JDim], gtx.float32] -) -> tuple[gtx.Field[[JDim], gtx.float32], gtx.Field[[IDim], gtx.float32]]: +def testee(a: IField, b: JField) -> tuple[JField, IField]: return b, a @gtx.program def prog( - a: gtx.Field[[IDim], gtx.float32], - b: gtx.Field[[JDim], gtx.float32], - out_a: gtx.Field[[IDim], gtx.float32], - out_b: gtx.Field[[JDim], gtx.float32], + a: IField, + b: JField, + out_a: IField, + out_b: JField, i_size: gtx.int32, j_size: gtx.int32, ): @@ -140,9 +142,9 @@ def test_program(cartesian_case): @gtx.program def prog_out_as_tuple( - a: gtx.Field[[IDim], gtx.float32], - b: gtx.Field[[JDim], gtx.float32], - out: tuple[gtx.Field[[JDim], gtx.float32], gtx.Field[[IDim], gtx.float32]], + a: IField, + b: JField, + out: tuple[JField, IField], i_size: gtx.int32, j_size: gtx.int32, ): @@ -171,24 +173,24 @@ def test_program_out_as_tuple( @gtx.field_operator def testee_nested_tuples( - a: gtx.Field[[IDim], gtx.float32], - b: gtx.Field[[JDim], gtx.float32], - c: gtx.Field[[KDim], gtx.float32], + a: IField, + b: JField, + c: KField, ) -> tuple[ - tuple[gtx.Field[[IDim], gtx.float32], gtx.Field[[JDim], gtx.float32]], - gtx.Field[[KDim], gtx.float32], + tuple[IField, JField], + KField, ]: return (a, b), c @gtx.program def prog_nested_tuples( - a: gtx.Field[[IDim], gtx.float32], - b: gtx.Field[[JDim], gtx.float32], - c: gtx.Field[[KDim], gtx.float32], - out_a: gtx.Field[[IDim], gtx.float32], - out_b: gtx.Field[[JDim], gtx.float32], - out_c: gtx.Field[[KDim], gtx.float32], + a: IField, + b: JField, + c: KField, + out_a: IField, + out_b: JField, + out_c: KField, i_size: gtx.int32, j_size: gtx.int32, k_size: gtx.int32, @@ -231,27 +233,27 @@ def test_program_nested_tuples( @gtx.field_operator def testee_double_nested_tuples( - a: gtx.Field[[IDim], gtx.float32], - b: gtx.Field[[JDim], gtx.float32], - c: gtx.Field[[KDim], gtx.float32], + a: IField, + b: JField, + c: KField, ) -> tuple[ tuple[ - gtx.Field[[IDim], gtx.float32], - tuple[gtx.Field[[JDim], gtx.float32], gtx.Field[[KDim], gtx.float32]], + IField, + tuple[JField, KField], ], - gtx.Field[[KDim], gtx.float32], + KField, ]: return (a, (b, c)), c @gtx.program def prog_double_nested_tuples( - a: gtx.Field[[IDim], gtx.float32], - b: gtx.Field[[JDim], gtx.float32], - c: gtx.Field[[KDim], gtx.float32], - out_a: gtx.Field[[IDim], gtx.float32], - out_b: gtx.Field[[JDim], gtx.float32], - out_c: gtx.Field[[KDim], gtx.float32], + a: IField, + b: JField, + c: KField, + out_a: IField, + out_b: JField, + out_c: KField, i_size: gtx.int32, j_size: gtx.int32, k_size: gtx.int32, @@ -297,16 +299,16 @@ def test_program_double_nested_tuples( @gtx.field_operator def testee_two_vertical_dims( - a: gtx.Field[[KDim], gtx.float32], b: gtx.Field[[KHalfDim], gtx.float32] -) -> tuple[gtx.Field[[KHalfDim], gtx.float32], gtx.Field[[KDim], gtx.float32]]: + a: KField, b: gtx.Field[[KHalfDim], gtx.float32] +) -> tuple[gtx.Field[[KHalfDim], gtx.float32], KField]: return b, a @gtx.program def prog_two_vertical_dims( - a: gtx.Field[[KDim], gtx.float32], + a: KField, b: gtx.Field[[KHalfDim], gtx.float32], - out_a: gtx.Field[[KDim], gtx.float32], + out_a: KField, out_b: gtx.Field[[KHalfDim], gtx.float32], k_size: gtx.int32, k_half_size: gtx.int32, @@ -337,15 +339,15 @@ def test_program_two_vertical_dims(cartesian_case): @gtx.field_operator -def testee_shift_e2c(a: cases.EField) -> tuple[cases.CField, cases.EField]: +def testee_shift_e2c(a: EField) -> tuple[CField, EField]: return a(C2E[1]), a @gtx.program def prog_unstructured( - a: cases.EField, - out_a: cases.EField, - out_a_shifted: cases.CField, + a: EField, + out_a: EField, + out_a_shifted: CField, c_size: gtx.int32, e_size: gtx.int32, ): @@ -385,7 +387,7 @@ def test_program_unstructured( @gtx.field_operator -def testee_temporary(a: cases.VField): +def testee_temporary(a: VField): edge = a(E2V[1]) cell = edge(C2E[1]) return edge, cell @@ -393,9 +395,9 @@ def testee_temporary(a: cases.VField): @gtx.program def prog_temporary( - a: cases.VField, - out_edge: cases.EField, - out_cell: cases.CField, + a: VField, + out_edge: EField, + out_cell: CField, c_size: gtx.int32, e_size: gtx.int32, ): From 54dde3077ca9f4d0dec4639bb535c60252421f3f Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 2 Sep 2025 16:41:53 +0200 Subject: [PATCH 17/44] Extend to also work for out arg that is a tuple --- src/gt4py/next/ffront/past_to_itir.py | 32 ++++-- .../next/iterator/transforms/pass_manager.py | 4 +- .../test_multiple_output_domains.py | 102 ++++++++++++------ 3 files changed, 95 insertions(+), 43 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 24a4be423d..4afb8fd8d4 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -327,14 +327,17 @@ def _construct_itir_domain_arg( slices: Optional[list[past.Slice]] = None, ) -> itir.FunCall: assert isinstance(out_field.type, ts.TypeSpec) - out_field_types = type_info.primitive_constituents(out_field.type).to_list() + out_field_types = type_info.primitive_constituents( + out_field.type if hasattr(out_field, "type") else out_field + ).to_list() + out_dims = cast(ts.FieldType, out_field_types[0]).dims if any( not isinstance(out_field_type, ts.FieldType) or out_field_type.dims != out_dims for out_field_type in out_field_types - ): # TODO + ): raise AssertionError( - f"Expected constituents of '{out_field.id}' argument to be" + f"Expected constituents of '{getattr(out_field, 'id', out_field)}' argument to be" " fields defined on the same dimensions. This error should be " " caught in type deduction already." ) @@ -348,7 +351,6 @@ def _construct_itir_domain_arg( ) domain_args = [] - domain_args_kind = [] for dim_i, dim in enumerate(out_dims): # an expression for the range of a dimension dim_range = im.call("get_domain_range")( @@ -383,7 +385,6 @@ def _construct_itir_domain_arg( args=[itir.AxisLiteral(value=dim.value, kind=dim.kind), lower, upper], ) ) - domain_args_kind.append(dim.kind) if self.grid_type == common.GridType.CARTESIAN: domain_builtin = "cartesian_domain" @@ -448,16 +449,17 @@ def _visit_stencil_call_out_arg( out_field_name, domain_arg, self._compute_field_slice(out_arg) ), ) - elif isinstance(out_arg, past.Name): + elif isinstance(out_arg, past.Name) and isinstance(out_arg.type, ts.FieldType): return ( self._construct_itir_out_arg(out_arg), self._construct_itir_domain_arg(out_arg, domain_arg), ) - elif isinstance(out_arg, past.TupleExpr): + elif isinstance(out_arg, past.TupleExpr) or ( + isinstance(out_arg, past.Name) and isinstance(out_arg.type, ts.TupleType) + ): flattened = _flatten_tuple_expr(out_arg) first_field = flattened[0] - field_slice = None if isinstance(first_field, past.Subscript): raise AssertionError # TODO support slicing of multiple fields with different domain @@ -475,11 +477,19 @@ def _visit_stencil_call_out_arg( first_field = first_field.value if isinstance(domain_arg, past.TupleExpr): - # TODO: Test with out as one argument which is a tuple, don't flatten field - domain_expr = type_info.apply_to_primitive_constituents( lambda field_type, path: self._construct_itir_domain_arg( - functools.reduce(lambda e, i: e.elts[i], path, out_arg), + functools.reduce( + lambda e, i: ( + e.elts[i] + if isinstance(e, past.TupleExpr) + else past.Name(type=e.type.types[i], id=e.id, location=e.location) + if isinstance(e, past.Name) and isinstance(e.type, ts.TupleType) + else e + ), + path, + out_arg, + ), functools.reduce(lambda e, i: e.elts[i], path, domain_arg) if isinstance(domain_arg, past.TupleExpr) else domain_arg, diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index e8ecdedc8e..d5fa84d4c0 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -143,10 +143,10 @@ def apply_common_transforms( # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can # only run the unconditional version here instead of in the loop above. - if unconditionally_collapse_tuples: + if unconditionally_collapse_tuples: # TODO(sf-n): delete this? ir = CollapseTuple.apply( ir, - ignore_tuple_size=True, + ignore_tuple_size=False, uids=collapse_tuple_uids, enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, offset_provider_type=offset_provider_type, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py index 3fda90f7d1..25c26af333 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -24,6 +24,7 @@ Cell, Vertex, cartesian_case, + unstructured_case, Case, IField, JField, @@ -40,6 +41,41 @@ pytestmark = pytest.mark.uses_cartesian_shift +@gtx.field_operator +def testee_no_tuple(a: IField, b: IField) -> IField: + return a + + +@gtx.program +def prog_no_tuple( + a: IField, + b: IField, + out_a: IField, + out_b: IField, + i_size: gtx.int32, +): + testee_no_tuple(a, b, out=out_a, domain={IDim: (0, i_size)}) + + +def test_program_no_tuple(cartesian_case): + a = cases.allocate(cartesian_case, prog_no_tuple, "a")() + b = cases.allocate(cartesian_case, prog_no_tuple, "b")() + out_a = cases.allocate(cartesian_case, prog_no_tuple, "out_a")() + out_b = cases.allocate(cartesian_case, prog_no_tuple, "out_b")() + + cases.verify( + cartesian_case, + prog_no_tuple, + a, + b, + out_a, + out_b, + cartesian_case.default_sizes[IDim], + inout=out_a, + ref=a, + ) + + @gtx.field_operator def testee_orig(a: IField, b: IField) -> tuple[IField, IField]: return b, a @@ -108,6 +144,36 @@ def testee(a: IField, b: JField) -> tuple[JField, IField]: return b, a +@gtx.program +def prog_no_domain_differnet_fields( + a: IField, + b: JField, + out_a: IField, + out_b: JField, +): + testee(a, b, out=(out_b, out_a)) + + +def test_program_no_domain_different_fields( + cartesian_case, +): # TODO: this still fails for some backends + a = cases.allocate(cartesian_case, prog_no_domain_differnet_fields, "a")() + b = cases.allocate(cartesian_case, prog_no_domain_differnet_fields, "b")() + out_a = cases.allocate(cartesian_case, prog_no_domain_differnet_fields, "out_a")() + out_b = cases.allocate(cartesian_case, prog_no_domain_differnet_fields, "out_b")() + + cases.verify( + cartesian_case, + prog_no_domain_differnet_fields, + a, + b, + out_a, + out_b, + inout=(out_b, out_a), + ref=(b, a), + ) + + @gtx.program def prog( a: IField, @@ -356,19 +422,7 @@ def prog_unstructured( ) -def test_program_unstructured( - exec_alloc_descriptor, mesh_descriptor -): # TODO: this fails for definitions_numpy, please see test_temporaries_with_sizes.py - unstructured_case = Case( - exec_alloc_descriptor, - offset_provider=mesh_descriptor.offset_provider, - default_sizes={ - Edge: mesh_descriptor.num_edges, - Cell: mesh_descriptor.num_cells, - }, - grid_type=common.GridType.UNSTRUCTURED, - allocator=exec_alloc_descriptor.allocator, - ) +def test_program_unstructured(unstructured_case): a = cases.allocate(unstructured_case, prog_unstructured, "a")() out_a = cases.allocate(unstructured_case, prog_unstructured, "out_a")() out_a_shifted = cases.allocate(unstructured_case, prog_unstructured, "out_a_shifted")() @@ -382,7 +436,7 @@ def test_program_unstructured( unstructured_case.default_sizes[Cell], unstructured_case.default_sizes[Edge], inout=(out_a_shifted, out_a), - ref=((a.ndarray)[mesh_descriptor.offset_provider["C2E"].asnumpy()[:, 1]], a), + ref=((a.ndarray)[unstructured_case.offset_provider["C2E"].asnumpy()[:, 1]], a), ) @@ -406,25 +460,12 @@ def prog_temporary( ) # TODO: specify other domain sizes? -def test_program_temporary( - exec_alloc_descriptor, mesh_descriptor -): # TODO: this fails for definitions_numpy, please see test_temporaries_with_sizes.py - unstructured_case = Case( - exec_alloc_descriptor, - offset_provider=mesh_descriptor.offset_provider, - default_sizes={ - Edge: mesh_descriptor.num_edges, - Cell: mesh_descriptor.num_cells, - Vertex: mesh_descriptor.num_vertices, - }, - grid_type=common.GridType.UNSTRUCTURED, - allocator=exec_alloc_descriptor.allocator, - ) +def test_program_temporary(unstructured_case): a = cases.allocate(unstructured_case, prog_temporary, "a")() out_edge = cases.allocate(unstructured_case, prog_temporary, "out_edge")() out_cell = cases.allocate(unstructured_case, prog_temporary, "out_cell")() - e2v = (a.ndarray)[mesh_descriptor.offset_provider["E2V"].asnumpy()[:, 1]] + e2v = (a.ndarray)[unstructured_case.offset_provider["E2V"].asnumpy()[:, 1]] cases.verify( unstructured_case, prog_temporary, @@ -434,7 +475,7 @@ def test_program_temporary( unstructured_case.default_sizes[Cell], unstructured_case.default_sizes[Edge], inout=(out_edge, out_cell), - ref=(e2v, e2v[mesh_descriptor.offset_provider["C2E"].asnumpy()[:, 1]]), + ref=(e2v, e2v[unstructured_case.offset_provider["C2E"].asnumpy()[:, 1]]), ) @@ -456,6 +497,7 @@ def test_direct_fo_orig(cartesian_case): # TODO: # - vertical staggering with dependency +# - test with different sizes (explicit domain and no domain) (use extend) # - cleanup and refactor tests # From c2805143661d802a47fa9b2ad0dab23b5cb15c3b Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 5 Sep 2025 17:54:23 +0200 Subject: [PATCH 18/44] Add tests with restricted domain and extend to construct domain tuple if output is a tuple --- src/gt4py/next/ffront/past_to_itir.py | 46 ++++---- .../test_multiple_output_domains.py | 111 ++++++++++++++++-- 2 files changed, 122 insertions(+), 35 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 3ca3bb58d1..a790466904 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -474,35 +474,31 @@ def _visit_stencil_call_out_arg( field_slice = self._compute_field_slice(first_field) first_field = first_field.value - if isinstance(domain_arg, past.TupleExpr): - domain_expr = type_info.apply_to_primitive_constituents( - lambda field_type, path: self._construct_itir_domain_arg( - functools.reduce( - lambda e, i: ( - e.elts[i] - if isinstance(e, past.TupleExpr) - else past.Name(type=e.type.types[i], id=e.id, location=e.location) - if isinstance(e, past.Name) and isinstance(e.type, ts.TupleType) - else e - ), - path, - out_arg, + domain_expr = type_info.apply_to_primitive_constituents( + lambda field_type, path: self._construct_itir_domain_arg( + functools.reduce( + lambda e, i: ( + e.elts[i] + if isinstance(e, past.TupleExpr) + else past.Name(type=e.type.types[i], id=e.id, location=e.location) + if isinstance(e, past.Name) and isinstance(e.type, ts.TupleType) + else e ), + path, + out_arg, + ), + ( functools.reduce(lambda e, i: e.elts[i], path, domain_arg) if isinstance(domain_arg, past.TupleExpr) - else domain_arg, - None, + else domain_arg ), - out_arg.type, - with_path_arg=True, - tuple_constructor=im.make_tuple, - ) - return self._construct_itir_out_arg(out_arg), domain_expr - else: - return ( - self._construct_itir_out_arg(out_arg), - self._construct_itir_domain_arg(first_field, domain_arg, field_slice), - ) + None, # TODO: support slicing + ), + out_arg.type, + with_path_arg=True, + tuple_constructor=im.make_tuple, + ) + return self._construct_itir_out_arg(out_arg), domain_expr else: raise AssertionError( "Unexpected 'out' argument. Must be a 'past.Subscript', 'past.Name' or 'past.TupleExpr' node." diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py index 25c26af333..138d520b6b 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import copy import pytest import gt4py.next as gtx @@ -156,7 +157,7 @@ def prog_no_domain_differnet_fields( def test_program_no_domain_different_fields( cartesian_case, -): # TODO: this still fails for some backends +): a = cases.allocate(cartesian_case, prog_no_domain_differnet_fields, "a")() b = cases.allocate(cartesian_case, prog_no_domain_differnet_fields, "b")() out_a = cases.allocate(cartesian_case, prog_no_domain_differnet_fields, "out_a")() @@ -219,7 +220,7 @@ def prog_out_as_tuple( def test_program_out_as_tuple( cartesian_case, -): # TODO: this fails for most backends, merge PR #1893 first +): a = cases.allocate(cartesian_case, prog_out_as_tuple, "a")() b = cases.allocate(cartesian_case, prog_out_as_tuple, "b")() out = cases.allocate(cartesian_case, prog_out_as_tuple, "out")() @@ -237,6 +238,65 @@ def test_program_out_as_tuple( ) +@gtx.program +def prog_out_as_tuple_different_sizes( + a: IField, + b: JField, + out: tuple[JField, IField], + i_size: gtx.int32, + j_size: gtx.int32, + restrict_i_0: gtx.int32, + restrict_i_1: gtx.int32, + restrict_j_0: gtx.int32, + restrict_j_1: gtx.int32, +): + testee( + a, + b, + out=out, + domain=( + {JDim: (restrict_j_0, j_size + restrict_j_1)}, + {IDim: (restrict_i_0, i_size + restrict_i_1)}, + ), + ) + + +def test_program_out_as_tuple_different_sizes( + cartesian_case, +): + restrict_i = (1, -3) + restrict_j = (2, -4) + i_size = cartesian_case.default_sizes[IDim] + j_size = cartesian_case.default_sizes[JDim] + a = cases.allocate(cartesian_case, prog_out_as_tuple_different_sizes, "a")() + b = cases.allocate(cartesian_case, prog_out_as_tuple_different_sizes, "b")() + out = cases.allocate( + cartesian_case, + prog_out_as_tuple_different_sizes, + "out", + extend={IDim: (-restrict_i[0], restrict_i[1]), JDim: (-restrict_j[0], restrict_j[1])}, + )() + + cases.verify( + cartesian_case, + prog_out_as_tuple_different_sizes, + a, + b, + out, + i_size, + j_size, + restrict_i[0], + restrict_i[1], + restrict_j[0], + restrict_j[1], + inout=(out), + ref=( + b.ndarray[restrict_j[0] : j_size + restrict_j[1]], + a.ndarray[restrict_i[0] : i_size + restrict_i[1]], + ), + ) + + @gtx.field_operator def testee_nested_tuples( a: IField, @@ -454,16 +514,39 @@ def prog_temporary( out_cell: CField, c_size: gtx.int32, e_size: gtx.int32, + restrict_edge_0: gtx.int32, + restrict_edge_1: gtx.int32, + restrict_cell_0: gtx.int32, + restrict_cell_1: gtx.int32, ): testee_temporary( - a, out=(out_edge, out_cell), domain=({Edge: (0, e_size)}, {Cell: (0, c_size)}) - ) # TODO: specify other domain sizes? + a, + out=(out_edge, out_cell), + domain=( + {Edge: (restrict_edge_0, e_size + restrict_edge_1)}, + {Cell: (restrict_cell_0, c_size + restrict_cell_1)}, + ), + ) def test_program_temporary(unstructured_case): + restrict_edge = (4, -2) + restrict_cell = (3, -1) + cell_size = unstructured_case.default_sizes[Cell] + edge_size = unstructured_case.default_sizes[Edge] a = cases.allocate(unstructured_case, prog_temporary, "a")() - out_edge = cases.allocate(unstructured_case, prog_temporary, "out_edge")() - out_cell = cases.allocate(unstructured_case, prog_temporary, "out_cell")() + out_edge = cases.allocate( + unstructured_case, + prog_temporary, + "out_edge", + extend={Edge: (-restrict_edge[0], restrict_edge[1])}, + )() + out_cell = cases.allocate( + unstructured_case, + prog_temporary, + "out_cell", + extend={Cell: (-restrict_cell[0], restrict_cell[1])}, + )() e2v = (a.ndarray)[unstructured_case.offset_provider["E2V"].asnumpy()[:, 1]] cases.verify( @@ -472,10 +555,19 @@ def test_program_temporary(unstructured_case): a, out_edge, out_cell, - unstructured_case.default_sizes[Cell], - unstructured_case.default_sizes[Edge], + cell_size, + edge_size, + restrict_edge[0], + restrict_edge[1], + restrict_cell[0], + restrict_cell[1], inout=(out_edge, out_cell), - ref=(e2v, e2v[unstructured_case.offset_provider["C2E"].asnumpy()[:, 1]]), + ref=( + e2v[restrict_edge[0] : edge_size + restrict_edge[1]], + e2v[unstructured_case.offset_provider["C2E"].asnumpy()[:, 1]][ + restrict_cell[0] : cell_size + restrict_cell[1] + ], + ), ) @@ -497,7 +589,6 @@ def test_direct_fo_orig(cartesian_case): # TODO: # - vertical staggering with dependency -# - test with different sizes (explicit domain and no domain) (use extend) # - cleanup and refactor tests # From 002b4c87bbcf5f8e7f5733815f396f06dacc262e Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 9 Sep 2025 16:12:10 +0200 Subject: [PATCH 19/44] Clean up --- src/gt4py/next/ffront/past_to_itir.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 8904d3bacb..32fdffe77b 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -14,7 +14,7 @@ import devtools -from gt4py.eve import NodeTranslator, concepts, traits, utils as eve_utils +from gt4py.eve import NodeTranslator, concepts, traits from gt4py.next import common, config, errors from gt4py.next.ffront import ( fbuiltins, @@ -314,14 +314,12 @@ def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: def _construct_itir_domain_arg( self, - out_field: past.Name | past.Subscript | past.Dict, + out_field: past.Name, node_domain: Optional[past.Expr], slices: Optional[list[past.Slice]] = None, ) -> itir.FunCall: assert isinstance(out_field.type, ts.TypeSpec) - out_field_types = type_info.primitive_constituents( - out_field.type if hasattr(out_field, "type") else out_field - ).to_list() + out_field_types = type_info.primitive_constituents(out_field.type).to_list() out_dims = cast(ts.FieldType, out_field_types[0]).dims if any( @@ -329,24 +327,16 @@ def _construct_itir_domain_arg( for out_field_type in out_field_types ): raise AssertionError( - f"Expected constituents of '{getattr(out_field, 'id', out_field)}' argument to be" + f"Expected constituents of '{out_field.id}' argument to be" " fields defined on the same dimensions. This error should be " " caught in type deduction already." ) - # if the out_field is a (potentially nested) tuple we get the domain from its first - # element - first_out_el_path = eve_utils.first( - type_info.primitive_constituents(out_field.type, with_path_arg=True) - )[1] - first_out_el = functools.reduce( - lambda expr, i: im.tuple_get(i, expr), first_out_el_path, out_field.id - ) domain_args = [] for dim_i, dim in enumerate(out_dims): # an expression for the range of a dimension dim_range = im.call("get_domain_range")( - first_out_el, itir.AxisLiteral(value=dim.value, kind=dim.kind) + out_field.id, itir.AxisLiteral(value=dim.value, kind=dim.kind) ) dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range) # bounds From 1e865190316b322b0fd7093501d89c3bd102f553 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 23 Sep 2025 15:16:50 +0200 Subject: [PATCH 20/44] Extend and refactor to fix tests --- src/gt4py/next/ffront/past_to_itir.py | 125 +++++++++++++------------- 1 file changed, 63 insertions(+), 62 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 32fdffe77b..c7bea8d12d 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -14,7 +14,7 @@ import devtools -from gt4py.eve import NodeTranslator, concepts, traits +from gt4py.eve import NodeTranslator, traits from gt4py.next import common, config, errors from gt4py.next.ffront import ( fbuiltins, @@ -331,42 +331,53 @@ def _construct_itir_domain_arg( " fields defined on the same dimensions. This error should be " " caught in type deduction already." ) + primitive_paths = [ + path for _, path in type_info.primitive_constituents(out_field.type, with_path_arg=True) + ] + tuple_elements = [ + functools.reduce(lambda expr, i: im.tuple_get(i, expr), path, out_field.id) + for path in primitive_paths + ] domain_args = [] - for dim_i, dim in enumerate(out_dims): - # an expression for the range of a dimension - dim_range = im.call("get_domain_range")( - out_field.id, itir.AxisLiteral(value=dim.value, kind=dim.kind) - ) - dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range) - # bounds - lower: itir.Expr - upper: itir.Expr - if node_domain is not None: - assert isinstance(node_domain, past.Dict) - lower, upper = self._construct_itir_initialized_domain_arg(dim_i, dim, node_domain) - else: - lower = self._visit_slice_bound( - slices[dim_i].lower if slices else None, - dim_start, - dim_start, - dim_stop, - ) - upper = self._visit_slice_bound( - slices[dim_i].upper if slices else None, - dim_stop, - dim_start, - dim_stop, + for el in tuple_elements: + for dim_i, dim in enumerate(out_dims): + # an expression for the range of a dimension + dim_range = im.call("get_domain_range")( + el, itir.AxisLiteral(value=dim.value, kind=dim.kind) ) - if dim.kind == common.DimensionKind.LOCAL: - raise ValueError(f"common.Dimension '{dim.value}' must not be local.") - domain_args.append( - itir.FunCall( - fun=itir.SymRef(id="named_range"), - args=[itir.AxisLiteral(value=dim.value, kind=dim.kind), lower, upper], + dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range) + # bounds + lower: itir.Expr + upper: itir.Expr + if node_domain is not None: + assert isinstance(node_domain, past.Dict) + lower, upper = self._construct_itir_initialized_domain_arg( + dim_i, dim, node_domain + ) + else: + lower = self._visit_slice_bound( + slices[dim_i].lower if slices else None, + dim_start, + dim_start, + dim_stop, + ) + upper = self._visit_slice_bound( + slices[dim_i].upper if slices else None, + dim_stop, + dim_start, + dim_stop, + ) + + if dim.kind == common.DimensionKind.LOCAL: + raise ValueError(f"common.Dimension '{dim.value}' must not be local.") + domain_args.append( + itir.FunCall( + fun=itir.SymRef(id="named_range"), + args=[itir.AxisLiteral(value=dim.value, kind=dim.kind), lower, upper], + ) ) - ) if self.grid_type == common.GridType.CARTESIAN: domain_builtin = "cartesian_domain" @@ -439,38 +450,28 @@ def _visit_stencil_call_out_arg( elif isinstance(out_arg, past.TupleExpr) or ( isinstance(out_arg, past.Name) and isinstance(out_arg.type, ts.TupleType) ): - flattened = _flatten_tuple_expr(out_arg) - - first_field = flattened[0] - field_slice = None - if isinstance(first_field, past.Subscript): - raise AssertionError # TODO support slicing of multiple fields with different domain - assert all(isinstance(field, past.Subscript) for field in flattened), ( - "Incompatible field in tuple: either all fields or no field must be sliced." - ) - assert all( - concepts.eq_nonlocated( - first_field.slice_, - field.slice_, - ) - for field in flattened - ), "Incompatible field in tuple: all fields must be sliced in the same way." - field_slice = self._compute_field_slice(first_field) - first_field = first_field.value - domain_expr = type_info.apply_to_primitive_constituents( lambda field_type, path: self._construct_itir_domain_arg( - functools.reduce( - lambda e, i: ( - e.elts[i] - if isinstance(e, past.TupleExpr) - else past.Name(type=e.type.types[i], id=e.id, location=e.location) - if isinstance(e, past.Name) and isinstance(e.type, ts.TupleType) - else e - ), - path, - out_arg, - ), + functools.reduce(lambda e, i: e.elts[i], path, out_arg), + functools.reduce(lambda e, i: e.elts[i], path, domain_arg) + if isinstance(domain_arg, past.TupleExpr) + else domain_arg, + None, # TODO: support slicing + ) + if isinstance(out_arg, past.TupleExpr) + else self._construct_itir_domain_arg( + # Create a temporary past.Name-like object that carries the indexed information + type( + "SyntheticElement", + (), + { + "id": functools.reduce( + lambda expr, i: im.tuple_get(i, expr), path, out_arg.id + ), + "type": field_type, + "location": out_arg.location, + }, + )(), ( functools.reduce(lambda e, i: e.elts[i], path, domain_arg) if isinstance(domain_arg, past.TupleExpr) From ef9a29e2d559e38c6fb6c6728908da6d383714e8 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 23 Sep 2025 18:44:36 +0200 Subject: [PATCH 21/44] Fix several tests --- src/gt4py/next/ffront/decorator.py | 1 + src/gt4py/next/ffront/past_to_itir.py | 11 +++-- .../next/iterator/transforms/infer_domain.py | 2 +- tests/next_tests/integration_tests/cases.py | 2 +- .../ffront_tests/test_execution.py | 45 ++++++++++++------- .../ffront_tests/test_import_from_mod.py | 20 +++++---- .../ffront_tests/test_math_unary_builtins.py | 4 +- .../feature_tests/ffront_tests/test_where.py | 16 +++++-- .../feature_tests/test_util_cases.py | 4 +- .../ffront_tests/test_func_to_past.py | 2 +- 10 files changed, 70 insertions(+), 37 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 8cde17ffc7..48322abc64 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -636,6 +636,7 @@ def program_inner(definition: types.FunctionType) -> Program: # else: # return outs[common.domain(domains)] + @dataclasses.dataclass(frozen=True) class FieldOperator(GTCallable, Generic[OperatorNodeT]): """ diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index c7bea8d12d..4931cf186a 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -314,7 +314,7 @@ def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: def _construct_itir_domain_arg( self, - out_field: past.Name, + out_field: past.Name | past.Subscript, node_domain: Optional[past.Expr], slices: Optional[list[past.Slice]] = None, ) -> itir.FunCall: @@ -334,8 +334,13 @@ def _construct_itir_domain_arg( primitive_paths = [ path for _, path in type_info.primitive_constituents(out_field.type, with_path_arg=True) ] + assert isinstance(out_field, (past.Name, past.Subscript)) or (hasattr(out_field, "id")) tuple_elements = [ - functools.reduce(lambda expr, i: im.tuple_get(i, expr), path, out_field.id) + functools.reduce( + lambda expr, i: im.tuple_get(i, expr), + path, + out_field.value.id if isinstance(out_field, past.Subscript) else out_field.id, + ) for path in primitive_paths ] @@ -462,7 +467,7 @@ def _visit_stencil_call_out_arg( else self._construct_itir_domain_arg( # Create a temporary past.Name-like object that carries the indexed information type( - "SyntheticElement", + "NameLikeObject", (), { "id": functools.reduce( diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 7fa5151e3a..8c3cc2f3d0 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -483,7 +483,7 @@ def infer_expr( domain, fill_value=DomainAccessDescriptor.NEVER, # el_types already has the right structure, we only want to change domain - bidirectional=False, + bidirectional=False if not isinstance(expr.type, ts.DeferredType) else True, ) if cpm.is_applied_as_fieldop(expr) and cpm.is_call_to(expr.fun.args[0], "scan"): diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 19e8e3ea6d..afd93fe427 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -567,7 +567,7 @@ def unstructured_case( def unstructured_case_3d(unstructured_case): return dataclasses.replace( unstructured_case, - default_sizes={**unstructured_case.default_sizes, KDim: 10}, + default_sizes={**unstructured_case.default_sizes}, offset_provider={**unstructured_case.offset_provider, "Koff": KDim}, ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index f87e8bf5cf..288fe38c41 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -370,8 +370,9 @@ def testee(qc: cases.IKFloatField, scalar: float): qc = cases.allocate(cartesian_case, testee, "qc").zeros()() scalar = 1.0 + isize = cartesian_case.default_sizes[IDim] ksize = cartesian_case.default_sizes[KDim] - expected = np.full((ksize, ksize), np.arange(start=1, stop=11, step=1).astype(float64)) + expected = np.full((isize, ksize), np.arange(start=1, stop=ksize + 1, step=1).astype(float64)) cases.verify(cartesian_case, testee, qc, scalar, inout=qc, ref=expected) @@ -394,8 +395,9 @@ def testee_op( qc = cases.allocate(cartesian_case, testee_op, "qc").zeros()() tuple_scalar = (1.0, (1.0, 0.0)) + isize = cartesian_case.default_sizes[IDim] ksize = cartesian_case.default_sizes[KDim] - expected = np.full((ksize, ksize), np.arange(start=1.0, stop=11.0), dtype=float) + expected = np.full((isize, ksize), np.arange(start=1.0, stop=ksize + 1), dtype=float) cases.verify(cartesian_case, testee_op, qc, tuple_scalar, out=qc, ref=expected) @@ -1052,22 +1054,22 @@ def fieldop_domain(a: cases.IField) -> cases.IField: return a + a @gtx.program - def program_domain(a: cases.IField, out: cases.IField): - fieldop_domain(a, out=out, domain={IDim: (minimum(1, 2), 9)}) + def program_domain(a: cases.IField, size: int32, out: cases.IField): + fieldop_domain(a, out=out, domain={IDim: (minimum(1, 2), size)}) a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - + size = cartesian_case.default_sizes[IDim] ref = out.asnumpy().copy() # ensure we are not writing to out outside the domain - ref[1:9] = a.asnumpy()[1:9] * 2 + ref[1:size] = a.asnumpy()[1:size] * 2 - cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref) + cases.verify(cartesian_case, program_domain, a, size, out, inout=out, ref=ref) @pytest.mark.uses_floordiv def test_domain_input_bounds(cartesian_case): lower_i = 1 - upper_i = 10 + upper_i = cartesian_case.default_sizes[IDim] + 1 @gtx.field_operator def fieldop_domain(a: cases.IField) -> cases.IField: @@ -1090,9 +1092,9 @@ def program_domain( def test_domain_input_bounds_1(cartesian_case): lower_i = 1 - upper_i = 9 - lower_j = 4 - upper_j = 6 + upper_i = cartesian_case.default_sizes[IDim] + lower_j = cartesian_case.default_sizes[JDim] - 3 + upper_j = cartesian_case.default_sizes[JDim] - 1 @gtx.field_operator def fieldop_domain(a: cases.IJField) -> cases.IJField: @@ -1142,19 +1144,30 @@ def fieldop_domain_tuple( @gtx.program def program_domain_tuple( - inp0: cases.IJField, inp1: cases.IJField, out0: cases.IJField, out1: cases.IJField + inp0: cases.IJField, + inp1: cases.IJField, + out0: cases.IJField, + out1: cases.IJField, + isize: int32, + jsize: int32, ): - fieldop_domain_tuple(inp0, inp1, out=(out0, out1), domain={IDim: (1, 9), JDim: (4, 6)}) + fieldop_domain_tuple( + inp0, inp1, out=(out0, out1), domain={IDim: (1, isize), JDim: (jsize - 2, jsize)} + ) inp0 = cases.allocate(cartesian_case, program_domain_tuple, "inp0")() inp1 = cases.allocate(cartesian_case, program_domain_tuple, "inp1")() out0 = cases.allocate(cartesian_case, program_domain_tuple, "out0")() out1 = cases.allocate(cartesian_case, program_domain_tuple, "out1")() + isize = cartesian_case.default_sizes[IDim] + jsize = cartesian_case.default_sizes[JDim] - 1 ref0 = out0.asnumpy().copy() - ref0[1:9, 4:6] = inp0.asnumpy()[1:9, 4:6] + inp1.asnumpy()[1:9, 4:6] + ref0[1:isize, jsize - 2 : jsize] = ( + inp0.asnumpy()[1:isize, jsize - 2 : jsize] + inp1.asnumpy()[1:isize, jsize - 2 : jsize] + ) ref1 = out1.asnumpy().copy() - ref1[1:9, 4:6] = inp1.asnumpy()[1:9, 4:6] + ref1[1:isize, jsize - 2 : jsize] = inp1.asnumpy()[1:isize, jsize - 2 : jsize] cases.verify( cartesian_case, @@ -1163,6 +1176,8 @@ def program_domain_tuple( inp1, out0, out1, + isize, + jsize, inout=(out0, out1), ref=(ref0, ref1), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_import_from_mod.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_import_from_mod.py index 8438a735dc..952dcb31bd 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_import_from_mod.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_import_from_mod.py @@ -10,11 +10,11 @@ import numpy as np import gt4py.next as gtx -from gt4py.next import broadcast, astype +from gt4py.next import broadcast, astype, int32 from next_tests import integration_tests from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import cartesian_case +from next_tests.integration_tests.cases import cartesian_case, IDim, KDim from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, @@ -29,27 +29,29 @@ def mod_op(f: cases.IField) -> cases.IKField: return f_i_k @gtx.program - def mod_prog(f: cases.IField, out: cases.IKField): + def mod_prog(f: cases.IField, isize: int32, ksize: int32, out: cases.IKField): mod_op( f, out=out, domain={ integration_tests.cases.IDim: ( 0, - 8, + isize, ), # Nested import done on purpose, do not change - cases.KDim: (0, 3), + cases.KDim: (0, ksize), }, ) f = cases.allocate(cartesian_case, mod_prog, "f")() out = cases.allocate(cartesian_case, mod_prog, "out")() expected = np.zeros_like(out.asnumpy()) - expected[0:8, 0:3] = np.reshape(np.repeat(f.asnumpy(), out.shape[1], axis=0), out.shape)[ - 0:8, 0:3 - ] + isize = cartesian_case.default_sizes[IDim] - 1 + ksize = cartesian_case.default_sizes[KDim] - 2 + expected[0:isize, 0:ksize] = np.reshape( + np.repeat(f.asnumpy(), out.shape[1], axis=0), out.shape + )[0:isize, 0:ksize] - cases.verify(cartesian_case, mod_prog, f, out=out, ref=expected) + cases.verify(cartesian_case, mod_prog, f, isize, ksize, out=out, ref=expected) # TODO: these set of features should be allowed as module imports in a later PR diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 1707adada8..bf6dd34cca 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -75,7 +75,9 @@ def test_mod(cartesian_case): def mod_fieldop(inp1: cases.IField) -> cases.IField: return inp1 % 2 - inp1 = cartesian_case.as_field([IDim], np.asarray(range(10), dtype=int32) - 5) + inp1 = cartesian_case.as_field( + [IDim], np.asarray(range(cartesian_case.default_sizes[IDim]), dtype=int32) - 5 + ) out = cases.allocate(cartesian_case, mod_fieldop, cases.RETURN)() cases.verify(cartesian_case, mod_fieldop, inp1, out=out, ref=inp1 % 2) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py index 7d634cec90..cd10c10437 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py @@ -11,6 +11,7 @@ import pytest from next_tests.integration_tests.cases import IDim, JDim, KDim, Koff, cartesian_case from gt4py import next as gtx +from gt4py.next import int32 from gt4py.next.ffront.fbuiltins import where, broadcast from next_tests.integration_tests import cases from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -27,8 +28,14 @@ def fieldop_where_k_offset( return where(k_index > 0, inp(Koff[-1]), 2) @gtx.program - def prog(inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType], out: cases.IKField): - fieldop_where_k_offset(inp, k_index, out=out, domain={IDim: (0, 10), KDim: (1, 10)}) + def prog( + inp: cases.IKField, + k_index: gtx.Field[[KDim], gtx.IndexType], + isize: int32, + ksize: int32, + out: cases.IKField, + ): + fieldop_where_k_offset(inp, k_index, out=out, domain={IDim: (0, isize), KDim: (1, ksize)}) inp = cases.allocate(cartesian_case, fieldop_where_k_offset, "inp")() k_index = cases.allocate( @@ -37,8 +44,9 @@ def prog(inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType], out: cas out = cases.allocate(cartesian_case, fieldop_where_k_offset, cases.RETURN)() ref = np.where(k_index.asnumpy() > 0, np.roll(inp.asnumpy(), 1, axis=1), out.asnumpy()) - - cases.verify(cartesian_case, prog, inp, k_index, out=out, ref=ref) + isize = cartesian_case.default_sizes[IDim] + ksize = cartesian_case.default_sizes[KDim] + cases.verify(cartesian_case, prog, inp, k_index, isize, ksize, out=out, ref=ref) def test_same_size_fields(cartesian_case): diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index 5727f29a2a..e5c5182b81 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -37,7 +37,7 @@ def test_allocate_default_unique(cartesian_case): a = cases.allocate(cartesian_case, mixed_args, "a")() assert np.min(a.asnumpy()) == 1 - assert np.max(a.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) + assert np.max(a.asnumpy()) == np.prod(tuple(list(cartesian_case.default_sizes.values())[:3])) b = cases.allocate(cartesian_case, mixed_args, "b")() @@ -46,7 +46,7 @@ def test_allocate_default_unique(cartesian_case): c = cases.allocate(cartesian_case, mixed_args, "c")() assert np.min(c.asnumpy()) == b + 1 - assert np.max(c.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) * 2 + 1 + assert np.max(c.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())[:3]) * 2 + 1 def test_allocate_return_default_zeros(cartesian_case): diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index 4185ef0bd7..2234895b5b 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -154,7 +154,7 @@ def domain_format_1_program(in_field: gtx.Field[[IDim], float64]): assert exc_info.match("Invalid call to 'domain_format_1'") assert ( - re.search("TupleExpr are only allowed in 'domain', if", exc_info.value.__cause__.args[0]) + re.search("Tuple domain requires tuple output", exc_info.value.__cause__.args[0]) is not None ) From 5c1edaeb45c7bbafc37b7ca0968938419f783f67 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 23 Sep 2025 20:14:07 +0200 Subject: [PATCH 22/44] Enable multiple output domains in direct fo calls and fix some tests --- src/gt4py/next/ffront/decorator.py | 26 +++++---- tests/next_tests/integration_tests/cases.py | 1 + .../feature_tests/dace/test_program.py | 17 +----- .../ffront_tests/ffront_test_utils.py | 4 ++ .../ffront_tests/test_execution.py | 2 +- .../test_multiple_output_domains.py | 54 ++++++++++++------- 6 files changed, 53 insertions(+), 51 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 48322abc64..25c4b5631a 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -33,7 +33,6 @@ embedded as next_embedded, errors, metrics, - utils, ) from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( @@ -625,16 +624,17 @@ def program_inner(definition: types.FunctionType) -> Program: OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) -# def _slice_outs( -# outs: common.Field | tuple[common.Field | tuple, ...], -# domains: common.Domain | tuple[common.Domain | tuple, ...], -# ) -> common.Field | tuple[common.Field | tuple, ...]: -# if isinstance(outs, tuple): -# if not isinstance(domains, tuple): -# domains = tuple([domains] * len(outs)) -# return tuple(_slice_outs(out, domain) for out, domain in zip(outs, domains, strict=True)) -# else: -# return outs[common.domain(domains)] + +def _slice_outs( + outs: common.Field | tuple[common.Field | tuple, ...], + domains: common.Domain | tuple[common.Domain | tuple, ...], +) -> common.Field | tuple[common.Field | tuple, ...]: + if isinstance(outs, tuple): + if not isinstance(domains, tuple): + domains = tuple([domains] * len(outs)) + return tuple(_slice_outs(out, domain) for out, domain in zip(outs, domains, strict=True)) + else: + return outs[common.domain(domains)] @dataclasses.dataclass(frozen=True) @@ -778,9 +778,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: raise errors.MissingArgumentError(None, "out", True) out = kwargs.pop("out") if "domain" in kwargs: - # out = _slice_outs(out, kwargs.pop("domain")) - domain = common.domain(kwargs.pop("domain")) - out = utils.tree_map(lambda f: f[domain])(out) + out = _slice_outs(out, kwargs.pop("domain")) args, kwargs = type_info.canonicalize_arguments( self.foast_stage.foast_node.type, args, kwargs diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index afd93fe427..26ab84931a 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -722,6 +722,7 @@ def from_mesh_descriptor( Vertex: mesh_descriptor.num_vertices, Edge: mesh_descriptor.num_edges, Cell: mesh_descriptor.num_cells, + KDim: mesh_descriptor.num_levels, }, grid_type=common.GridType.UNSTRUCTURED, allocator=allocator, diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py index 2ab97814b9..5ec5f6f335 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py @@ -42,21 +42,6 @@ def exec_alloc_descriptor(request): yield request.param -@pytest.fixture -def cartesian(request, gtir_dace_backend): - yield cases.Case( - backend=gtir_dace_backend, - offset_provider={ - "Ioff": IDim, - "Joff": JDim, - "Koff": KDim, - }, - default_sizes={IDim: 10, JDim: 10, KDim: 10}, - grid_type=common.GridType.CARTESIAN, - allocator=gtir_dace_backend.allocator, - ) - - @pytest.fixture def unstructured(request, exec_alloc_descriptor, mesh_descriptor): # noqa: F811 yield cases.Case( @@ -66,7 +51,7 @@ def unstructured(request, exec_alloc_descriptor, mesh_descriptor): # noqa: F811 Vertex: mesh_descriptor.num_vertices, Edge: mesh_descriptor.num_edges, Cell: mesh_descriptor.num_cells, - KDim: 10, + KDim: mesh_descriptor.num_levels, }, grid_type=common.GridType.UNSTRUCTURED, allocator=exec_alloc_descriptor.allocator, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index cd6aaf5de3..39d68227ad 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -224,6 +224,7 @@ def offset_provider_type(self) -> common.OffsetProviderType: ... def simple_mesh(allocator) -> MeshDescriptor: num_vertices = 9 num_cells = 8 + num_levels = 10 v2e_arr = np.array( [ @@ -313,6 +314,7 @@ def simple_mesh(allocator) -> MeshDescriptor: num_vertices=num_vertices, num_edges=np.int32(num_edges), num_cells=num_cells, + num_levels=num_levels, offset_provider=offset_provider, offset_provider_type=common.offset_provider_to_type(offset_provider), ) @@ -324,6 +326,7 @@ def skip_value_mesh(allocator) -> MeshDescriptor: num_vertices = 7 num_cells = 6 num_edges = 12 + num_levels = 10 v2e_arr = np.array( [ @@ -408,6 +411,7 @@ def skip_value_mesh(allocator) -> MeshDescriptor: num_vertices=num_vertices, num_edges=num_edges, num_cells=num_cells, + num_levels=num_levels, offset_provider=offset_provider, offset_provider_type=common.offset_provider_to_type(offset_provider), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 288fe38c41..159cf2d00a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -756,7 +756,7 @@ def testee(a: tuple[tuple[cases.IField, cases.IField], cases.IField]) -> cases.I @pytest.mark.parametrize("forward", [True, False]) def test_fieldop_from_scan(cartesian_case, forward): init = 1.0 - expected = np.arange(init + 1.0, init + 1.0 + cartesian_case.default_sizes[IDim], 1) + expected = np.arange(init + 1.0, init + 1.0 + cartesian_case.default_sizes[KDim], 1) out = cartesian_case.as_field([KDim], np.zeros((cartesian_case.default_sizes[KDim],))) if not forward: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py index 138d520b6b..bf32c97209 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -587,25 +587,39 @@ def test_direct_fo_orig(cartesian_case): ) +def test_direct_fo(cartesian_case): + a = cases.allocate(cartesian_case, testee, "a")() + b = cases.allocate(cartesian_case, testee, "b")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify( + cartesian_case, + testee, + a, + b, + out=out, + ref=(b, a), + domain=( + {JDim: (0, cartesian_case.default_sizes[JDim])}, + {IDim: (0, cartesian_case.default_sizes[IDim])}, + ), + ) + + +def test_direct_fo_no_domain(cartesian_case): + a = cases.allocate(cartesian_case, testee, "a")() + b = cases.allocate(cartesian_case, testee, "b")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify( + cartesian_case, + testee, + a, + b, + out=out, + ref=(b, a), + ) + + # TODO: # - vertical staggering with dependency -# - cleanup and refactor tests - -# -# def test_direct_fo(cartesian_case): -# a = cases.allocate(cartesian_case, testee, "a")() -# b = cases.allocate(cartesian_case, testee, "b")() -# out = cases.allocate(cartesian_case, testee, cases.RETURN)() -# -# cases.verify( -# cartesian_case, -# testee, -# a, -# b, -# out=out, -# ref=(b, a), -# domain=( -# {JDim: (0, cartesian_case.default_sizes[JDim])}, -# {IDim: (0, cartesian_case.default_sizes[IDim])}, -# ), -# ) From e357b89af61190f250dec659485e665b0fcd1b07 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 24 Sep 2025 14:59:53 +0200 Subject: [PATCH 23/44] Refactor and make slices work --- src/gt4py/next/embedded/operators.py | 3 +- .../next/ffront/past_passes/type_deduction.py | 19 ++++----- src/gt4py/next/ffront/past_to_itir.py | 35 ++++++++++------ .../test_multiple_output_domains.py | 42 ++++++++++++++++++- 4 files changed, 72 insertions(+), 27 deletions(-) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index d1b308fc7c..df3f05a209 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -110,7 +110,6 @@ def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) -> out_domain = domain if domain is not None else _get_out_domain(out) - # TODO? new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain) with embedded_context.update(**new_context_kwargs): @@ -158,7 +157,7 @@ def impl(target: common.MutableField, source: common.Field, domain: common.Domai if not isinstance( domain, tuple - ): # TODO: use a generic condition that also works for nested domains and targets + ): domain = utils.tree_map(lambda _: domain)(target) impl(target, source, domain) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 3f771f281a..6d148c82ae 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -56,12 +56,11 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict) -> None: if "domain" in new_kwargs: _ensure_no_sliced_field(new_kwargs["out"]) - def check( - dom: past.Dict | past.TupleExpr, out: past.TupleExpr | past.Name, level: int = 0 - ) -> None: + def validate_domain_out(dom: past.Dict | past.TupleExpr, out: past.TupleExpr | past.Name, + is_nested: bool = False) -> None: if isinstance(dom, past.Dict): - # Only reject tuple outputs if nested (level > 0) - if level > 0 and (isinstance(out, past.TupleExpr) or isinstance(out, ts.TupleType)): + # Only reject tuple outputs if nested + if is_nested and (isinstance(out, past.TupleExpr) or isinstance(out, ts.TupleType)): raise ValueError("Domain dict cannot map to tuple outputs.") if len(dom.values_) == 0 and len(dom.keys_) == 0: @@ -72,17 +71,17 @@ def check( raise ValueError( f"Only 'Dimension' allowed in domain dictionary keys, got '{dim}' which is of type '{dim.type}'." ) + for domain_values in dom.values_: if len(domain_values.elts) != 2: raise ValueError( f"Only 2 values allowed in domain range, got {len(domain_values.elts)}." ) - if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar( - domain_values.elts[1] - ): + if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar(domain_values.elts[1]): raise ValueError( f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'." ) + elif isinstance(dom, past.TupleExpr): if isinstance(out, past.TupleExpr): out_elts = out.elts @@ -95,12 +94,12 @@ def check( raise ValueError("Mismatched tuple lengths between domain and output.") for d, o in zip(dom.elts, out_elts): - check(d, o, level=level + 1) + validate_domain_out(d, o, is_nested=True) else: raise ValueError(f"'domain' must be Dict or TupleExpr, got {type(dom)}.") - check(new_kwargs["domain"], new_kwargs["out"]) + validate_domain_out(new_kwargs["domain"], new_kwargs["out"]) class ProgramTypeDeduction(traits.VisitorWithSymbolTableTrait, NodeTranslator): diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 4931cf186a..7cf7570c13 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -455,14 +455,25 @@ def _visit_stencil_call_out_arg( elif isinstance(out_arg, past.TupleExpr) or ( isinstance(out_arg, past.Name) and isinstance(out_arg.type, ts.TupleType) ): + def get_field_and_slice(field_expr, path): + """Extract field and its slice for a given path through the tuple structure.""" + current_field = functools.reduce(lambda e, i: e.elts[i], path, out_arg) + + if isinstance(current_field, past.Subscript): + return current_field.value, self._compute_field_slice(current_field) + else: + return current_field, None + domain_expr = type_info.apply_to_primitive_constituents( - lambda field_type, path: self._construct_itir_domain_arg( - functools.reduce(lambda e, i: e.elts[i], path, out_arg), - functools.reduce(lambda e, i: e.elts[i], path, domain_arg) - if isinstance(domain_arg, past.TupleExpr) - else domain_arg, - None, # TODO: support slicing - ) + lambda field_type, path: ( + lambda field, slice_info: self._construct_itir_domain_arg( + field, + functools.reduce(lambda e, i: e.elts[i], path, domain_arg) + if isinstance(domain_arg, past.TupleExpr) + else domain_arg, + slice_info, + ) + )(*get_field_and_slice(None, path)) if isinstance(out_arg, past.TupleExpr) else self._construct_itir_domain_arg( # Create a temporary past.Name-like object that carries the indexed information @@ -477,12 +488,10 @@ def _visit_stencil_call_out_arg( "location": out_arg.location, }, )(), - ( - functools.reduce(lambda e, i: e.elts[i], path, domain_arg) - if isinstance(domain_arg, past.TupleExpr) - else domain_arg - ), - None, # TODO: support slicing + functools.reduce(lambda e, i: e.elts[i], path, domain_arg) + if isinstance(domain_arg, past.TupleExpr) + else domain_arg, + None, # Name with TupleType doesn't support per-field slicing ), out_arg.type, with_path_arg=True, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py index bf32c97209..85ed683558 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -7,6 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause import copy + +import numpy as np import pytest import gt4py.next as gtx @@ -43,14 +45,14 @@ @gtx.field_operator -def testee_no_tuple(a: IField, b: IField) -> IField: +def testee_no_tuple(a: IField, b: JField) -> IField: return a @gtx.program def prog_no_tuple( a: IField, - b: IField, + b: JField, out_a: IField, out_b: IField, i_size: gtx.int32, @@ -206,6 +208,42 @@ def test_program(cartesian_case): ref=(b, a), ) +@gtx.program +def prog_slicing( + a: IField, + b: JField, + out_a: IField, + out_b: JField, + i_size: gtx.int32, + j_size: gtx.int32, +): + testee( + a, + b, + out=(out_b[2:-2], out_a[1:-1]), + ) + + +def test_program_slicing(cartesian_case): + a = cases.allocate(cartesian_case, prog, "a")() + b = cases.allocate(cartesian_case, prog, "b")() + out_a = cases.allocate(cartesian_case, prog, "out_a")() + out_b = cases.allocate(cartesian_case, prog, "out_b")() + out_a_ =copy.deepcopy(out_a) + out_b_ =copy.deepcopy(out_b) + cases.verify( + cartesian_case, + prog_slicing, + a, + b, + out_a, + out_b, + cartesian_case.default_sizes[IDim], + cartesian_case.default_sizes[JDim], + inout=(out_b, out_a), + ref=(np.concatenate([out_b_.ndarray[0:2], b.ndarray[2:-2], out_b_.ndarray[-2:]]), np.concatenate([out_a_.ndarray[0:1], a.ndarray[1:-1], out_a_.ndarray[-1:]])), + ) + @gtx.program def prog_out_as_tuple( From 3a50b6045a3fe2a4552f22facaf57df356d29050 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 30 Sep 2025 18:39:51 +0200 Subject: [PATCH 24/44] Remove num_levels from unstructured meshes and reformat --- src/gt4py/next/embedded/operators.py | 4 +--- src/gt4py/next/ffront/past_passes/type_deduction.py | 11 ++++++++--- src/gt4py/next/ffront/past_to_itir.py | 3 ++- tests/next_tests/integration_tests/cases.py | 3 +-- .../feature_tests/ffront_tests/ffront_test_utils.py | 4 ---- .../ffront_tests/test_multiple_output_domains.py | 11 +++++++---- 6 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index df3f05a209..0672a04c0a 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -155,9 +155,7 @@ def impl(target: common.MutableField, source: common.Field, domain: common.Domai assert core_defs.is_scalar_type(source) target[domain] = source - if not isinstance( - domain, tuple - ): + if not isinstance(domain, tuple): domain = utils.tree_map(lambda _: domain)(target) impl(target, source, domain) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 6d148c82ae..54054df3a2 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -56,8 +56,11 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict) -> None: if "domain" in new_kwargs: _ensure_no_sliced_field(new_kwargs["out"]) - def validate_domain_out(dom: past.Dict | past.TupleExpr, out: past.TupleExpr | past.Name, - is_nested: bool = False) -> None: + def validate_domain_out( + dom: past.Dict | past.TupleExpr, + out: past.TupleExpr | past.Name, + is_nested: bool = False, + ) -> None: if isinstance(dom, past.Dict): # Only reject tuple outputs if nested if is_nested and (isinstance(out, past.TupleExpr) or isinstance(out, ts.TupleType)): @@ -77,7 +80,9 @@ def validate_domain_out(dom: past.Dict | past.TupleExpr, out: past.TupleExpr | p raise ValueError( f"Only 2 values allowed in domain range, got {len(domain_values.elts)}." ) - if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar(domain_values.elts[1]): + if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar( + domain_values.elts[1] + ): raise ValueError( f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'." ) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index e0357eb415..719cd7a9e5 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -14,7 +14,7 @@ import devtools -from gt4py.eve import NodeTranslator, traits +from gt4py.eve import NodeTranslator, traits from gt4py.next import common, config, errors, utils as gtx_utils from gt4py.next.ffront import ( fbuiltins, @@ -463,6 +463,7 @@ def _visit_stencil_call_out_arg( elif isinstance(out_arg, past.TupleExpr) or ( isinstance(out_arg, past.Name) and isinstance(out_arg.type, ts.TupleType) ): + def get_field_and_slice(field_expr, path): """Extract field and its slice for a given path through the tuple structure.""" current_field = functools.reduce(lambda e, i: e.elts[i], path, out_arg) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 26ab84931a..19e8e3ea6d 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -567,7 +567,7 @@ def unstructured_case( def unstructured_case_3d(unstructured_case): return dataclasses.replace( unstructured_case, - default_sizes={**unstructured_case.default_sizes}, + default_sizes={**unstructured_case.default_sizes, KDim: 10}, offset_provider={**unstructured_case.offset_provider, "Koff": KDim}, ) @@ -722,7 +722,6 @@ def from_mesh_descriptor( Vertex: mesh_descriptor.num_vertices, Edge: mesh_descriptor.num_edges, Cell: mesh_descriptor.num_cells, - KDim: mesh_descriptor.num_levels, }, grid_type=common.GridType.UNSTRUCTURED, allocator=allocator, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 39d68227ad..cd6aaf5de3 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -224,7 +224,6 @@ def offset_provider_type(self) -> common.OffsetProviderType: ... def simple_mesh(allocator) -> MeshDescriptor: num_vertices = 9 num_cells = 8 - num_levels = 10 v2e_arr = np.array( [ @@ -314,7 +313,6 @@ def simple_mesh(allocator) -> MeshDescriptor: num_vertices=num_vertices, num_edges=np.int32(num_edges), num_cells=num_cells, - num_levels=num_levels, offset_provider=offset_provider, offset_provider_type=common.offset_provider_to_type(offset_provider), ) @@ -326,7 +324,6 @@ def skip_value_mesh(allocator) -> MeshDescriptor: num_vertices = 7 num_cells = 6 num_edges = 12 - num_levels = 10 v2e_arr = np.array( [ @@ -411,7 +408,6 @@ def skip_value_mesh(allocator) -> MeshDescriptor: num_vertices=num_vertices, num_edges=num_edges, num_cells=num_cells, - num_levels=num_levels, offset_provider=offset_provider, offset_provider_type=common.offset_provider_to_type(offset_provider), ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py index 85ed683558..1fb972c98a 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -25,7 +25,6 @@ CField, VField, Cell, - Vertex, cartesian_case, unstructured_case, Case, @@ -208,6 +207,7 @@ def test_program(cartesian_case): ref=(b, a), ) + @gtx.program def prog_slicing( a: IField, @@ -229,8 +229,8 @@ def test_program_slicing(cartesian_case): b = cases.allocate(cartesian_case, prog, "b")() out_a = cases.allocate(cartesian_case, prog, "out_a")() out_b = cases.allocate(cartesian_case, prog, "out_b")() - out_a_ =copy.deepcopy(out_a) - out_b_ =copy.deepcopy(out_b) + out_a_ = copy.deepcopy(out_a) + out_b_ = copy.deepcopy(out_b) cases.verify( cartesian_case, prog_slicing, @@ -241,7 +241,10 @@ def test_program_slicing(cartesian_case): cartesian_case.default_sizes[IDim], cartesian_case.default_sizes[JDim], inout=(out_b, out_a), - ref=(np.concatenate([out_b_.ndarray[0:2], b.ndarray[2:-2], out_b_.ndarray[-2:]]), np.concatenate([out_a_.ndarray[0:1], a.ndarray[1:-1], out_a_.ndarray[-1:]])), + ref=( + np.concatenate([out_b_.ndarray[0:2], b.ndarray[2:-2], out_b_.ndarray[-2:]]), + np.concatenate([out_a_.ndarray[0:1], a.ndarray[1:-1], out_a_.ndarray[-1:]]), + ), ) From 5b1bf8690ca8774fbb443838c53ce005f070709a Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 1 Oct 2025 10:16:40 +0200 Subject: [PATCH 25/44] Remove num_levels from MeshDescriptor --- .../integration_tests/feature_tests/dace/test_program.py | 1 - .../feature_tests/ffront_tests/ffront_test_utils.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py index 5ec5f6f335..0d6c44977e 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py @@ -51,7 +51,6 @@ def unstructured(request, exec_alloc_descriptor, mesh_descriptor): # noqa: F811 Vertex: mesh_descriptor.num_vertices, Edge: mesh_descriptor.num_edges, Cell: mesh_descriptor.num_cells, - KDim: mesh_descriptor.num_levels, }, grid_type=common.GridType.UNSTRUCTURED, allocator=exec_alloc_descriptor.allocator, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index cd6aaf5de3..7640553e6a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -211,9 +211,6 @@ def num_cells(self) -> int: ... @property def num_edges(self) -> int: ... - @property - def num_levels(self) -> int: ... - @property def offset_provider(self) -> common.OffsetProvider: ... From 0440779ad01ea2815369b9c8ff160fa33aecec4e Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 1 Oct 2025 12:27:07 +0200 Subject: [PATCH 26/44] Try to refactor Domain vs DomainLike --- src/gt4py/next/common.py | 11 +++++++++++ src/gt4py/next/embedded/operators.py | 14 ++++++++------ src/gt4py/next/ffront/decorator.py | 9 +++++---- src/gt4py/next/iterator/embedded.py | 3 ++- 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index d3c4c854d5..cf3e73dfb7 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -29,6 +29,7 @@ Final, Generic, Literal, + MaybeNestedInTuple, NamedTuple, Never, Optional, @@ -592,6 +593,16 @@ def __getstate__(self) -> dict[str, Any]: ) # `Domain` is `Sequence[NamedRange]` and therefore a subset +def normalize_domains(domain_like: MaybeNestedInTuple[DomainLike]) -> MaybeNestedInTuple[Domain]: + """ + Convert a potentially nested tuple structure of `DomainLike` objects to `Domain` objects. + """ + if isinstance(domain_like, tuple): + return tuple(normalize_domains(item) for item in domain_like) + else: + return domain(domain_like) + + def domain(domain_like: DomainLike) -> Domain: """ Construct `Domain` from `DomainLike` object. diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index 0672a04c0a..a57df1f204 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -11,6 +11,7 @@ from gt4py import eve from gt4py._core import definitions as core_defs +from gt4py.eve import extended_typing as xtyping from gt4py.next import common, errors, field_utils, utils from gt4py.next.embedded import common as embedded_common, context as embedded_context from gt4py.next.field_utils import get_array_ns @@ -108,7 +109,9 @@ def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) -> domain = kwargs.pop("domain", None) - out_domain = domain if domain is not None else _get_out_domain(out) + out_domain = ( + common.normalize_domains(domain) if domain is not None else _get_out_domain(out) + ) new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain) @@ -129,13 +132,13 @@ def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) -> def _get_vertical_range( - domain: common.Domain | tuple[common.Domain, ...], + domain: xtyping.MaybeNestedInTuple[common.Domain], ) -> common.NamedRange | eve.NothingType | tuple[common.NamedRange | eve.NothingType, ...]: if isinstance(domain, tuple): return tuple(_get_vertical_range(dom) for dom in domain) else: vertical_dim_filtered = [ - nr for nr in common.domain(domain) if nr.dim.kind == common.DimensionKind.VERTICAL + nr for nr in domain if nr.dim.kind == common.DimensionKind.VERTICAL ] assert len(vertical_dim_filtered) <= 1 return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING @@ -144,11 +147,10 @@ def _get_vertical_range( def _tuple_assign_field( target: tuple[common.MutableField | tuple, ...] | common.MutableField, source: tuple[common.Field | tuple, ...] | common.Field, - domain: common.DomainLike | tuple[common.DomainLike | tuple, ...], + domain: xtyping.MaybeNestedInTuple[common.Domain], ) -> None: @utils.tree_map - def impl(target: common.MutableField, source: common.Field, domain: common.DomainLike) -> None: - domain = common.domain(domain) + def impl(target: common.MutableField, source: common.Field, domain: common.Domain) -> None: if isinstance(source, common.Field): target[domain] = source[domain] else: diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 64cd0ace3d..2a000c035e 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -630,15 +630,15 @@ def program_inner(definition: types.FunctionType) -> Program: def _slice_outs( - outs: common.Field | tuple[common.Field | tuple, ...], - domains: common.Domain | tuple[common.Domain | tuple, ...], + outs: xtyping.MaybeNestedInTuple[common.Field], + domains: xtyping.MaybeNestedInTuple[common.Domain], ) -> common.Field | tuple[common.Field | tuple, ...]: if isinstance(outs, tuple): if not isinstance(domains, tuple): domains = tuple([domains] * len(outs)) return tuple(_slice_outs(out, domain) for out, domain in zip(outs, domains, strict=True)) else: - return outs[common.domain(domains)] + return outs[domains] @dataclasses.dataclass(frozen=True) @@ -782,7 +782,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: raise errors.MissingArgumentError(None, "out", True) out = kwargs.pop("out") if "domain" in kwargs: - out = _slice_outs(out, kwargs.pop("domain")) + dom = common.normalize_domains(kwargs.pop("domain")) + out = _slice_outs(out, dom) args, kwargs = type_info.canonicalize_arguments( self.foast_stage.foast_node.type, args, kwargs diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 8f72ac3ff3..76cee86153 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1632,9 +1632,10 @@ def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProvider @runtime.set_at.register(EMBEDDED) def set_at( expr: common.Field, - domain: common.DomainLike | tuple[common.DomainLike | tuple, ...], + domain: xtyping.MaybeNestedInTuple[common.DomainLike], target: common.MutableField, ) -> None: + domain = common.normalize_domains(domain) operators._tuple_assign_field(target, expr, domain) From 4e72b822c18e0e13d28aa3bc54259d6381e7c04f Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 1 Oct 2025 14:34:42 +0200 Subject: [PATCH 27/44] Update tests and address TODO --- .../next/ffront/past_passes/type_deduction.py | 5 +++-- .../ffront_tests/ffront_test_utils.py | 20 +++++++++---------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 54054df3a2..0cba6a421d 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -164,8 +164,9 @@ def visit_TupleExpr(self, node: past.TupleExpr, **kwargs: Any) -> past.TupleExpr def infer_type(elt: past.Dict | past.TupleExpr) -> ts.DomainType | ts.TupleType: if isinstance(elt, past.Dict): - # TODO: add check that Dict is DomainLike - return ts.DomainType(dims=[common.Dimension(elt.keys_[0].id)]) + assert all(isinstance(key, past.Name) for key in elt.keys_) + assert all(isinstance(key.type, ts.DimensionType) for key in elt.keys_) + return ts.DomainType(dims=[common.Dimension(key.id) for key in elt.keys_]) elif isinstance(elt, past.TupleExpr): return ts.TupleType(types=[infer_type(elt) for elt in elt.elts]) else: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 7640553e6a..a41eae7290 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -74,22 +74,22 @@ def __gt_allocator__( next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, next_tests.definitions.ProgramBackendId.GTFN_CPU, next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, - pytest.param( - next_tests.definitions.ProgramBackendId.GTFN_GPU, marks=pytest.mark.requires_gpu - ), + # pytest.param( + # next_tests.definitions.ProgramBackendId.GTFN_GPU, marks=pytest.mark.requires_gpu + # ), # will use the default (embedded) execution, but input/output allocated with the provided allocator next_tests.definitions.EmbeddedIds.NUMPY_EXECUTION, - pytest.param( - next_tests.definitions.EmbeddedIds.CUPY_EXECUTION, marks=pytest.mark.requires_gpu - ), + # pytest.param( + # next_tests.definitions.EmbeddedIds.CUPY_EXECUTION, marks=pytest.mark.requires_gpu + # ), pytest.param( next_tests.definitions.OptionalProgramBackendId.DACE_CPU, marks=pytest.mark.requires_dace, ), - pytest.param( - next_tests.definitions.OptionalProgramBackendId.DACE_GPU, - marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), - ), + # pytest.param( + # next_tests.definitions.OptionalProgramBackendId.DACE_GPU, + # marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), + # ), pytest.param( next_tests.definitions.OptionalProgramBackendId.DACE_CPU_NO_OPT, marks=pytest.mark.requires_dace, From 9658d679596eaa24b4032ef7a62ed484a2f13710 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 1 Oct 2025 14:42:07 +0200 Subject: [PATCH 28/44] Revert unintensional change and pdate tests --- .../ffront_tests/ffront_test_utils.py | 20 +++++++++---------- .../test_multiple_output_domains.py | 13 ------------ 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index a41eae7290..7640553e6a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -74,22 +74,22 @@ def __gt_allocator__( next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, next_tests.definitions.ProgramBackendId.GTFN_CPU, next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, - # pytest.param( - # next_tests.definitions.ProgramBackendId.GTFN_GPU, marks=pytest.mark.requires_gpu - # ), + pytest.param( + next_tests.definitions.ProgramBackendId.GTFN_GPU, marks=pytest.mark.requires_gpu + ), # will use the default (embedded) execution, but input/output allocated with the provided allocator next_tests.definitions.EmbeddedIds.NUMPY_EXECUTION, - # pytest.param( - # next_tests.definitions.EmbeddedIds.CUPY_EXECUTION, marks=pytest.mark.requires_gpu - # ), + pytest.param( + next_tests.definitions.EmbeddedIds.CUPY_EXECUTION, marks=pytest.mark.requires_gpu + ), pytest.param( next_tests.definitions.OptionalProgramBackendId.DACE_CPU, marks=pytest.mark.requires_dace, ), - # pytest.param( - # next_tests.definitions.OptionalProgramBackendId.DACE_GPU, - # marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), - # ), + pytest.param( + next_tests.definitions.OptionalProgramBackendId.DACE_GPU, + marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), + ), pytest.param( next_tests.definitions.OptionalProgramBackendId.DACE_CPU_NO_OPT, marks=pytest.mark.requires_dace, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py index 1fb972c98a..4b31d0f2c1 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -37,8 +37,6 @@ mesh_descriptor, ) -from gt4py.next import common - KHalfDim = gtx.Dimension("KHalf", kind=gtx.DimensionKind.VERTICAL) pytestmark = pytest.mark.uses_cartesian_shift @@ -53,7 +51,6 @@ def prog_no_tuple( a: IField, b: JField, out_a: IField, - out_b: IField, i_size: gtx.int32, ): testee_no_tuple(a, b, out=out_a, domain={IDim: (0, i_size)}) @@ -63,7 +60,6 @@ def test_program_no_tuple(cartesian_case): a = cases.allocate(cartesian_case, prog_no_tuple, "a")() b = cases.allocate(cartesian_case, prog_no_tuple, "b")() out_a = cases.allocate(cartesian_case, prog_no_tuple, "out_a")() - out_b = cases.allocate(cartesian_case, prog_no_tuple, "out_b")() cases.verify( cartesian_case, @@ -71,7 +67,6 @@ def test_program_no_tuple(cartesian_case): a, b, out_a, - out_b, cartesian_case.default_sizes[IDim], inout=out_a, ref=a, @@ -214,8 +209,6 @@ def prog_slicing( b: JField, out_a: IField, out_b: JField, - i_size: gtx.int32, - j_size: gtx.int32, ): testee( a, @@ -238,8 +231,6 @@ def test_program_slicing(cartesian_case): b, out_a, out_b, - cartesian_case.default_sizes[IDim], - cartesian_case.default_sizes[JDim], inout=(out_b, out_a), ref=( np.concatenate([out_b_.ndarray[0:2], b.ndarray[2:-2], out_b_.ndarray[-2:]]), @@ -660,7 +651,3 @@ def test_direct_fo_no_domain(cartesian_case): out=out, ref=(b, a), ) - - -# TODO: -# - vertical staggering with dependency From 21b8ca6f831db4c93e00be447c95ba2f991d998a Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 2 Oct 2025 18:10:45 +0200 Subject: [PATCH 29/44] Minor --- src/gt4py/next/ffront/past_to_itir.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 719cd7a9e5..2d4da265d1 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -464,7 +464,7 @@ def _visit_stencil_call_out_arg( isinstance(out_arg, past.Name) and isinstance(out_arg.type, ts.TupleType) ): - def get_field_and_slice(field_expr, path): + def get_field_and_slice(path: tuple[int, ...]): """Extract field and its slice for a given path through the tuple structure.""" current_field = functools.reduce(lambda e, i: e.elts[i], path, out_arg) @@ -482,7 +482,7 @@ def get_field_and_slice(field_expr, path): else domain_arg, slice_info, ) - )(*get_field_and_slice(None, path)) + )(*get_field_and_slice(path)) if isinstance(out_arg, past.TupleExpr) else self._construct_itir_domain_arg( # Create a temporary past.Name-like object that carries the indexed information From 98420e6a01bd25cd9f94bb3e2c19a22cd7f54a12 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 7 Oct 2025 14:05:22 +0200 Subject: [PATCH 30/44] fix global tmps tuple splitting --- src/gt4py/next/iterator/transforms/global_tmps.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 092c3291bf..c6b2a118a6 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -9,7 +9,6 @@ from __future__ import annotations import functools -from collections.abc import Sequence from typing import Callable, Literal, Optional, cast from gt4py.eve import utils as eve_utils @@ -30,7 +29,7 @@ def select_elems_by_domain( select_domain: SymbolicDomain, target: itir.Expr, - args: Sequence[itir.Expr], + source: itir.Expr, domains: tuple[SymbolicDomain, ...], ): """ @@ -40,12 +39,12 @@ def select_elems_by_domain( """ new_targets = [] new_els = [] - for i, (el, el_domain) in enumerate(zip(args, domains)): + for i, el_domain in enumerate(domains): current_target = im.tuple_get(i, target) + current_source = im.tuple_get(i, source) if isinstance(el_domain, tuple): - assert cpm.is_call_to(el, "make_tuple") more_targets, more_els = select_elems_by_domain( - select_domain, current_target, el.args, el_domain + select_domain, current_target, current_source, el_domain ) new_els.extend(more_els) new_targets.extend(more_targets) @@ -53,16 +52,15 @@ def select_elems_by_domain( assert isinstance(el_domain, SymbolicDomain) if el_domain == select_domain: new_targets.append(current_target) - new_els.append(el) + new_els.append(current_source) return new_targets, new_els def _set_at_for_domain(stmt: itir.SetAt, domain: SymbolicDomain) -> itir.SetAt: """Extract all elements with given domain into a new `SetAt` statement.""" tuple_expr = stmt.expr - assert cpm.is_call_to(tuple_expr, "make_tuple") targets, expr_els = select_elems_by_domain( - domain, stmt.target, tuple_expr.args, stmt.expr.annex.domain + domain, stmt.target, tuple_expr, stmt.expr.annex.domain ) new_expr = im.make_tuple(*expr_els) new_expr.annex.domain = domain From 88401fea1b11a17d995d43ed33615993d439f876 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 7 Oct 2025 14:28:10 +0200 Subject: [PATCH 31/44] nested direct fop call and cleanups --- .../test_multiple_output_domains.py | 63 ++++++++++++++----- 1 file changed, 46 insertions(+), 17 deletions(-) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py index 4b31d0f2c1..b0653f5b6f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -142,7 +142,7 @@ def testee(a: IField, b: JField) -> tuple[JField, IField]: @gtx.program -def prog_no_domain_differnet_fields( +def prog_no_domain_different_fields( a: IField, b: JField, out_a: IField, @@ -154,14 +154,14 @@ def prog_no_domain_differnet_fields( def test_program_no_domain_different_fields( cartesian_case, ): - a = cases.allocate(cartesian_case, prog_no_domain_differnet_fields, "a")() - b = cases.allocate(cartesian_case, prog_no_domain_differnet_fields, "b")() - out_a = cases.allocate(cartesian_case, prog_no_domain_differnet_fields, "out_a")() - out_b = cases.allocate(cartesian_case, prog_no_domain_differnet_fields, "out_b")() + a = cases.allocate(cartesian_case, prog_no_domain_different_fields, "a")() + b = cases.allocate(cartesian_case, prog_no_domain_different_fields, "b")() + out_a = cases.allocate(cartesian_case, prog_no_domain_different_fields, "out_a")() + out_b = cases.allocate(cartesian_case, prog_no_domain_different_fields, "out_b")() cases.verify( cartesian_case, - prog_no_domain_differnet_fields, + prog_no_domain_different_fields, a, b, out_a, @@ -411,7 +411,8 @@ def prog_double_nested_tuples( c: KField, out_a: IField, out_b: JField, - out_c: KField, + out_c0: KField, + out_c1: KField, i_size: gtx.int32, j_size: gtx.int32, k_size: gtx.int32, @@ -420,7 +421,7 @@ def prog_double_nested_tuples( a, b, c, - out=((out_a, (out_b, out_c)), out_c), + out=((out_a, (out_b, out_c0)), out_c1), domain=( ({IDim: (0, i_size)}, ({JDim: (0, j_size)}, {KDim: (0, k_size)})), {KDim: (0, k_size)}, @@ -436,7 +437,8 @@ def test_program_double_nested_tuples( c = cases.allocate(cartesian_case, prog_double_nested_tuples, "c")() out_a = cases.allocate(cartesian_case, prog_double_nested_tuples, "out_a")() out_b = cases.allocate(cartesian_case, prog_double_nested_tuples, "out_b")() - out_c = cases.allocate(cartesian_case, prog_double_nested_tuples, "out_c")() + out_c0 = cases.allocate(cartesian_case, prog_double_nested_tuples, "out_c0")() + out_c1 = cases.allocate(cartesian_case, prog_double_nested_tuples, "out_c1")() cases.verify( cartesian_case, @@ -446,11 +448,12 @@ def test_program_double_nested_tuples( c, out_a, out_b, - out_c, + out_c0, + out_c1, cartesian_case.default_sizes[IDim], cartesian_case.default_sizes[JDim], cartesian_case.default_sizes[KDim], - inout=((out_a, (out_b, out_c)), out_c), + inout=((out_a, (out_b, out_c0)), out_c1), ref=((a, (b, c)), c), ) @@ -619,6 +622,30 @@ def test_direct_fo_orig(cartesian_case): ) +def test_direct_fo_nested(cartesian_case): + a = cases.allocate(cartesian_case, testee_nested_tuples, "a")() + b = cases.allocate(cartesian_case, testee_nested_tuples, "b")() + c = cases.allocate(cartesian_case, testee_nested_tuples, "c")() + out = cases.allocate(cartesian_case, testee_nested_tuples, cases.RETURN)() + + cases.verify( + cartesian_case, + testee_nested_tuples, + a, + b, + c, + out=out, + ref=((a, b), c), + domain=( + ( + {IDim: (0, cartesian_case.default_sizes[IDim])}, + {JDim: (0, cartesian_case.default_sizes[JDim])}, + ), + {KDim: (0, cartesian_case.default_sizes[KDim])}, + ), + ) + + def test_direct_fo(cartesian_case): a = cases.allocate(cartesian_case, testee, "a")() b = cases.allocate(cartesian_case, testee, "b")() @@ -638,16 +665,18 @@ def test_direct_fo(cartesian_case): ) -def test_direct_fo_no_domain(cartesian_case): - a = cases.allocate(cartesian_case, testee, "a")() - b = cases.allocate(cartesian_case, testee, "b")() - out = cases.allocate(cartesian_case, testee, cases.RETURN)() +def test_direct_fo_nested_no_domain(cartesian_case): + a = cases.allocate(cartesian_case, testee_nested_tuples, "a")() + b = cases.allocate(cartesian_case, testee_nested_tuples, "b")() + c = cases.allocate(cartesian_case, testee_nested_tuples, "c")() + out = cases.allocate(cartesian_case, testee_nested_tuples, cases.RETURN)() cases.verify( cartesian_case, - testee, + testee_nested_tuples, a, b, + c, out=out, - ref=(b, a), + ref=((a, b), c), ) From a07294d760b5f269760b492f5ff218f450aad909 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 7 Oct 2025 16:07:54 +0200 Subject: [PATCH 32/44] cleanup --- .../test_multiple_output_domains.py | 78 +++++++++---------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py index b0653f5b6f..dd30caa726 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -74,7 +74,7 @@ def test_program_no_tuple(cartesian_case): @gtx.field_operator -def testee_orig(a: IField, b: IField) -> tuple[IField, IField]: +def fop_original(a: IField, b: IField) -> tuple[IField, IField]: return b, a @@ -86,7 +86,7 @@ def prog_orig( out_b: IField, i_size: gtx.int32, ): - testee_orig(a, b, out=(out_b, out_a), domain={IDim: (0, i_size)}) + fop_original(a, b, out=(out_b, out_a), domain={IDim: (0, i_size)}) def test_program_orig(cartesian_case): @@ -115,7 +115,7 @@ def prog_no_domain( out_a: IField, out_b: IField, ): - testee_orig(a, b, out=(out_b, out_a)) + fop_original(a, b, out=(out_b, out_a)) def test_program_no_domain(cartesian_case): @@ -137,7 +137,7 @@ def test_program_no_domain(cartesian_case): @gtx.field_operator -def testee(a: IField, b: JField) -> tuple[JField, IField]: +def fop_different_fields(a: IField, b: JField) -> tuple[JField, IField]: return b, a @@ -148,7 +148,7 @@ def prog_no_domain_different_fields( out_a: IField, out_b: JField, ): - testee(a, b, out=(out_b, out_a)) + fop_different_fields(a, b, out=(out_b, out_a)) def test_program_no_domain_different_fields( @@ -180,7 +180,9 @@ def prog( i_size: gtx.int32, j_size: gtx.int32, ): - testee(a, b, out=(out_b, out_a), domain=({JDim: (0, j_size)}, {IDim: (0, i_size)})) + fop_different_fields( + a, b, out=(out_b, out_a), domain=({JDim: (0, j_size)}, {IDim: (0, i_size)}) + ) def test_program(cartesian_case): @@ -210,7 +212,7 @@ def prog_slicing( out_a: IField, out_b: JField, ): - testee( + fop_different_fields( a, b, out=(out_b[2:-2], out_a[1:-1]), @@ -247,7 +249,7 @@ def prog_out_as_tuple( i_size: gtx.int32, j_size: gtx.int32, ): - testee(a, b, out=out, domain=({JDim: (0, j_size)}, {IDim: (0, i_size)})) + fop_different_fields(a, b, out=out, domain=({JDim: (0, j_size)}, {IDim: (0, i_size)})) def test_program_out_as_tuple( @@ -282,7 +284,7 @@ def prog_out_as_tuple_different_sizes( restrict_j_0: gtx.int32, restrict_j_1: gtx.int32, ): - testee( + fop_different_fields( a, b, out=out, @@ -330,7 +332,7 @@ def test_program_out_as_tuple_different_sizes( @gtx.field_operator -def testee_nested_tuples( +def fop_nested_tuples( a: IField, b: JField, c: KField, @@ -353,7 +355,7 @@ def prog_nested_tuples( j_size: gtx.int32, k_size: gtx.int32, ): - testee_nested_tuples( + fop_nested_tuples( a, b, c, @@ -390,7 +392,7 @@ def test_program_nested_tuples( @gtx.field_operator -def testee_double_nested_tuples( +def fop_double_nested_tuples( a: IField, b: JField, c: KField, @@ -417,7 +419,7 @@ def prog_double_nested_tuples( j_size: gtx.int32, k_size: gtx.int32, ): - testee_double_nested_tuples( + fop_double_nested_tuples( a, b, c, @@ -459,7 +461,7 @@ def test_program_double_nested_tuples( @gtx.field_operator -def testee_two_vertical_dims( +def fop_two_vertical_dims( a: KField, b: gtx.Field[[KHalfDim], gtx.float32] ) -> tuple[gtx.Field[[KHalfDim], gtx.float32], KField]: return b, a @@ -474,7 +476,7 @@ def prog_two_vertical_dims( k_size: gtx.int32, k_half_size: gtx.int32, ): - testee_two_vertical_dims( + fop_two_vertical_dims( a, b, out=(out_b, out_a), domain=({KHalfDim: (0, k_half_size)}, {KDim: (0, k_size)}) ) @@ -500,7 +502,7 @@ def test_program_two_vertical_dims(cartesian_case): @gtx.field_operator -def testee_shift_e2c(a: EField) -> tuple[CField, EField]: +def fop_shift_e2c(a: EField) -> tuple[CField, EField]: return a(C2E[1]), a @@ -512,9 +514,7 @@ def prog_unstructured( c_size: gtx.int32, e_size: gtx.int32, ): - testee_shift_e2c( - a, out=(out_a_shifted, out_a), domain=({Cell: (0, c_size)}, {Edge: (0, e_size)}) - ) + fop_shift_e2c(a, out=(out_a_shifted, out_a), domain=({Cell: (0, c_size)}, {Edge: (0, e_size)})) def test_program_unstructured(unstructured_case): @@ -536,7 +536,7 @@ def test_program_unstructured(unstructured_case): @gtx.field_operator -def testee_temporary(a: VField): +def fop_temporary(a: VField): edge = a(E2V[1]) cell = edge(C2E[1]) return edge, cell @@ -554,7 +554,7 @@ def prog_temporary( restrict_cell_0: gtx.int32, restrict_cell_1: gtx.int32, ): - testee_temporary( + fop_temporary( a, out=(out_edge, out_cell), domain=( @@ -607,13 +607,13 @@ def test_program_temporary(unstructured_case): def test_direct_fo_orig(cartesian_case): - a = cases.allocate(cartesian_case, testee_orig, "a")() - b = cases.allocate(cartesian_case, testee_orig, "b")() - out = cases.allocate(cartesian_case, testee_orig, cases.RETURN)() + a = cases.allocate(cartesian_case, fop_original, "a")() + b = cases.allocate(cartesian_case, fop_original, "b")() + out = cases.allocate(cartesian_case, fop_original, cases.RETURN)() cases.verify( cartesian_case, - testee_orig, + fop_original, a, b, out=out, @@ -623,14 +623,14 @@ def test_direct_fo_orig(cartesian_case): def test_direct_fo_nested(cartesian_case): - a = cases.allocate(cartesian_case, testee_nested_tuples, "a")() - b = cases.allocate(cartesian_case, testee_nested_tuples, "b")() - c = cases.allocate(cartesian_case, testee_nested_tuples, "c")() - out = cases.allocate(cartesian_case, testee_nested_tuples, cases.RETURN)() + a = cases.allocate(cartesian_case, fop_nested_tuples, "a")() + b = cases.allocate(cartesian_case, fop_nested_tuples, "b")() + c = cases.allocate(cartesian_case, fop_nested_tuples, "c")() + out = cases.allocate(cartesian_case, fop_nested_tuples, cases.RETURN)() cases.verify( cartesian_case, - testee_nested_tuples, + fop_nested_tuples, a, b, c, @@ -647,13 +647,13 @@ def test_direct_fo_nested(cartesian_case): def test_direct_fo(cartesian_case): - a = cases.allocate(cartesian_case, testee, "a")() - b = cases.allocate(cartesian_case, testee, "b")() - out = cases.allocate(cartesian_case, testee, cases.RETURN)() + a = cases.allocate(cartesian_case, fop_different_fields, "a")() + b = cases.allocate(cartesian_case, fop_different_fields, "b")() + out = cases.allocate(cartesian_case, fop_different_fields, cases.RETURN)() cases.verify( cartesian_case, - testee, + fop_different_fields, a, b, out=out, @@ -666,14 +666,14 @@ def test_direct_fo(cartesian_case): def test_direct_fo_nested_no_domain(cartesian_case): - a = cases.allocate(cartesian_case, testee_nested_tuples, "a")() - b = cases.allocate(cartesian_case, testee_nested_tuples, "b")() - c = cases.allocate(cartesian_case, testee_nested_tuples, "c")() - out = cases.allocate(cartesian_case, testee_nested_tuples, cases.RETURN)() + a = cases.allocate(cartesian_case, fop_nested_tuples, "a")() + b = cases.allocate(cartesian_case, fop_nested_tuples, "b")() + c = cases.allocate(cartesian_case, fop_nested_tuples, "c")() + out = cases.allocate(cartesian_case, fop_nested_tuples, cases.RETURN)() cases.verify( cartesian_case, - testee_nested_tuples, + fop_nested_tuples, a, b, c, From 0cc0bf415b12b13fb7a6a80da6111a6da8c9abf6 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 7 Oct 2025 17:06:05 +0200 Subject: [PATCH 33/44] improve past type deduction --- .../next/ffront/past_passes/type_deduction.py | 132 +++++++++--------- 1 file changed, 65 insertions(+), 67 deletions(-) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 0cba6a421d..d281687d2d 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -37,6 +37,54 @@ def _is_integral_scalar(expr: past.Expr) -> bool: return isinstance(expr.type, ts.ScalarType) and type_info.is_integral(expr.type) +def _validate_domain_out( + dom: past.Dict | past.TupleExpr, + out: ts.TypeSpec, + is_nested: bool = False, +) -> None: + if isinstance(dom, past.Dict): + # Only reject tuple outputs if nested + if is_nested and (isinstance(out, past.TupleExpr) or isinstance(out, ts.TupleType)): + raise ValueError("Domain dict cannot map to tuple outputs.") + + if len(dom.values_) == 0 and len(dom.keys_) == 0: + raise ValueError("Empty domain not allowed.") + + for dim in dom.keys_: + if not isinstance(dim.type, ts.DimensionType): + raise ValueError( + f"Only 'Dimension' allowed in domain dictionary keys, got '{dim}' which is of type '{dim.type}'." + ) + + for domain_values in dom.values_: + if len(domain_values.elts) != 2: + raise ValueError( + f"Only 2 values allowed in domain range, got {len(domain_values.elts)}." + ) + if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar( + domain_values.elts[1] + ): + raise ValueError( + f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'." + ) + + elif isinstance(dom, past.TupleExpr): + if isinstance(out, ts.TupleType): + out_elts = out.types + else: + raise ValueError(f"Tuple domain requires tuple output, got {type(out)}.") + + if len(dom.elts) != len(out_elts): + raise ValueError("Mismatched tuple lengths between domain and output.") + + for d, o in zip(dom.elts, out_elts): + assert isinstance(d, (past.Dict, past.TupleExpr)) + _validate_domain_out(d, o, is_nested=True) + + else: + raise ValueError(f"'domain' must be Dict or TupleExpr, got {type(dom)}.") + + def _validate_operator_call(new_func: past.Name, new_kwargs: dict) -> None: """ Perform checks for domain and output field types. @@ -53,58 +101,12 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict) -> None: if "out" not in new_kwargs: raise ValueError("Missing required keyword argument 'out'.") - if "domain" in new_kwargs: + if (domain := new_kwargs.get("domain")) is not None: _ensure_no_sliced_field(new_kwargs["out"]) - - def validate_domain_out( - dom: past.Dict | past.TupleExpr, - out: past.TupleExpr | past.Name, - is_nested: bool = False, - ) -> None: - if isinstance(dom, past.Dict): - # Only reject tuple outputs if nested - if is_nested and (isinstance(out, past.TupleExpr) or isinstance(out, ts.TupleType)): - raise ValueError("Domain dict cannot map to tuple outputs.") - - if len(dom.values_) == 0 and len(dom.keys_) == 0: - raise ValueError("Empty domain not allowed.") - - for dim in dom.keys_: - if not isinstance(dim.type, ts.DimensionType): - raise ValueError( - f"Only 'Dimension' allowed in domain dictionary keys, got '{dim}' which is of type '{dim.type}'." - ) - - for domain_values in dom.values_: - if len(domain_values.elts) != 2: - raise ValueError( - f"Only 2 values allowed in domain range, got {len(domain_values.elts)}." - ) - if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar( - domain_values.elts[1] - ): - raise ValueError( - f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'." - ) - - elif isinstance(dom, past.TupleExpr): - if isinstance(out, past.TupleExpr): - out_elts = out.elts - elif isinstance(out.type, ts.TupleType): - out_elts = out.type.types - else: - raise ValueError(f"Tuple domain requires tuple output, got {type(out)}.") - - if len(dom.elts) != len(out_elts): - raise ValueError("Mismatched tuple lengths between domain and output.") - - for d, o in zip(dom.elts, out_elts): - validate_domain_out(d, o, is_nested=True) - - else: - raise ValueError(f"'domain' must be Dict or TupleExpr, got {type(dom)}.") - - validate_domain_out(new_kwargs["domain"], new_kwargs["out"]) + assert isinstance(domain, (past.Dict, past.TupleExpr)) + out = new_kwargs["out"] + assert isinstance(out, past.Expr) and out.type is not None + _validate_domain_out(domain, out.type) class ProgramTypeDeduction(traits.VisitorWithSymbolTableTrait, NodeTranslator): @@ -157,24 +159,20 @@ def visit_Attribute(self, node: past.Attribute, **kwargs: Any) -> past.Attribute type=getattr(new_value.type, node.attr), ) + def visit_Dict(self, node: past.Dict, **kwargs: Any) -> past.Dict: + keys = self.visit(node.keys_, **kwargs) + assert all(isinstance(key, past.Name) for key in keys) + assert all(isinstance(key.type, ts.DimensionType) for key in keys) + return past.Dict( + keys_=keys, + values_=self.visit(node.values_, **kwargs), + location=node.location, + type=ts.DomainType(dims=[common.Dimension(key.id) for key in keys]), + ) + def visit_TupleExpr(self, node: past.TupleExpr, **kwargs: Any) -> past.TupleExpr: elts = self.visit(node.elts, **kwargs) - if any(isinstance(elt, past.Dict) for elt in elts): - assert all(isinstance(elt, (past.Dict, past.TupleExpr)) for elt in elts) - - def infer_type(elt: past.Dict | past.TupleExpr) -> ts.DomainType | ts.TupleType: - if isinstance(elt, past.Dict): - assert all(isinstance(key, past.Name) for key in elt.keys_) - assert all(isinstance(key.type, ts.DimensionType) for key in elt.keys_) - return ts.DomainType(dims=[common.Dimension(key.id) for key in elt.keys_]) - elif isinstance(elt, past.TupleExpr): - return ts.TupleType(types=[infer_type(elt) for elt in elt.elts]) - else: - raise AssertionError(f"Unexpected element type {type(elt)} inside TupleExpr") - - ttype = ts.TupleType(types=[infer_type(elt) for elt in elts]) - else: - ttype = ts.TupleType(types=[elt.type for elt in elts]) + ttype = ts.TupleType(types=[elt.type for elt in elts]) return past.TupleExpr(elts=elts, type=ttype, location=node.location) From bfbfd36aa4b7fdf8a99fc2bff4ecb2a9cb46bf7d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 7 Oct 2025 18:06:41 +0200 Subject: [PATCH 34/44] fix domain type deduction --- src/gt4py/next/ffront/past_passes/type_deduction.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index d281687d2d..fdf4106a50 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -8,7 +8,7 @@ from typing import Any, Optional, cast from gt4py.eve import NodeTranslator, traits -from gt4py.next import common, errors +from gt4py.next import errors from gt4py.next.ffront import ( dialect_ast_enums, program_ast as past, @@ -160,14 +160,14 @@ def visit_Attribute(self, node: past.Attribute, **kwargs: Any) -> past.Attribute ) def visit_Dict(self, node: past.Dict, **kwargs: Any) -> past.Dict: + # the only supported dict for now is in domain specification keys = self.visit(node.keys_, **kwargs) - assert all(isinstance(key, past.Name) for key in keys) assert all(isinstance(key.type, ts.DimensionType) for key in keys) return past.Dict( keys_=keys, values_=self.visit(node.values_, **kwargs), location=node.location, - type=ts.DomainType(dims=[common.Dimension(key.id) for key in keys]), + type=ts.DomainType(dims=[key.type.dim for key in keys]), ) def visit_TupleExpr(self, node: past.TupleExpr, **kwargs: Any) -> past.TupleExpr: From 93087d83dd941a54485b45d639e461fb0d2e3885 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 7 Oct 2025 21:50:47 +0200 Subject: [PATCH 35/44] cleanup tree_map like operations --- src/gt4py/next/embedded/operators.py | 16 +++++----------- src/gt4py/next/ffront/decorator.py | 19 +++++-------------- src/gt4py/next/iterator/ir_utils/misc.py | 15 +++++++++------ src/gt4py/next/utils.py | 1 + 4 files changed, 20 insertions(+), 31 deletions(-) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index a57df1f204..9398111e5c 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -131,17 +131,11 @@ def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) -> return op(*args, **kwargs) -def _get_vertical_range( - domain: xtyping.MaybeNestedInTuple[common.Domain], -) -> common.NamedRange | eve.NothingType | tuple[common.NamedRange | eve.NothingType, ...]: - if isinstance(domain, tuple): - return tuple(_get_vertical_range(dom) for dom in domain) - else: - vertical_dim_filtered = [ - nr for nr in domain if nr.dim.kind == common.DimensionKind.VERTICAL - ] - assert len(vertical_dim_filtered) <= 1 - return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING +@utils.tree_map +def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.NothingType: + vertical_dim_filtered = [nr for nr in domain if nr.dim.kind == common.DimensionKind.VERTICAL] + assert len(vertical_dim_filtered) <= 1 + return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING def _tuple_assign_field( diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 2a000c035e..da9cb9fb4a 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -33,6 +33,7 @@ embedded as next_embedded, errors, metrics, + utils, ) from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( @@ -629,18 +630,6 @@ def program_inner(definition: types.FunctionType) -> Program: OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) -def _slice_outs( - outs: xtyping.MaybeNestedInTuple[common.Field], - domains: xtyping.MaybeNestedInTuple[common.Domain], -) -> common.Field | tuple[common.Field | tuple, ...]: - if isinstance(outs, tuple): - if not isinstance(domains, tuple): - domains = tuple([domains] * len(outs)) - return tuple(_slice_outs(out, domain) for out, domain in zip(outs, domains, strict=True)) - else: - return outs[domains] - - @dataclasses.dataclass(frozen=True) class FieldOperator(GTCallable, Generic[OperatorNodeT]): """ @@ -782,8 +771,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: raise errors.MissingArgumentError(None, "out", True) out = kwargs.pop("out") if "domain" in kwargs: - dom = common.normalize_domains(kwargs.pop("domain")) - out = _slice_outs(out, dom) + domain = common.normalize_domains(kwargs.pop("domain")) + if not isinstance(domain, tuple): + domain = utils.tree_map(lambda _: domain)(out) + out = utils.tree_map(lambda f, dom: f[dom])(out, domain) args, kwargs = type_info.canonicalize_arguments( self.foast_stage.foast_node.type, args, kwargs diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index 964e91609e..08f4746cfb 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -8,7 +8,7 @@ import dataclasses from collections import ChainMap -from typing import Callable, Iterable, TypeVar +from typing import Callable, Iterable, TypeVar, cast from gt4py import eve from gt4py._core import definitions as core_defs @@ -229,17 +229,20 @@ def grid_type_from_domain(domain: itir.FunCall) -> common.GridType: return common.GridType.UNSTRUCTURED -def _flatten_tuple_expr(domain_expr: itir.Expr) -> tuple[itir.Expr]: - if cpm.is_call_to(domain_expr, "make_tuple"): - return sum((_flatten_tuple_expr(arg) for arg in domain_expr.args), start=()) +def _flatten_tuple_expr(expr: itir.Expr) -> tuple[itir.Expr]: + if cpm.is_call_to(expr, "make_tuple"): + return sum( + (_flatten_tuple_expr(arg) for arg in expr.args), start=cast(tuple[itir.Expr], ()) + ) else: - return (domain_expr,) + return (expr,) def grid_type_from_program(program: itir.Program) -> common.GridType: domain_exprs = program.walk_values().if_isinstance(itir.SetAt).getattr("domain").to_set() domains = sum((_flatten_tuple_expr(domain_expr) for domain_expr in domain_exprs), start=()) - grid_types = {grid_type_from_domain(d) for d in domains} + assert all(isinstance(d, itir.FunCall) for d in domains) + grid_types = {grid_type_from_domain(d) for d in domains} # type: ignore[arg-type] # checked above if len(grid_types) != 1: raise ValueError( f"Found 'set_at' with more than one 'GridType': '{grid_types}'. This is currently not supported." diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 75c0f68859..e03e5b8274 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -125,6 +125,7 @@ def tree_map( if result_collection_constructor is None: if isinstance(collection_type, tuple): + # Note: that doesn't mean `collection_type=tuple`, but e.g. `collection_type=(list, tuple)` raise TypeError( "tree_map() requires `result_collection_constructor` when `collection_type` is a tuple of types." ) From ddf95d0fccec024c6efa5890dfb61a4a45a07131 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 8 Oct 2025 13:04:48 +0200 Subject: [PATCH 36/44] refactor past_to_itir --- src/gt4py/next/ffront/past_to_itir.py | 189 +++++++++++++------------- 1 file changed, 91 insertions(+), 98 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 2d4da265d1..f21e25ee4d 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -164,10 +164,6 @@ def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension] return iter(scanops_per_axis.keys()).__next__() -def _range_arg_from_field(field_name: str, dim: int) -> str: - return f"__{field_name}_{dim}_range" - - def _flatten_tuple_expr(node: past.Expr) -> list[past.Name | past.Subscript | past.Dict]: if isinstance(node, (past.Name, past.Subscript, past.Dict)): return [node] @@ -181,6 +177,39 @@ def _flatten_tuple_expr(node: past.Expr) -> list[past.Name | past.Subscript | pa ) +def _compute_field_slice(node: past.Subscript) -> list[past.Slice]: + out_field_name: past.Name = node.value + out_field_slice_: list[past.Slice] + if isinstance(node.slice_, past.TupleExpr) and all( + isinstance(el, past.Slice) for el in node.slice_.elts + ): + out_field_slice_ = cast(list[past.Slice], node.slice_.elts) # type ensured by if + elif isinstance(node.slice_, past.Slice): + out_field_slice_ = [node.slice_] + else: + raise AssertionError( + "Unexpected 'out' argument, must be tuple of slices or slice expression." + ) + node_dims = cast(ts.FieldType, node.type).dims + assert isinstance(node_dims, list) + if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims): + raise errors.DSLError( + node.location, + f"Too many indices for field '{out_field_name}': field is {len(node_dims)}" + f"-dimensional, but {len(out_field_slice_)} were indexed.", + ) + return out_field_slice_ + + +def _get_element_from_tuple_expr(node: past.Expr, path: tuple[int, ...]) -> past.Expr: + """Get element from a (nested) TupleExpr by following the given path. + + Pre-condition: `node` is a `past.TupleExpr` (if `path ! = ()`) + and `path` is a valid path through the nested tuple structure. + """ + return functools.reduce(lambda e, i: e.elts[i], path, node) # type: ignore[attr-defined] # see pre-condition + + @dataclasses.dataclass class ProgramLowering( traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator @@ -322,12 +351,12 @@ def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: def _construct_itir_domain_arg( self, - out_field: past.Name | past.Subscript, + out_expr: itir.Expr, + out_type: ts.TypeSpec, node_domain: Optional[past.Expr], slices: Optional[list[past.Slice]] = None, ) -> itir.FunCall: - assert isinstance(out_field.type, ts.TypeSpec) - out_field_types = type_info.primitive_constituents(out_field.type).to_list() + out_field_types = type_info.primitive_constituents(out_type).to_list() out_dims = cast(ts.FieldType, out_field_types[0]).dims if any( @@ -335,19 +364,18 @@ def _construct_itir_domain_arg( for out_field_type in out_field_types ): raise AssertionError( - f"Expected constituents of '{out_field.id}' argument to be" + f"Expected constituents of '{out_expr}' argument to be" " fields defined on the same dimensions. This error should be " " caught in type deduction already." ) primitive_paths = [ - path for _, path in type_info.primitive_constituents(out_field.type, with_path_arg=True) + path for _, path in type_info.primitive_constituents(out_type, with_path_arg=True) ] - assert isinstance(out_field, (past.Name, past.Subscript)) or (hasattr(out_field, "id")) tuple_elements = [ functools.reduce( lambda expr, i: im.tuple_get(i, expr), path, - out_field.value.id if isinstance(out_field, past.Subscript) else out_field.id, + out_expr, ) for path in primitive_paths ] @@ -402,7 +430,7 @@ def _construct_itir_domain_arg( return itir.FunCall( fun=itir.SymRef(id=domain_builtin), args=domain_args, - location=(node_domain or out_field).location, + location=(node_domain or out_expr).location, ) def _construct_itir_initialized_domain_arg( @@ -418,100 +446,65 @@ def _construct_itir_initialized_domain_arg( return [self.visit(bound) for bound in node_domain.values_[dim_i].elts] - @staticmethod - def _compute_field_slice(node: past.Subscript) -> list[past.Slice]: - out_field_name: past.Name = node.value - out_field_slice_: list[past.Slice] - if isinstance(node.slice_, past.TupleExpr) and all( - isinstance(el, past.Slice) for el in node.slice_.elts - ): - out_field_slice_ = cast(list[past.Slice], node.slice_.elts) # type ensured by if - elif isinstance(node.slice_, past.Slice): - out_field_slice_ = [node.slice_] + def _split_field_and_slice( + self, field: past.Name | past.Subscript + ) -> tuple[itir.SymRef, list[past.Slice] | None]: + if isinstance(field, past.Subscript): + return self.visit(field.value), _compute_field_slice(field) else: - raise AssertionError( - "Unexpected 'out' argument, must be tuple of slices or slice expression." - ) - node_dims = cast(ts.FieldType, node.type).dims - assert isinstance(node_dims, list) - if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims): - raise errors.DSLError( - node.location, - f"Too many indices for field '{out_field_name}': field is {len(node_dims)}" - f"-dimensional, but {len(out_field_slice_)} were indexed.", - ) - return out_field_slice_ + assert isinstance(field, past.Name) + return self.visit(field), None + + def _get_field_and_slice( + self, out_arg: past.TupleExpr, path: tuple[int, ...] + ) -> tuple[itir.SymRef, list[past.Slice] | None]: + """Extract field and its slice for a given path through the tuple structure.""" + current_field = _get_element_from_tuple_expr(out_arg, path) + assert isinstance(current_field, (past.Name, past.Subscript)) + return self._split_field_and_slice(current_field) def _visit_stencil_call_out_arg( self, out_arg: past.Expr, domain_arg: Optional[past.Expr], **kwargs: Any ) -> tuple[itir.Expr, itir.FunCall]: - if isinstance(out_arg, past.Subscript): - # as the ITIR does not support slicing a field we have to do a deeper - # inspection of the PAST to emulate the behaviour - out_field_name: past.Name = out_arg.value - return ( - self._construct_itir_out_arg(out_field_name), - self._construct_itir_domain_arg( - out_field_name, domain_arg, self._compute_field_slice(out_arg) - ), - ) - elif isinstance(out_arg, past.Name) and isinstance(out_arg.type, ts.FieldType): - return ( - self._construct_itir_out_arg(out_arg), - self._construct_itir_domain_arg(out_arg, domain_arg), - ) - elif isinstance(out_arg, past.TupleExpr) or ( - isinstance(out_arg, past.Name) and isinstance(out_arg.type, ts.TupleType) - ): - - def get_field_and_slice(path: tuple[int, ...]): - """Extract field and its slice for a given path through the tuple structure.""" - current_field = functools.reduce(lambda e, i: e.elts[i], path, out_arg) + assert isinstance(out_arg, (past.Subscript, past.Name, past.TupleExpr)), ( + "Unexpected 'out' argument. Must be a 'past.Subscript', 'past.Name' or 'past.TupleExpr' node." + ) - if isinstance(current_field, past.Subscript): - return current_field.value, self._compute_field_slice(current_field) - else: - return current_field, None - - domain_expr = type_info.apply_to_primitive_constituents( - lambda field_type, path: ( - lambda field, slice_info: self._construct_itir_domain_arg( - field, - functools.reduce(lambda e, i: e.elts[i], path, domain_arg) - if isinstance(domain_arg, past.TupleExpr) - else domain_arg, - slice_info, - ) - )(*get_field_and_slice(path)) - if isinstance(out_arg, past.TupleExpr) - else self._construct_itir_domain_arg( - # Create a temporary past.Name-like object that carries the indexed information - type( - "NameLikeObject", - (), - { - "id": functools.reduce( - lambda expr, i: im.tuple_get(i, expr), path, out_arg.id - ), - "type": field_type, - "location": out_arg.location, - }, - )(), - functools.reduce(lambda e, i: e.elts[i], path, domain_arg) - if isinstance(domain_arg, past.TupleExpr) - else domain_arg, - None, # Name with TupleType doesn't support per-field slicing - ), - out_arg.type, - with_path_arg=True, - tuple_constructor=im.make_tuple, + def generate_nested_domain_expr( + field_type: ts.TypeSpec, path: tuple[int, ...] + ) -> itir.FunCall: + if isinstance(out_arg, past.TupleExpr): + # If the out_arg is a TupleExpr, we directly extract the node... + expr, slice_info = self._get_field_and_slice(out_arg, path) + else: + # ... otherwise we construct an expression to extract the field from a symbol. + # Note this code path works for + # - `out_arg` being a single field with and without slicing + # - `out_arg` being a (nested) tuple of fields (always without slicing) + name, slice_info = self._split_field_and_slice(out_arg) + expr = functools.reduce(lambda expr, i: im.tuple_get(i, expr), path, name) + + domain_element = ( + _get_element_from_tuple_expr(domain_arg, path) + if isinstance(domain_arg, past.TupleExpr) + else domain_arg ) - return self._construct_itir_out_arg(out_arg), domain_expr - else: - raise AssertionError( - "Unexpected 'out' argument. Must be a 'past.Subscript', 'past.Name' or 'past.TupleExpr' node." + return self._construct_itir_domain_arg( + expr, + field_type, + domain_element, + slice_info, ) + assert out_arg.type is not None + domain_expr = type_info.apply_to_primitive_constituents( + generate_nested_domain_expr, + out_arg.type, + with_path_arg=True, + tuple_constructor=im.make_tuple, + ) + return self._construct_itir_out_arg(out_arg), domain_expr + def visit_Constant(self, node: past.Constant, **kwargs: Any) -> itir.Literal: if isinstance(node.type, ts.ScalarType) and node.type.shape is None: match node.type.kind: @@ -525,7 +518,7 @@ def visit_Constant(self, node: past.Constant, **kwargs: Any) -> itir.Literal: raise NotImplementedError("Only scalar literals supported currently.") def visit_Name(self, node: past.Name, **kwargs: Any) -> itir.SymRef: - return itir.SymRef(id=node.id) + return itir.SymRef(id=node.id, location=node.location) def visit_Symbol(self, node: past.Symbol, **kwargs: Any) -> itir.Sym: return itir.Sym(id=node.id, type=node.type) From 2b33bc559db8343961fec398ffb1e14df52142a9 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 19 Oct 2025 18:47:00 +0200 Subject: [PATCH 37/44] Simplify and cleanup past_to_itir --- .../next/ffront/past_passes/type_deduction.py | 12 +- src/gt4py/next/ffront/past_to_itir.py | 178 +++++++----------- src/gt4py/next/ffront/program_ast.py | 8 +- src/gt4py/next/utils.py | 70 ++++++- 4 files changed, 146 insertions(+), 122 deletions(-) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index fdf4106a50..42330cbaaf 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -44,10 +44,11 @@ def _validate_domain_out( ) -> None: if isinstance(dom, past.Dict): # Only reject tuple outputs if nested - if is_nested and (isinstance(out, past.TupleExpr) or isinstance(out, ts.TupleType)): + if is_nested and isinstance(out, ts.TupleType): raise ValueError("Domain dict cannot map to tuple outputs.") + assert not (is_nested and isinstance(out, past.TupleExpr)) - if len(dom.values_) == 0 and len(dom.keys_) == 0: + if len(dom.keys_) == 0: raise ValueError("Empty domain not allowed.") for dim in dom.keys_: @@ -61,9 +62,7 @@ def _validate_domain_out( raise ValueError( f"Only 2 values allowed in domain range, got {len(domain_values.elts)}." ) - if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar( - domain_values.elts[1] - ): + if any(not _is_integral_scalar(el) for el in domain_values.elts): raise ValueError( f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'." ) @@ -77,7 +76,7 @@ def _validate_domain_out( if len(dom.elts) != len(out_elts): raise ValueError("Mismatched tuple lengths between domain and output.") - for d, o in zip(dom.elts, out_elts): + for d, o in zip(dom.elts, out_elts, strict=True): assert isinstance(d, (past.Dict, past.TupleExpr)) _validate_domain_out(d, o, is_nested=True) @@ -103,7 +102,6 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict) -> None: raise ValueError("Missing required keyword argument 'out'.") if (domain := new_kwargs.get("domain")) is not None: _ensure_no_sliced_field(new_kwargs["out"]) - assert isinstance(domain, (past.Dict, past.TupleExpr)) out = new_kwargs["out"] assert isinstance(out, past.Expr) and out.type is not None _validate_domain_out(domain, out.type) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index f21e25ee4d..ba5194a6b0 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -10,7 +10,7 @@ import dataclasses import functools -from typing import Any, Optional, cast +from typing import Any, Optional, Sequence, cast import devtools @@ -164,19 +164,6 @@ def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension] return iter(scanops_per_axis.keys()).__next__() -def _flatten_tuple_expr(node: past.Expr) -> list[past.Name | past.Subscript | past.Dict]: - if isinstance(node, (past.Name, past.Subscript, past.Dict)): - return [node] - elif isinstance(node, past.TupleExpr): - result = [] - for e in node.elts: - result.extend(_flatten_tuple_expr(e)) - return result - raise ValueError( - f"Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed, got '{type(node)}'." - ) - - def _compute_field_slice(node: past.Subscript) -> list[past.Slice]: out_field_name: past.Name = node.value out_field_slice_: list[past.Slice] @@ -210,6 +197,20 @@ def _get_element_from_tuple_expr(node: past.Expr, path: tuple[int, ...]) -> past return functools.reduce(lambda e, i: e.elts[i], path, node) # type: ignore[attr-defined] # see pre-condition +def _unwrap_tuple_expr(expr: past.Expr, path: tuple[int, ...]) -> tuple[past.Expr, Sequence[int]]: + """Unwrap (nested) TupleExpr by following the given path as long as possible. + + If a non-tuple expression is encountered, the current expression and the remaining path are + returned. + """ + path_remainder: Sequence[int] = path + while isinstance(expr, past.TupleExpr): + idx, *path_remainder = path_remainder + expr = expr.elts[idx] + + return expr, path_remainder + + @dataclasses.dataclass class ProgramLowering( traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator @@ -352,73 +353,46 @@ def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: def _construct_itir_domain_arg( self, out_expr: itir.Expr, - out_type: ts.TypeSpec, + out_type: ts.FieldType, node_domain: Optional[past.Expr], slices: Optional[list[past.Slice]] = None, ) -> itir.FunCall: - out_field_types = type_info.primitive_constituents(out_type).to_list() - - out_dims = cast(ts.FieldType, out_field_types[0]).dims - if any( - not isinstance(out_field_type, ts.FieldType) or out_field_type.dims != out_dims - for out_field_type in out_field_types - ): - raise AssertionError( - f"Expected constituents of '{out_expr}' argument to be" - " fields defined on the same dimensions. This error should be " - " caught in type deduction already." - ) - primitive_paths = [ - path for _, path in type_info.primitive_constituents(out_type, with_path_arg=True) - ] - tuple_elements = [ - functools.reduce( - lambda expr, i: im.tuple_get(i, expr), - path, - out_expr, + domain_args = [] + for dim_i, dim in enumerate(out_type.dims): + # an expression for the range of a dimension + dim_range = im.call("get_domain_range")( + out_expr, itir.AxisLiteral(value=dim.value, kind=dim.kind) ) - for path in primitive_paths - ] - domain_args = [] - for el in tuple_elements: - for dim_i, dim in enumerate(out_dims): - # an expression for the range of a dimension - dim_range = im.call("get_domain_range")( - el, itir.AxisLiteral(value=dim.value, kind=dim.kind) + dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range) + # bounds + lower: itir.Expr + upper: itir.Expr + if node_domain is not None: + assert isinstance(node_domain, past.Dict) + lower, upper = self._construct_itir_initialized_domain_arg(dim_i, dim, node_domain) + else: + lower = self._visit_slice_bound( + slices[dim_i].lower if slices else None, + dim_start, + dim_start, + dim_stop, + ) + upper = self._visit_slice_bound( + slices[dim_i].upper if slices else None, + dim_stop, + dim_start, + dim_stop, ) - dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range) - # bounds - lower: itir.Expr - upper: itir.Expr - if node_domain is not None: - assert isinstance(node_domain, past.Dict) - lower, upper = self._construct_itir_initialized_domain_arg( - dim_i, dim, node_domain - ) - else: - lower = self._visit_slice_bound( - slices[dim_i].lower if slices else None, - dim_start, - dim_start, - dim_stop, - ) - upper = self._visit_slice_bound( - slices[dim_i].upper if slices else None, - dim_stop, - dim_start, - dim_stop, - ) - - if dim.kind == common.DimensionKind.LOCAL: - raise ValueError(f"common.Dimension '{dim.value}' must not be local.") - domain_args.append( - itir.FunCall( - fun=itir.SymRef(id="named_range"), - args=[itir.AxisLiteral(value=dim.value, kind=dim.kind), lower, upper], - ) + if dim.kind == common.DimensionKind.LOCAL: + raise ValueError(f"common.Dimension '{dim.value}' must not be local.") + domain_args.append( + itir.FunCall( + fun=itir.SymRef(id="named_range"), + args=[itir.AxisLiteral(value=dim.value, kind=dim.kind), lower, upper], ) + ) if self.grid_type == common.GridType.CARTESIAN: domain_builtin = "cartesian_domain" @@ -448,20 +422,12 @@ def _construct_itir_initialized_domain_arg( def _split_field_and_slice( self, field: past.Name | past.Subscript - ) -> tuple[itir.SymRef, list[past.Slice] | None]: + ) -> tuple[past.Name, list[past.Slice] | None]: if isinstance(field, past.Subscript): - return self.visit(field.value), _compute_field_slice(field) + return field.value, _compute_field_slice(field) else: assert isinstance(field, past.Name) - return self.visit(field), None - - def _get_field_and_slice( - self, out_arg: past.TupleExpr, path: tuple[int, ...] - ) -> tuple[itir.SymRef, list[past.Slice] | None]: - """Extract field and its slice for a given path through the tuple structure.""" - current_field = _get_element_from_tuple_expr(out_arg, path) - assert isinstance(current_field, (past.Name, past.Subscript)) - return self._split_field_and_slice(current_field) + return field, None def _visit_stencil_call_out_arg( self, out_arg: past.Expr, domain_arg: Optional[past.Expr], **kwargs: Any @@ -470,40 +436,36 @@ def _visit_stencil_call_out_arg( "Unexpected 'out' argument. Must be a 'past.Subscript', 'past.Name' or 'past.TupleExpr' node." ) - def generate_nested_domain_expr( - field_type: ts.TypeSpec, path: tuple[int, ...] - ) -> itir.FunCall: - if isinstance(out_arg, past.TupleExpr): - # If the out_arg is a TupleExpr, we directly extract the node... - expr, slice_info = self._get_field_and_slice(out_arg, path) - else: - # ... otherwise we construct an expression to extract the field from a symbol. - # Note this code path works for - # - `out_arg` being a single field with and without slicing - # - `out_arg` being a (nested) tuple of fields (always without slicing) - name, slice_info = self._split_field_and_slice(out_arg) - expr = functools.reduce(lambda expr, i: im.tuple_get(i, expr), path, name) + @gtx_utils.tree_map( + collection_type=ts.TupleType, + with_path_arg=True, + unpack=True, + result_collection_constructor=lambda elts: im.make_tuple(*elts), + ) + def impl(out_type: ts.FieldType, path: tuple[int, ...]) -> tuple[itir.Expr, itir.Expr]: + out_field, path_remainder = _unwrap_tuple_expr(out_arg, path) + + assert isinstance(out_field, (past.Name, past.Subscript)) + out_field, slice_info = self._split_field_and_slice(out_field) domain_element = ( _get_element_from_tuple_expr(domain_arg, path) if isinstance(domain_arg, past.TupleExpr) else domain_arg ) - return self._construct_itir_domain_arg( - expr, - field_type, + + lowered_out_field = functools.reduce( + lambda expr, i: im.tuple_get(i, expr), path_remainder, self.visit(out_field) + ) + lowered_domain = self._construct_itir_domain_arg( + lowered_out_field, + out_type, domain_element, slice_info, ) + return lowered_out_field, lowered_domain - assert out_arg.type is not None - domain_expr = type_info.apply_to_primitive_constituents( - generate_nested_domain_expr, - out_arg.type, - with_path_arg=True, - tuple_constructor=im.make_tuple, - ) - return self._construct_itir_out_arg(out_arg), domain_expr + return impl(out_arg.type) def visit_Constant(self, node: past.Constant, **kwargs: Any) -> itir.Literal: if isinstance(node.type, ts.ScalarType) and node.type.shape is None: diff --git a/src/gt4py/next/ffront/program_ast.py b/src/gt4py/next/ffront/program_ast.py index ea579aa211..9e0eb30939 100644 --- a/src/gt4py/next/ffront/program_ast.py +++ b/src/gt4py/next/ffront/program_ast.py @@ -9,7 +9,7 @@ from typing import Any, Generic, Literal, Optional, TypeVar, Union import gt4py.eve as eve -from gt4py.eve import Coerced, Node, SourceLocation, SymbolName, SymbolRef +from gt4py.eve import Coerced, Node, SourceLocation, SymbolName, SymbolRef, datamodels from gt4py.eve.traits import SymbolTableTrait from gt4py.next.ffront import dialect_ast_enums, type_specifications as ts_ffront from gt4py.next.type_system import type_specifications as ts @@ -85,6 +85,12 @@ class Dict(Expr): keys_: list[Union[Name | Attribute]] values_: list[TupleExpr] + @datamodels.root_validator + @classmethod + def keys_values_length_validation(cls: type["Dict"], instance: "Dict") -> None: + if len(instance.keys_) != len(instance.values_): + raise ValueError("`Dict` must have same number of keys as values.") + class Slice(Expr): lower: Optional[Constant] diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index e03e5b8274..e738c7eb26 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -8,7 +8,18 @@ import functools import itertools -from typing import Any, Callable, ClassVar, Optional, ParamSpec, TypeGuard, TypeVar, cast, overload +from typing import ( + Any, + Callable, + ClassVar, + Optional, + ParamSpec, + Sequence, + TypeGuard, + TypeVar, + cast, + overload, +) class RecursionGuard: @@ -74,6 +85,8 @@ def tree_map( *, collection_type: type | tuple[type, ...] = tuple, result_collection_constructor: Optional[type | Callable] = None, + unpack: bool = False, + with_path_arg: bool = False, ) -> Callable[..., _R | tuple[_R | tuple, ...]]: ... @@ -82,6 +95,8 @@ def tree_map( *, collection_type: type | tuple[type, ...] = tuple, result_collection_constructor: Optional[type | Callable] = None, + unpack: bool = False, + with_path_arg: bool = False, ) -> Callable[ [Callable[_P, _R]], Callable[..., Any] ]: ... # TODO(havogt): if result_collection_constructor is Callable, improve typing @@ -92,6 +107,8 @@ def tree_map( *, collection_type: type | tuple[type, ...] = tuple, result_collection_constructor: Optional[type | Callable] = None, + unpack: bool = False, + with_path_arg: bool = False, ) -> Callable[..., _R | tuple[_R | tuple, ...]] | Callable[[Callable[_P, _R]], Callable[..., Any]]: """ Apply `fun` to each entry of (possibly nested) collections (by default `tuple`s). @@ -100,7 +117,9 @@ def tree_map( fun: Function to apply to each entry of the collection. collection_type: Type of the collection to be traversed. Can be a single type or a tuple of types. result_collection_constructor: Type of the collection to be returned. If `None` the same type as `collection_type` is used. - + unpack: Replicate tuple structure returned from `fun` to the mapped result, i.e. return + tuple of result collections instead of result collections of tuples. + with_path_arg: Pass the path to access the current element to `fun`. Examples: >>> tree_map(lambda x: x + 1)(((1, 2), 3)) ((2, 3), 4) @@ -121,6 +140,27 @@ def tree_map( ... return x + 1 >>> impl(((1, 2), 3)) ((2, 3), 4) + + >>> @tree_map(with_path_arg=True) + ... def impl(x, path: tuple[int, ...]): + ... path_str = "".join(f"[{i}]" for i in path) + ... return f"t{path_str} = {x}" + >>> t = impl(((1, 2), 3)) + >>> t[0][0] + 't[0][0] = 1' + >>> t[0][1] + 't[0][1] = 2' + >>> t[0][0] + 't[1] = 3' + + >>> @tree_map(unpack=True) + ... def impl(x): + ... return (x, x**2) + >>> identity, squared = impl(((2, 3), 4)) + >>> identity + ((2, 3), 4) + >>> squared + ((4, 9), 16) """ if result_collection_constructor is None: @@ -136,22 +176,40 @@ def tree_map( @functools.wraps(fun) def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: if isinstance(args[0], collection_type): + non_path_args: Sequence[Any] + if with_path_arg: + *non_path_args, path = args + args = (*non_path_args, tuple((*path, i) for i in range(len(args[0])))) + else: + non_path_args = args + assert all( - isinstance(arg, collection_type) and len(args[0]) == len(arg) for arg in args + isinstance(arg, collection_type) and len(args[0]) == len(arg) + for arg in non_path_args ) assert result_collection_constructor is not None - return result_collection_constructor(impl(*arg) for arg in zip(*args)) + + mapped = [impl(*arg) for arg in zip(*args)] + if unpack: + return tuple(map(result_collection_constructor, zip(*mapped))) + else: + return result_collection_constructor(mapped) return fun( # type: ignore[call-arg] - *cast(_P.args, args) # type: ignore[valid-type] + *cast(_P.args, args), # type: ignore[valid-type] ) # mypy doesn't understand that `args` at this point is of type `_P.args` - return impl + if with_path_arg: + return lambda *args: impl(*args, ()) + else: + return impl else: return functools.partial( tree_map, collection_type=collection_type, result_collection_constructor=result_collection_constructor, + unpack=unpack, + with_path_arg=with_path_arg, ) From 35a8966071bc1b9645bb9ba994f434758ab29728 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 19 Oct 2025 19:41:34 +0200 Subject: [PATCH 38/44] Fix doctest --- src/gt4py/next/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index e738c7eb26..12e4118e8d 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -150,7 +150,7 @@ def tree_map( 't[0][0] = 1' >>> t[0][1] 't[0][1] = 2' - >>> t[0][0] + >>> t[1] 't[1] = 3' >>> @tree_map(unpack=True) From 1888c2d0d46d44f5d79db7c31ba9aaf029a68f4d Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 29 Oct 2025 16:27:22 +0100 Subject: [PATCH 39/44] SDFG lowering of multiple output domains (#16) --- .../runners/dace/gtir_domain.py | 26 +++- .../runners/dace/gtir_to_sdfg.py | 115 ++++++++++++------ .../runners/dace/gtir_to_sdfg_primitives.py | 27 ++-- .../runners/dace/gtir_to_sdfg_scan.py | 110 ++++++++++------- .../runners/dace/gtir_to_sdfg_types.py | 21 ++-- .../runners/dace/gtir_to_sdfg_utils.py | 2 + .../ffront_tests/test_compiled_program.py | 1 + .../dace_tests/test_gtir_to_sdfg.py | 53 ++++++-- 8 files changed, 235 insertions(+), 120 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_domain.py b/src/gt4py/next/program_processors/runners/dace/gtir_domain.py index fb47f9e1d6..68bd1976b2 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_domain.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_domain.py @@ -14,8 +14,11 @@ import dace from dace import subsets as dace_subsets +from gt4py import eve +from gt4py.eve.extended_typing import MaybeNestedInTuple from gt4py.next import common as gtx_common -from gt4py.next.iterator.ir_utils import domain_utils +from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils from gt4py.next.program_processors.runners.dace import gtir_to_sdfg_utils @@ -57,6 +60,26 @@ def get_field_domain(domain: domain_utils.SymbolicDomain) -> FieldopDomain: ] +TargetDomain: TypeAlias = MaybeNestedInTuple[domain_utils.SymbolicDomain] +"""Symbolic domain which defines the range to write in the target field. + +For tuple output, the corresponding domain in fieldview is a tuple of domains. +""" + + +class TargetDomainParser(eve.visitors.NodeTranslator): + """Visitor class to build a `TargetDomain` symbolic domain.""" + + def visit_FunCall(self, node: gtir.FunCall) -> TargetDomain: + if cpm.is_call_to(node, "make_tuple"): + return tuple(self.visit(arg) for arg in node.args) + else: + return domain_utils.SymbolicDomain.from_expr(node) + + def apply(cls, node: gtir.Expr) -> TargetDomain: + return cls.visit(node) + + def get_domain_indices( dims: Sequence[gtx_common.Dimension], origin: Optional[Sequence[dace.symbolic.SymExpr]] ) -> dace_subsets.Indices: @@ -99,7 +122,6 @@ def get_field_layout( Args: field_domain: The field operator domain. - target_domain: Domain of the target field in the root `SetAt` expression. Returns: A tuple of three lists containing: diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py index 074ccb3626..5ebe84e337 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py @@ -683,7 +683,7 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: def visit_SetAt( self, stmt: gtir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState - ) -> dace.SDFGState: + ) -> dace.SDFGState | None: """Visits a `SetAt` statement expression and writes the local result to some external storage. Each statement expression results in some sort of dataflow gragh writing to temporary storage. @@ -694,8 +694,7 @@ def visit_SetAt( """ # Visit the domain expression. - assert isinstance(stmt.domain.type, ts.DomainType) - domain = domain_utils.SymbolicDomain.from_expr(stmt.domain) + domain = gtir_domain.TargetDomainParser().apply(stmt.domain) # Visit the field operator expression. source_tree = self._visit_expression(stmt.expr, sdfg, state) @@ -754,11 +753,24 @@ def _visit_target( ), ) - gtx_utils.tree_map( - lambda source, target, _domain=domain, _target_state=target_state: _visit_target( - source, target, _domain, _target_state - ) - )(source_tree, target_tree) + if isinstance(target_tree, tuple) and not isinstance(domain, tuple): + # This branch handles a specific case that indeed never happens in + # fieldview GTIR, only in iterator GTIR tests. The case corresponds + # to 'as_fieldop' with tuple output and single domain, which is a format + # used when multiple 'as_fieldop' are fused into one. The input to SDFG + # lowering is fieldview IR, where 'as_fieldop' will always have a single + # domain and the frontend will never emit 'as_fieldop' with tuple output. + gtx_utils.tree_map( + lambda source, target, domain_=domain, target_state_=target_state: _visit_target( + source, target, domain_, target_state_ + ) + )(source_tree, target_tree) + else: + gtx_utils.tree_map( + lambda source, target, domain_, target_state_=target_state: _visit_target( + source, target, domain_, target_state_ + ) + )(source_tree, target_tree, domain) if target_state.is_empty(): sdfg.remove_node(target_state) @@ -808,12 +820,12 @@ def visit_FunCall( symbolic_args[str(p.id)] = symbolic_expr # All other lambda arguments are lowered to some dataflow that produces a data node. args = { - str(p.id): ( - gtir_to_sdfg_types.SymbolicData(p.type, symbolic_args[param]) # type: ignore[arg-type] - if (param := str(p.id)) in symbolic_args + param: ( + gtir_to_sdfg_types.SymbolicData(param.type, symbolic_args[pname]) # type: ignore[arg-type] + if (pname := str(param.id)) in symbolic_args else self.visit(arg, ctx=ctx) ) - for p, arg in zip(node.fun.params, node.args, strict=True) + for param, arg in zip(node.fun.params, node.args, strict=True) } return self.visit(node.fun, ctx=ctx, args=args) elif isinstance(node.type, ts.ScalarType): @@ -825,7 +837,7 @@ def visit_Lambda( self, node: gtir.Lambda, ctx: SubgraphContext, - args: Mapping[str, gtir_to_sdfg_types.FieldopResult | gtir_to_sdfg_types.SymbolicData], + args: Mapping[gtir.Sym, gtir_to_sdfg_types.FieldopResult | gtir_to_sdfg_types.SymbolicData], ) -> gtir_to_sdfg_types.FieldopResult: """ Translates a `Lambda` node to a nested SDFG in the current state. @@ -846,20 +858,24 @@ def visit_Lambda( the previous symbol during traversal of the lambda expression. """ - data_args = { - param: arg - for param, arg in args.items() - if not isinstance(arg, gtir_to_sdfg_types.SymbolicData) - } symbolic_args = { - param: arg + str(param.id): arg for param, arg in args.items() if isinstance(arg, gtir_to_sdfg_types.SymbolicData) } + data_args: dict[str, gtir_to_sdfg_types.FieldopResult] = { + str(param.id): arg # type: ignore[misc] # symbolic args are filtered out + for param, arg in args.items() + if arg is not None and param.id not in symbolic_args + } lambda_arg_nodes = dict( itertools.chain( - *[gtir_to_sdfg_types.flatten_tuples(param, arg) for param, arg in data_args.items()] + *[ + gtir_to_sdfg_types.flatten_tuple(param, arg) # type: ignore[arg-type] # symbolic args are filtered out + for param, arg in args.items() + if param.id in data_args + ] ) ) @@ -867,14 +883,16 @@ def visit_Lambda( lambda_symbols = { sym: self.global_symbols[sym] for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) - } | { - param: gtir_to_sdfg_types.get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type - for param, arg in args.items() - } + } | {str(param.id): param.type for param, arg in args.items() if arg is not None} + assert all(isinstance(_type, ts.DataType) for _type in lambda_symbols.values()) # lower let-statement lambda node as a nested SDFG lamnda_translator, lambda_ctx = self.setup_nested_context( - node, "lambda", ctx, lambda_symbols, symbolic_inputs=set(symbolic_args.keys()) + expr=node, + sdfg_name="lambda", + parent_ctx=ctx, + scope_symbols=lambda_symbols, # type: ignore[arg-type] # lambda_symbols checked by assert above + symbolic_inputs=set(symbolic_args.keys()), ) lambda_result = lamnda_translator.visit(node.expr, ctx=lambda_ctx) @@ -893,11 +911,21 @@ def visit_Lambda( for nsdfg_dataname, nsdfg_datadesc in lambda_ctx.sdfg.arrays.items(): if nsdfg_datadesc.transient: continue - - if nsdfg_dataname in lambda_arg_nodes: - src_node = lambda_arg_nodes[nsdfg_dataname].dc_node - dataname = src_node.data - datadesc = src_node.desc(ctx.sdfg) + elif nsdfg_dataname in lambda_arg_nodes: + arg_node = lambda_arg_nodes[nsdfg_dataname] + if arg_node is None: + # This argument has empty domain, which means that it should not be + # used inside the nested SDFG, and does not need to be connected outside. + assert all( + node.data != nsdfg_dataname + for node in lambda_ctx.sdfg.all_nodes_recursive() + if isinstance(node, dace.nodes.AccessNode) + ) + lambda_ctx.sdfg.arrays[nsdfg_dataname].transient = True + continue + else: + dataname = arg_node.dc_node.data + datadesc = arg_node.dc_node.desc(ctx.sdfg) else: dataname = nsdfg_dataname datadesc = ctx.sdfg.arrays[nsdfg_dataname] @@ -939,7 +967,8 @@ def visit_Lambda( nsdfg_symbols_mapping = {} for sym in lambda_ctx.sdfg.free_symbols: if (sym_id := str(sym)) in lambda_arg_nodes: - assert isinstance(lambda_arg_nodes[sym_id].gt_type, ts.ScalarType) + arg_node = lambda_arg_nodes[sym_id] + assert arg_node and isinstance(arg_node.gt_type, ts.ScalarType) raise NotImplementedError( "Unexpected mapping of scalar node to symbol on nested SDFG." ) @@ -958,13 +987,24 @@ def visit_Lambda( debuginfo=gtir_to_sdfg_utils.debug_info(node, default=ctx.sdfg.debuginfo), ) - for connector, memlet in input_memlets.items(): - if connector in lambda_arg_nodes: - src_node = lambda_arg_nodes[connector].dc_node + for input_connector, memlet in input_memlets.items(): + if input_connector in lambda_arg_nodes: + arg_node = lambda_arg_nodes[input_connector] + if arg_node is None: + # this argument has empty domain, therefore it should not be used inside the nested SDFG + assert all( + node.data != input_connector + for node in lambda_ctx.sdfg.all_nodes_recursive() + if isinstance(node, dace.nodes.AccessNode) + ) + lambda_ctx.sdfg.arrays[input_connector].transient = True + continue + else: + src_node = arg_node.dc_node else: src_node = ctx.state.add_access(memlet.data) - ctx.state.add_edge(src_node, None, nsdfg_node, connector, memlet) + ctx.state.add_edge(src_node, None, nsdfg_node, input_connector, memlet) def construct_output_for_nested_sdfg( inner_data: gtir_to_sdfg_types.FieldopData, @@ -1003,7 +1043,10 @@ def construct_output_for_nested_sdfg( # Non-transient nodes are just input nodes that are immediately returned # by the lambda expression. Therefore, these nodes are already available # in the parent context and can be directly accessed there. - outer_data = lambda_arg_nodes[inner_dataname] + outer_arg = lambda_arg_nodes[inner_dataname] + if outer_arg is None: + raise ValueError(f"Unexpected argument with empty domain {inner_data}.") + outer_data = outer_arg else: # This must be a symbol captured from the lambda parent scope. outer_node = ctx.state.add_access(inner_dataname) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py index 409b32b5c3..78fab48a16 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py @@ -14,6 +14,7 @@ import dace from dace import subsets as dace_subsets +from gt4py.eve.extended_typing import MaybeNestedInTuple from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ( @@ -72,20 +73,16 @@ def _parse_fieldop_arg( ctx: gtir_to_sdfg.SubgraphContext, sdfg_builder: gtir_to_sdfg.SDFGBuilder, domain: gtir_domain.FieldopDomain, -) -> ( - gtir_dataflow.IteratorExpr - | gtir_dataflow.MemletExpr - | tuple[gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr | tuple[Any, ...], ...] -): - """Helper method to visit an expression passed as argument to a field operator.""" - +) -> MaybeNestedInTuple[gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr]: + """ + Helper method to visit an expression passed as argument to a field operator + and create the local view for the field argument. + """ arg = sdfg_builder.visit(node, ctx=ctx) - if isinstance(arg, gtir_to_sdfg_types.FieldopData): - return arg.get_local_view(domain, ctx.sdfg) - else: - # handle tuples of fields - return gtx_utils.tree_map(lambda targ: targ.get_local_view(domain))(arg) + if not isinstance(arg, gtir_to_sdfg_types.FieldopData): + raise ValueError("Expected a field, found a tuple of fields.") + return arg.get_local_view(domain, ctx.sdfg) def _create_field_operator_impl( @@ -234,8 +231,8 @@ def _create_field_operator( # handle tuples of fields output_symbol_tree = gtir_to_sdfg_utils.make_symbol_tree("x", node_type) return gtx_utils.tree_map( - lambda output_edge, output_sym: _create_field_operator_impl( - ctx, sdfg_builder, domain, output_edge, output_sym.type, map_exit + lambda edge, sym, ctx_=ctx: _create_field_operator_impl( + ctx_, sdfg_builder, domain, edge, sym.type, map_exit ) )(output_tree, output_symbol_tree) @@ -620,7 +617,7 @@ def translate_tuple_get( if isinstance(data_nodes, gtir_to_sdfg_types.FieldopData): raise ValueError(f"Invalid tuple expression {node}") unused_arg_nodes: Iterable[gtir_to_sdfg_types.FieldopData] = gtx_utils.flatten_nested_tuple( - tuple(arg for i, arg in enumerate(data_nodes) if i != index) + tuple(arg for i, arg in enumerate(data_nodes) if arg is not None and i != index) ) ctx.state.remove_nodes_from( [arg.dc_node for arg in unused_arg_nodes if ctx.state.degree(arg.dc_node) == 0] diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py index c8d2b16995..3550bd140a 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py @@ -29,6 +29,7 @@ from dace import subsets as dace_subsets from gt4py import eve +from gt4py.eve.extended_typing import MaybeNestedInTuple from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ( @@ -36,6 +37,7 @@ domain_utils, ir_makers as im, ) +from gt4py.next.iterator.transforms import infer_domain from gt4py.next.program_processors.runners.dace import ( gtir_dataflow, gtir_domain, @@ -88,11 +90,11 @@ def _parse_fieldop_arg_impl( def _create_scan_field_operator_impl( ctx: gtir_to_sdfg.SubgraphContext, sdfg_builder: gtir_to_sdfg.SDFGBuilder, - field_domain: gtir_domain.FieldopDomain, - output_edge: gtir_dataflow.DataflowOutputEdge, + output_edge: gtir_dataflow.DataflowOutputEdge | None, + output_domain: infer_domain.NonTupleDomainAccess, output_type: ts.FieldType, map_exit: dace.nodes.MapExit, -) -> gtir_to_sdfg_types.FieldopData: +) -> gtir_to_sdfg_types.FieldopData | None: """ Helper method to allocate a temporary array that stores one field computed by the scan field operator. @@ -108,6 +110,13 @@ def _create_scan_field_operator_impl( Refer to `gtir_to_sdfg_primitives._create_field_operator_impl()` for the description of function arguments and return values. """ + if output_edge is None: + # According to domain inference, this tuple field does not need to be computed. + assert output_domain == infer_domain.DomainAccessDescriptor.NEVER + return None + assert isinstance(output_domain, domain_utils.SymbolicDomain) + field_domain = gtir_domain.get_field_domain(output_domain) + dataflow_output_desc = output_edge.result.dc_node.desc(ctx.sdfg) assert isinstance(dataflow_output_desc, dace.data.Array) @@ -201,8 +210,8 @@ def _create_scan_field_operator( node_type: ts.FieldType | ts.TupleType, sdfg_builder: gtir_to_sdfg.SDFGBuilder, input_edges: Iterable[gtir_dataflow.DataflowInputEdge], - output_tree: gtir_dataflow.DataflowOutputEdge - | tuple[gtir_dataflow.DataflowOutputEdge | tuple[Any, ...], ...], + output: MaybeNestedInTuple[gtir_dataflow.DataflowOutputEdge | None], + output_domain: infer_domain.DomainAccess, ) -> gtir_to_sdfg_types.FieldopResult: """ Helper method to build the output of a field operator, which can consist of @@ -246,27 +255,28 @@ def _create_scan_field_operator( edge.connect(map_entry) if isinstance(node_type, ts.FieldType): - assert isinstance(output_tree, gtir_dataflow.DataflowOutputEdge) + assert isinstance(output, gtir_dataflow.DataflowOutputEdge) + assert isinstance(output_domain, domain_utils.SymbolicDomain) return _create_scan_field_operator_impl( - ctx, sdfg_builder, field_domain, output_tree, node_type, map_exit + ctx, sdfg_builder, output, output_domain, node_type, map_exit ) else: - # handle tuples of fields - # the symbol name 'x' in the call below is not used, we only need - # the tree structure of the `TupleType` definition to pass to `tree_map()` + # Handle tuples of fields. note that the symbol name 'x' in the call below + # is not used, we only need the tree structure of the `TupleType` definition + # to pass to `tree_map()` in order to retrieve the type of each field. output_symbol_tree = gtir_to_sdfg_utils.make_symbol_tree("x", node_type) return gtx_utils.tree_map( - lambda output_edge, output_sym: ( + lambda edge_, domain_, sym_, ctx_=ctx: ( _create_scan_field_operator_impl( - ctx, + ctx_, sdfg_builder, - field_domain, - output_edge, - output_sym.type, + edge_, + domain_, + sym_.type, map_exit, ) ) - )(output_tree, output_symbol_tree) + )(output, output_domain, output_symbol_tree) def _scan_input_name(input_name: str) -> str: @@ -375,6 +385,7 @@ def get_scan_output_shape( if isinstance(init_data, tuple): lambda_result_shape = gtx_utils.tree_map(get_scan_output_shape)(init_data) else: + assert init_data is not None lambda_result_shape = get_scan_output_shape(init_data) # Create the body of the initialization state @@ -558,6 +569,15 @@ def _connect_nested_sdfg_output_to_temporaries( return gtir_dataflow.DataflowOutputEdge(outer_ctx.state, output_expr) +def _remove_nested_sdfg_connector( + inner_ctx: gtir_to_sdfg.SubgraphContext, + nsdfg_node: dace.nodes.NestedSDFG, + inner_data: gtir_to_sdfg_types.FieldopData, +) -> None: + inner_data.dc_node.desc(inner_ctx.sdfg).transient = True + nsdfg_node.out_connectors.pop(inner_data.dc_node.data) + + def translate_scan( node: gtir.Node, ctx: gtir_to_sdfg.SubgraphContext, @@ -596,7 +616,9 @@ def translate_scan( assert isinstance(stencil_expr, gtir.Lambda) # params[0]: the lambda parameter to propagate the scan carry on the vertical dimension - scan_carry = str(stencil_expr.params[0].id) + scan_carry = stencil_expr.params[0].id + scan_carry_type = stencil_expr.params[0].type + assert isinstance(scan_carry_type, ts.DataType) # params[1]: boolean flag for forward/backward scan assert isinstance(scan_expr.args[1], gtir.Literal) and ti.is_logical(scan_expr.args[1].type) @@ -606,19 +628,13 @@ def translate_scan( init_expr = scan_expr.args[2] # visit the initialization value of the scan expression init_data = sdfg_builder.visit(init_expr, ctx=ctx) - # extract type definition of the scan carry - scan_carry_type = ( - init_data.gt_type - if isinstance(init_data, gtir_to_sdfg_types.FieldopData) - else gtir_to_sdfg_types.get_tuple_type(init_data) - ) # define the set of symbols available in the lambda context, which consists of # the carry argument and all lambda function arguments - lambda_arg_types = [scan_carry_type] + [ + lambda_arg_types: list[ts.DataType] = [scan_carry_type] + [ arg.type for arg in node.args if isinstance(arg.type, ts.DataType) ] - lambda_symbols = { + lambda_symbols: dict[str, ts.DataType] = { str(p.id): arg_type for p, arg_type in zip(stencil_expr.params, lambda_arg_types, strict=True) } @@ -641,14 +657,11 @@ def translate_scan( lambda_args = [sdfg_builder.visit(arg, ctx=ctx) for arg in node.args] lambda_args_mapping = [ (im.sym(_scan_input_name(scan_carry), scan_carry_type), init_data), - ] + [ - (im.sym(param.id, arg.gt_type), arg) - for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) - ] + ] + [(param, arg) for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True)] lambda_arg_nodes = dict( itertools.chain( - *[gtir_to_sdfg_types.flatten_tuples(psym.id, arg) for psym, arg in lambda_args_mapping] + *[gtir_to_sdfg_types.flatten_tuple(psym, arg) for psym, arg in lambda_args_mapping] ) ) @@ -677,24 +690,37 @@ def translate_scan( symbol_mapping=nsdfg_symbols_mapping, ) - lambda_input_edges = [] + input_edges = [] for input_connector, outer_arg in lambda_arg_nodes.items(): - arg_desc = outer_arg.dc_node.desc(ctx.sdfg) - input_subset = dace_subsets.Range.from_array(arg_desc) - input_edge = gtir_dataflow.MemletInputEdge( - ctx.state, outer_arg.dc_node, input_subset, nsdfg_node, input_connector - ) - lambda_input_edges.append(input_edge) + assert not lambda_ctx.sdfg.arrays[input_connector].transient + if outer_arg is None: + # This argument has empty domain, which means that it should not be + # used inside the nested SDFG, and does not need to be connected outside. + assert all( + node.data != input_connector + for node in lambda_ctx.sdfg.all_nodes_recursive() + if isinstance(node, dace.nodes.AccessNode) + ) + lambda_ctx.sdfg.arrays[input_connector].transient = True + else: + arg_desc = outer_arg.dc_node.desc(ctx.sdfg) + input_subset = dace_subsets.Range.from_array(arg_desc) + input_edge = gtir_dataflow.MemletInputEdge( + ctx.state, outer_arg.dc_node, input_subset, nsdfg_node, input_connector + ) + input_edges.append(input_edge) # for output connections, we create temporary arrays that contain the computation # results of a column slice for each point in the horizontal domain - lambda_output_tree = gtx_utils.tree_map( - lambda lambda_output_data: _connect_nested_sdfg_output_to_temporaries( - lambda_ctx, ctx, nsdfg_node, lambda_output_data + output_tree = gtx_utils.tree_map( + lambda output_data, output_domain: _connect_nested_sdfg_output_to_temporaries( + lambda_ctx, ctx, nsdfg_node, output_data ) - )(lambda_output) + if output_domain != infer_domain.DomainAccessDescriptor.NEVER + else _remove_nested_sdfg_connector(lambda_ctx, nsdfg_node, output_data) + )(lambda_output, node.annex.domain) # we call a helper method to create a map scope that will compute the entire field return _create_scan_field_operator( - ctx, field_domain, node.type, sdfg_builder, lambda_input_edges, lambda_output_tree + ctx, field_domain, node.type, sdfg_builder, input_edges, output_tree, node.annex.domain ) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py index 162133fb28..13dc5c4c1b 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py @@ -16,8 +16,10 @@ import dace from dace import subsets as dace_subsets +from gt4py.eve.extended_typing import MaybeNestedInTuple from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins +from gt4py.next.iterator import ir as gtir from gt4py.next.program_processors.runners.dace import ( gtir_dataflow, gtir_domain, @@ -178,7 +180,7 @@ def get_symbol_mapping( return symbol_mapping -FieldopResult: TypeAlias = FieldopData | tuple[FieldopData | tuple, ...] +FieldopResult: TypeAlias = MaybeNestedInTuple[FieldopData | None] """Result of a field operator, can be either a field or a tuple fields.""" @@ -192,28 +194,19 @@ class SymbolicData: """Data type used for field indexing.""" -def get_tuple_type(data: tuple[FieldopResult, ...]) -> ts.TupleType: - """ - Compute the `ts.TupleType` corresponding to the tuple structure of `FieldopResult`. - """ - return ts.TupleType( - types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] - ) - - -def flatten_tuples(name: str, arg: FieldopResult) -> list[tuple[str, FieldopData]]: +def flatten_tuple(sym: gtir.Sym, arg: FieldopResult) -> list[tuple[str, FieldopData | None]]: """ Visit a `FieldopResult`, potentially containing nested tuples, and construct a list of pairs `(str, FieldopData)` containing the symbol name of each tuple field and the corresponding `FieldopData`. """ if isinstance(arg, tuple): - tuple_type = get_tuple_type(arg) - tuple_symbols = gtir_to_sdfg_utils.flatten_tuple_fields(name, tuple_type) + assert isinstance(sym.type, ts.TupleType) + tuple_symbols = gtir_to_sdfg_utils.flatten_tuple_fields(sym.id, sym.type) tuple_data_fields = gtx_utils.flatten_nested_tuple(arg) return [ (str(tsym.id), tfield) for tsym, tfield in zip(tuple_symbols, tuple_data_fields, strict=True) ] else: - return [(name, arg)] + return [(sym.id, arg)] diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_utils.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_utils.py index 64509b2acd..c236d16c97 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_utils.py @@ -52,6 +52,8 @@ def get_arg_symbol_mapping( A mapping from inner symbol names to values or symbolic definitions in the parent SDFG. """ + if arg is None: + return {} if isinstance(arg, gtir_to_sdfg_types.FieldopData): return arg.get_symbol_mapping(dataname, sdfg) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index dcbff25dab..a9374c6a24 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -120,6 +120,7 @@ def test_compile_kwargs(cartesian_case, compile_testee): assert np.allclose(kwargs["out"].ndarray, a.ndarray + b.ndarray) +@pytest.mark.uses_scan def test_compile_scan(cartesian_case, compile_testee_scan): if cartesian_case.backend is None: pytest.skip("Embedded compiled program doesn't make sense.") diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 81ed0aee62..3fe845cf08 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -247,7 +247,10 @@ def test_gtir_tuple_swap(): body=[ gtir.SetAt( expr=im.make_tuple("y", "x"), - domain=im.get_field_domain(gtx_common.GridType.CARTESIAN, "x", [IDim]), + domain=im.make_tuple( + im.get_field_domain(gtx_common.GridType.CARTESIAN, "x", [IDim]), + im.get_field_domain(gtx_common.GridType.CARTESIAN, "y", [IDim]), + ), # TODO(havogt): add a frontend check for this pattern target=im.make_tuple("x", "y"), ) @@ -455,10 +458,22 @@ def test_gtir_tuple_return(): body=[ gtir.SetAt( expr=im.make_tuple(im.make_tuple(im.op_as_fieldop("plus")("x", "y"), "x"), "y"), - domain=im.get_field_domain( - gtx_common.GridType.CARTESIAN, - im.tuple_get(0, im.tuple_get(0, "z")), - [IDim], + domain=im.make_tuple( + im.make_tuple( + im.get_field_domain( + gtx_common.GridType.CARTESIAN, + im.tuple_get(0, im.tuple_get(0, "z")), + [IDim], + ), + im.get_field_domain( + gtx_common.GridType.CARTESIAN, + im.tuple_get(1, im.tuple_get(0, "z")), + [IDim], + ), + ), + im.get_field_domain( + gtx_common.GridType.CARTESIAN, im.tuple_get(1, "z"), [IDim] + ), ), target=gtir.SymRef(id="z"), ) @@ -502,7 +517,10 @@ def test_gtir_tuple_target(): body=[ gtir.SetAt( expr=im.make_tuple(im.op_as_fieldop("plus")("x", 1.0), gtir.SymRef(id="x")), - domain=im.get_field_domain(gtx_common.GridType.CARTESIAN, "x", [IDim]), + domain=im.make_tuple( + im.get_field_domain(gtx_common.GridType.CARTESIAN, "x", [IDim]), + im.get_field_domain(gtx_common.GridType.CARTESIAN, "y", [IDim]), + ), target=im.make_tuple("x", "y"), ) ], @@ -1851,8 +1869,13 @@ def test_gtir_let_lambda_with_tuple1(): im.make_tuple(im.op_as_fieldop("plus", inner_domain)("x", "y"), "x"), "y" ), )(im.make_tuple(im.tuple_get(1, im.tuple_get(0, "t")), im.tuple_get(1, "t"))), - domain=im.get_field_domain( - gtx_common.GridType.CARTESIAN, im.tuple_get(0, "z"), [IDim] + domain=im.make_tuple( + im.get_field_domain( + gtx_common.GridType.CARTESIAN, im.tuple_get(0, "z"), [IDim] + ), + im.get_field_domain( + gtx_common.GridType.CARTESIAN, im.tuple_get(1, "z"), [IDim] + ), ), target=gtir.SymRef(id="z"), ) @@ -1905,8 +1928,16 @@ def test_gtir_let_lambda_with_tuple2(): ) ) ), - domain=im.get_field_domain( - gtx_common.GridType.CARTESIAN, im.tuple_get(0, "z"), [IDim] + domain=im.make_tuple( + im.get_field_domain( + gtx_common.GridType.CARTESIAN, im.tuple_get(0, "z"), [IDim] + ), + im.get_field_domain( + gtx_common.GridType.CARTESIAN, im.tuple_get(1, "z"), [IDim] + ), + im.get_field_domain( + gtx_common.GridType.CARTESIAN, im.tuple_get(2, "z"), [IDim] + ), ), target=gtir.SymRef(id="z"), ) @@ -2248,7 +2279,7 @@ def test_gtir_scan(id, use_symbolic_column_size): im.make_tuple(0.0, True), ) )("x"), - domain=domain, + domain=im.make_tuple(domain, domain), target=im.make_tuple(gtir.SymRef(id="y"), gtir.SymRef(id="z")), ) ], From 8229f342dc28b717bb5ba2b9c65ae7937de793b3 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 30 Oct 2025 15:06:21 +0100 Subject: [PATCH 40/44] address review comments in SDFG lowering --- .../runners/dace/gtir_dataflow.py | 72 +++------ .../runners/dace/gtir_domain.py | 24 +-- .../runners/dace/gtir_to_sdfg.py | 130 ++++++--------- .../runners/dace/gtir_to_sdfg_primitives.py | 46 ++---- .../runners/dace/gtir_to_sdfg_scan.py | 148 ++++++++++-------- .../runners/dace/gtir_to_sdfg_types.py | 21 ++- .../runners/dace/gtir_to_sdfg_utils.py | 36 +---- tests/next_tests/definitions.py | 1 + .../ffront_tests/test_execution.py | 4 - .../iterator_tests/test_tuple.py | 8 +- .../iterator_tests/test_column_stencil.py | 3 +- 11 files changed, 205 insertions(+), 288 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py index 9d6327316f..9b114ffe39 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -29,6 +29,7 @@ from dace import subsets as dace_subsets from gt4py import eve +from gt4py.eve.extended_typing import MaybeNestedInTuple, NestedTuple from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import builtins as gtir_builtins, ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im @@ -371,7 +372,7 @@ def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: def get_tuple_type( - data: tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...], + data: NestedTuple[IteratorExpr | MemletExpr | ValueExpr], ) -> ts.TupleType: """ Compute the `ts.TupleType` corresponding to the tuple structure of input data expressions. @@ -413,7 +414,7 @@ class LambdaToDataflow(eve.NodeVisitor): input_edges: list[DataflowInputEdge] = dataclasses.field(default_factory=lambda: []) symbol_map: dict[ str, - IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...], + MaybeNestedInTuple[IteratorExpr | DataExpr], ] = dataclasses.field(default_factory=dict) def _add_input_data_edge( @@ -747,10 +748,7 @@ def _visit_if_branch( expr: gtir.Expr, if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr], direct_deref_iterators: Iterable[str], - ) -> tuple[ - list[DataflowInputEdge], - tuple[DataflowOutputEdge | tuple[Any, ...], ...], - ]: + ) -> tuple[list[DataflowInputEdge], MaybeNestedInTuple[DataflowOutputEdge]]: """ Helper method to visit an if-branch expression and lower it to a dataflow inside the given nested SDFG and state. @@ -847,7 +845,7 @@ def _visit_if_branch_result( ) return ValueExpr(output_node, edge.result.gt_dtype) - def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[Any, ...], ...]: + def _visit_if(self, node: gtir.FunCall) -> MaybeNestedInTuple[ValueExpr]: """ Lowers an if-expression with exclusive branch execution into a nested SDFG, in which each branch is lowered into a dataflow in a separate state and @@ -948,18 +946,16 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp for edge in in_edges: edge.connect(map_entry=None) - if isinstance(node.type, ts.TupleType): - out_symbol_tree = gtir_to_sdfg_utils.make_symbol_tree("__output", node.type) - outer_value = gtx_utils.tree_map( - lambda x, y, nstate=nstate: self._visit_if_branch_result(nsdfg, nstate, x, y) - )(output_tree, out_symbol_tree) - else: - assert isinstance(node.type, ts.FieldType | ts.ScalarType) - assert len(output_tree) == 1 and isinstance(output_tree[0], DataflowOutputEdge) - output_edge = output_tree[0] - outer_value = self._visit_if_branch_result( - nsdfg, nstate, output_edge, im.sym("__output", node.type) - ) + out_symbol = ( + gtir_to_sdfg_utils.make_symbol_tree("__output", node.type) + if isinstance(node.type, ts.TupleType) + else im.sym("__output", node.type) + ) + + outer_value = gtx_utils.tree_map( + lambda x, y, nstate=nstate: self._visit_if_branch_result(nsdfg, nstate, x, y) + )(output_tree, out_symbol) + # Isolated access node will make validation fail. # Isolated access nodes can be found in `make_tuple` expressions that # construct tuples from input arguments. @@ -1803,9 +1799,7 @@ def _visit_tuple_get( tuple_fields = self.visit(node.args[1]) return tuple_fields[index] - def visit_FunCall( - self, node: gtir.FunCall - ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: + def visit_FunCall(self, node: gtir.FunCall) -> MaybeNestedInTuple[IteratorExpr | DataExpr]: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) @@ -1843,9 +1837,7 @@ def visit_FunCall( else: raise NotImplementedError(f"Invalid 'FunCall' node: {node}.") - def visit_Lambda( - self, node: gtir.Lambda - ) -> DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...]: + def visit_Lambda(self, node: gtir.Lambda) -> MaybeNestedInTuple[DataflowOutputEdge]: def _visit_Lambda_impl( output_expr: DataflowOutputEdge | ValueExpr | MemletExpr | SymbolExpr, ) -> DataflowOutputEdge: @@ -1886,9 +1878,7 @@ def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: dc_dtype = gtx_dace_utils.as_dace_type(node.type) return SymbolExpr(node.value, dc_dtype) - def visit_SymRef( - self, node: gtir.SymRef - ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: + def visit_SymRef(self, node: gtir.SymRef) -> MaybeNestedInTuple[IteratorExpr | DataExpr]: param = str(node.id) if param in self.symbol_map: return self.symbol_map[param] @@ -1899,13 +1889,8 @@ def visit_SymRef( def visit_let( self, node: gtir.Lambda, - args: Sequence[ - IteratorExpr - | MemletExpr - | ValueExpr - | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] - ], - ) -> DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...]: + args: Sequence[MaybeNestedInTuple[IteratorExpr | MemletExpr | ValueExpr]], + ) -> MaybeNestedInTuple[DataflowOutputEdge]: """ Maps lambda arguments to internal parameters. @@ -1940,16 +1925,8 @@ def translate_lambda_to_dataflow( state: dace.SDFGState, sdfg_builder: gtir_to_sdfg.DataflowBuilder, node: gtir.Lambda, - args: Sequence[ - IteratorExpr - | MemletExpr - | ValueExpr - | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] - ], -) -> tuple[ - list[DataflowInputEdge], - tuple[DataflowOutputEdge | tuple[Any, ...], ...], -]: + args: Sequence[MaybeNestedInTuple[IteratorExpr | MemletExpr | ValueExpr]], +) -> tuple[list[DataflowInputEdge], MaybeNestedInTuple[DataflowOutputEdge]]: """ Entry point to visit a `Lambda` node and lower it to a dataflow graph, that can be instantiated inside a map scope implementing the field operator. @@ -1973,7 +1950,4 @@ def translate_lambda_to_dataflow( taskgen = LambdaToDataflow(sdfg, state, sdfg_builder) lambda_output = taskgen.visit_let(node, args) - if isinstance(lambda_output, DataflowOutputEdge): - return taskgen.input_edges, (lambda_output,) - else: - return taskgen.input_edges, lambda_output + return taskgen.input_edges, lambda_output diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_domain.py b/src/gt4py/next/program_processors/runners/dace/gtir_domain.py index 68bd1976b2..303bb5d936 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_domain.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_domain.py @@ -63,21 +63,25 @@ def get_field_domain(domain: domain_utils.SymbolicDomain) -> FieldopDomain: TargetDomain: TypeAlias = MaybeNestedInTuple[domain_utils.SymbolicDomain] """Symbolic domain which defines the range to write in the target field. -For tuple output, the corresponding domain in fieldview is a tuple of domains. +For tuple output in fieldview, `TargetDomain` is a tree-like tuple of symbolic domains. """ -class TargetDomainParser(eve.visitors.NodeTranslator): - """Visitor class to build a `TargetDomain` symbolic domain.""" +def extract_target_domain(node: gtir.Expr) -> TargetDomain: + """ + Visit a GTIR domain expression and construct a `TargetDomain` symbolic domain. + + We use a visitor class to extract the tree-like structure for (nested) tuple of domains. + """ - def visit_FunCall(self, node: gtir.FunCall) -> TargetDomain: - if cpm.is_call_to(node, "make_tuple"): - return tuple(self.visit(arg) for arg in node.args) - else: - return domain_utils.SymbolicDomain.from_expr(node) + class TargetDomainParser(eve.visitors.NodeTranslator): + def visit_FunCall(self, node: gtir.FunCall) -> TargetDomain: + if cpm.is_call_to(node, "make_tuple"): + return tuple(self.visit(arg) for arg in node.args) + else: + return domain_utils.SymbolicDomain.from_expr(node) - def apply(cls, node: gtir.Expr) -> TargetDomain: - return cls.visit(node) + return TargetDomainParser().visit(node) def get_domain_indices( diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py index 5ebe84e337..9bebfc5593 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py @@ -16,7 +16,6 @@ import abc import dataclasses -import itertools from typing import ( Any, Dict, @@ -683,7 +682,7 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: def visit_SetAt( self, stmt: gtir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState - ) -> dace.SDFGState | None: + ) -> dace.SDFGState: """Visits a `SetAt` statement expression and writes the local result to some external storage. Each statement expression results in some sort of dataflow gragh writing to temporary storage. @@ -694,7 +693,7 @@ def visit_SetAt( """ # Visit the domain expression. - domain = gtir_domain.TargetDomainParser().apply(stmt.domain) + domain = gtir_domain.extract_target_domain(stmt.domain) # Visit the field operator expression. source_tree = self._visit_expression(stmt.expr, sdfg, state) @@ -753,24 +752,11 @@ def _visit_target( ), ) - if isinstance(target_tree, tuple) and not isinstance(domain, tuple): - # This branch handles a specific case that indeed never happens in - # fieldview GTIR, only in iterator GTIR tests. The case corresponds - # to 'as_fieldop' with tuple output and single domain, which is a format - # used when multiple 'as_fieldop' are fused into one. The input to SDFG - # lowering is fieldview IR, where 'as_fieldop' will always have a single - # domain and the frontend will never emit 'as_fieldop' with tuple output. - gtx_utils.tree_map( - lambda source, target, domain_=domain, target_state_=target_state: _visit_target( - source, target, domain_, target_state_ - ) - )(source_tree, target_tree) - else: - gtx_utils.tree_map( - lambda source, target, domain_, target_state_=target_state: _visit_target( - source, target, domain_, target_state_ - ) - )(source_tree, target_tree, domain) + gtx_utils.tree_map( + lambda source, target, domain_, target_state_=target_state: _visit_target( + source, target, domain_, target_state_ + ) + )(source_tree, target_tree, domain) if target_state.is_empty(): sdfg.remove_node(target_state) @@ -858,26 +844,22 @@ def visit_Lambda( the previous symbol during traversal of the lambda expression. """ - symbolic_args = { - str(param.id): arg - for param, arg in args.items() - if isinstance(arg, gtir_to_sdfg_types.SymbolicData) - } - data_args: dict[str, gtir_to_sdfg_types.FieldopResult] = { - str(param.id): arg # type: ignore[misc] # symbolic args are filtered out - for param, arg in args.items() - if arg is not None and param.id not in symbolic_args - } - - lambda_arg_nodes = dict( - itertools.chain( - *[ - gtir_to_sdfg_types.flatten_tuple(param, arg) # type: ignore[arg-type] # symbolic args are filtered out - for param, arg in args.items() - if param.id in data_args - ] - ) - ) + data_args: dict[str, gtir_to_sdfg_types.FieldopResult] = {} + symbolic_args: dict[str, gtir_to_sdfg_types.SymbolicData] = {} + lambda_arg_nodes: dict[str, gtir_to_sdfg_types.FieldopData] = {} + for param, arg in args.items(): + pname = str(param.id) + if arg is None: + pass # domain inference has detetcted that this argument is not used + elif isinstance(arg, gtir_to_sdfg_types.SymbolicData): + symbolic_args[pname] = arg + else: + data_args[pname] = arg + lambda_arg_nodes |= { + str(nested_param.id): nested_arg + for nested_param, nested_arg in gtir_to_sdfg_types.flatten_tuple(param, arg) + if nested_arg is not None # we filter out arguments with empty domain + } # inherit symbols from parent scope but eventually override with local symbols lambda_symbols = { @@ -908,33 +890,28 @@ def visit_Lambda( } input_memlets = {} + unused_data = set() for nsdfg_dataname, nsdfg_datadesc in lambda_ctx.sdfg.arrays.items(): if nsdfg_datadesc.transient: - continue + pass # nothing to do here elif nsdfg_dataname in lambda_arg_nodes: arg_node = lambda_arg_nodes[nsdfg_dataname] - if arg_node is None: - # This argument has empty domain, which means that it should not be - # used inside the nested SDFG, and does not need to be connected outside. - assert all( - node.data != nsdfg_dataname - for node in lambda_ctx.sdfg.all_nodes_recursive() - if isinstance(node, dace.nodes.AccessNode) - ) - lambda_ctx.sdfg.arrays[nsdfg_dataname].transient = True - continue - else: - dataname = arg_node.dc_node.data - datadesc = arg_node.dc_node.desc(ctx.sdfg) + source_data = arg_node.dc_node.data + input_memlets[nsdfg_dataname] = ctx.sdfg.make_array_memlet(source_data) + elif nsdfg_dataname in ctx.sdfg.arrays: + source_data = nsdfg_dataname + # ensure that connectivity tables are non-transient arrays in parent SDFG + if source_data in connectivity_arrays: + ctx.sdfg.arrays[source_data].transient = False + input_memlets[nsdfg_dataname] = ctx.sdfg.make_array_memlet(source_data) else: - dataname = nsdfg_dataname - datadesc = ctx.sdfg.arrays[nsdfg_dataname] - - # ensure that connectivity tables are non-transient arrays in parent SDFG - if dataname in connectivity_arrays: - datadesc.transient = False + # This argument has empty domain, which means that it is not used + # by the lambda expression, and does not need to be connected on + # the nested SDFG. + unused_data.add(nsdfg_dataname) - input_memlets[nsdfg_dataname] = ctx.sdfg.make_array_memlet(dataname) + for data in sorted(unused_data): # NOTE: remove the data in deterministic order + lambda_ctx.sdfg.remove_data(data, validate=__debug__) # Process lambda outputs # @@ -965,19 +942,18 @@ def visit_Lambda( # Map free symbols to parent SDFG nsdfg_symbols_mapping = {} - for sym in lambda_ctx.sdfg.free_symbols: - if (sym_id := str(sym)) in lambda_arg_nodes: - arg_node = lambda_arg_nodes[sym_id] - assert arg_node and isinstance(arg_node.gt_type, ts.ScalarType) + for symbol in lambda_ctx.sdfg.free_symbols: + if symbol in lambda_arg_nodes: + assert isinstance(lambda_arg_nodes[symbol].gt_type, ts.ScalarType) raise NotImplementedError( "Unexpected mapping of scalar node to symbol on nested SDFG." ) - elif sym_id in symbolic_args: - nsdfg_symbols_mapping[sym_id] = symbolic_args[sym_id].value + elif symbol in symbolic_args: + nsdfg_symbols_mapping[symbol] = symbolic_args[symbol].value else: - nsdfg_symbols_mapping[sym_id] = sym - for param, arg in data_args.items(): - nsdfg_symbols_mapping |= gtir_to_sdfg_utils.get_arg_symbol_mapping(param, arg, ctx.sdfg) + nsdfg_symbols_mapping[symbol] = symbol + for pname, arg in lambda_arg_nodes.items(): + nsdfg_symbols_mapping |= arg.get_symbol_mapping(pname, ctx.sdfg) nsdfg_node = ctx.state.add_nested_sdfg( lambda_ctx.sdfg, @@ -990,17 +966,7 @@ def visit_Lambda( for input_connector, memlet in input_memlets.items(): if input_connector in lambda_arg_nodes: arg_node = lambda_arg_nodes[input_connector] - if arg_node is None: - # this argument has empty domain, therefore it should not be used inside the nested SDFG - assert all( - node.data != input_connector - for node in lambda_ctx.sdfg.all_nodes_recursive() - if isinstance(node, dace.nodes.AccessNode) - ) - lambda_ctx.sdfg.arrays[input_connector].transient = True - continue - else: - src_node = arg_node.dc_node + src_node = arg_node.dc_node else: src_node = ctx.state.add_access(memlet.data) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py index 78fab48a16..1f4298c591 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py @@ -9,7 +9,7 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Any, Iterable, Optional, Protocol +from typing import TYPE_CHECKING, Iterable, Optional, Protocol import dace from dace import subsets as dace_subsets @@ -175,17 +175,13 @@ def _create_field_operator_impl( def _create_field_operator( ctx: gtir_to_sdfg.SubgraphContext, domain: gtir_domain.FieldopDomain, - node_type: ts.FieldType | ts.TupleType, + node_type: ts.FieldType, sdfg_builder: gtir_to_sdfg.SDFGBuilder, input_edges: Iterable[gtir_dataflow.DataflowInputEdge], - output_tree: tuple[gtir_dataflow.DataflowOutputEdge | tuple[Any, ...], ...], + output_edge: gtir_dataflow.DataflowOutputEdge, ) -> gtir_to_sdfg_types.FieldopResult: """ - Helper method to build the output of a field operator, which can consist of - a single field or a tuple of fields. - - A tuple of fields is returned when one stencil computes a grid point on multiple - fields: for each field, this method will call `_create_field_operator_impl()`. + Helper method to build the output of a field operator. Args: ctx: The SDFG context in which to lower the field operator. @@ -193,11 +189,11 @@ def _create_field_operator( node_type: The GT4Py type of the IR node that produces this field. sdfg_builder: The object used to build the map scope in the provided SDFG. input_edges: List of edges to pass input data into the dataflow. - output_tree: A tree representation of the dataflow output data. + output_edge: Edge corresponding to the dataflow output. Returns: - The descriptor of the field operator result, which can be either a single - field or a tuple fields. + The descriptor of the field operator result, which is a single field defined + on the domain of the field operator. """ if len(domain) == 0: @@ -219,22 +215,7 @@ def _create_field_operator( for edge in input_edges: edge.connect(map_entry) - if isinstance(node_type, ts.FieldType): - assert len(output_tree) == 1 and isinstance( - output_tree[0], gtir_dataflow.DataflowOutputEdge - ) - output_edge = output_tree[0] - return _create_field_operator_impl( - ctx, sdfg_builder, domain, output_edge, node_type, map_exit - ) - else: - # handle tuples of fields - output_symbol_tree = gtir_to_sdfg_utils.make_symbol_tree("x", node_type) - return gtx_utils.tree_map( - lambda edge, sym, ctx_=ctx: _create_field_operator_impl( - ctx_, sdfg_builder, domain, edge, sym.type, map_exit - ) - )(output_tree, output_symbol_tree) + return _create_field_operator_impl(ctx, sdfg_builder, domain, output_edge, node_type, map_exit) def translate_as_fieldop( @@ -266,11 +247,13 @@ def translate_as_fieldop( if cpm.is_call_to(fieldop_expr, "scan"): return translate_scan(node, ctx, sdfg_builder) + if not isinstance(node.type, ts.FieldType): + raise NotImplementedError("Unexpected 'as_filedop' with tuple output in SDFG lowering.") + if cpm.is_ref_to(fieldop_expr, "deref"): # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. - assert isinstance(node.type, ts.FieldType) stencil_expr = im.lambda_("a")(im.deref("a")) stencil_expr.expr.type = node.type.dtype elif isinstance(fieldop_expr, gtir.Lambda): @@ -292,12 +275,13 @@ def translate_as_fieldop( fieldop_args = [_parse_fieldop_arg(arg, ctx, sdfg_builder, field_domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain - input_edges, output_edges = gtir_dataflow.translate_lambda_to_dataflow( + input_edges, output_edge = gtir_dataflow.translate_lambda_to_dataflow( ctx.sdfg, ctx.state, sdfg_builder, stencil_expr, fieldop_args ) + assert isinstance(output_edge, gtir_dataflow.DataflowOutputEdge) return _create_field_operator( - ctx, field_domain, node.type, sdfg_builder, input_edges, output_edges + ctx, field_domain, node.type, sdfg_builder, input_edges, output_edge ) @@ -511,7 +495,7 @@ def translate_index( ] output_edge = gtir_dataflow.DataflowOutputEdge(ctx.state, index_value) return _create_field_operator( - ctx, field_domain, node.type, sdfg_builder, input_edges, (output_edge,) + ctx, field_domain, node.type, sdfg_builder, input_edges, output_edge ) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py index 3550bd140a..a429c0820b 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py @@ -22,8 +22,7 @@ from __future__ import annotations -import itertools -from typing import Any, Iterable +from typing import Iterable import dace from dace import subsets as dace_subsets @@ -53,7 +52,7 @@ def _parse_scan_fieldop_arg( ctx: gtir_to_sdfg.SubgraphContext, sdfg_builder: gtir_to_sdfg.SDFGBuilder, field_domain: gtir_domain.FieldopDomain, -) -> gtir_dataflow.MemletExpr | tuple[gtir_dataflow.MemletExpr | tuple[Any, ...], ...]: +) -> MaybeNestedInTuple[gtir_dataflow.MemletExpr]: """Helper method to visit an expression passed as argument to a scan field operator. On the innermost level, a scan operator is lowered to a loop region which computes @@ -107,8 +106,17 @@ def _create_scan_field_operator_impl( Therefore, the memlet subset will write a slice into the result array, that corresponds to the full vertical shape for each horizontal grid point. + Another difference is that this function is called on all fields inside a tuple, + in case of tuple return. Note that a regular field operator only computes a + single field, never a tuple of fields. For tuples, it can happen that one of + the nested fields is not used, outside the scan field operator, and therefore + does not need to be computed. Then, the domain inferred by gt4py on this field + is empty and the corresponding `output_edge` argument to this function is None. + In this case, the function does not allocate an array node for the output field + and returns None. + Refer to `gtir_to_sdfg_primitives._create_field_operator_impl()` for - the description of function arguments and return values. + the description of function arguments. """ if output_edge is None: # According to domain inference, this tuple field does not need to be computed. @@ -223,7 +231,12 @@ def _create_scan_field_operator( by a loop region in a mapped nested SDFG. Refer to `gtir_to_sdfg_primitives._create_field_operator()` for the - description of function arguments and return values. + description of function arguments. Note that the return value is different, + because the scan field operator can return a tuple of fields, while a regular + field operator return a single field. The domain of the nested fields, in + a tuple, can be empty, in case the nested field is not used outside the scan. + In this case, the corresponding `output` edge will be None and this function + will also return None for the corresponding field inside the tree-like result. """ dims, _, _ = gtir_domain.get_field_layout(field_domain) @@ -254,29 +267,29 @@ def _create_scan_field_operator( for edge in input_edges: edge.connect(map_entry) + output_symbol: MaybeNestedInTuple[gtir.Sym] if isinstance(node_type, ts.FieldType): - assert isinstance(output, gtir_dataflow.DataflowOutputEdge) - assert isinstance(output_domain, domain_utils.SymbolicDomain) - return _create_scan_field_operator_impl( - ctx, sdfg_builder, output, output_domain, node_type, map_exit - ) + assert isinstance(output_domain, infer_domain.NonTupleDomainAccess) + output_symbol = im.sym("__gtir_guaranteed_unused_dummy_variable", node_type) else: - # Handle tuples of fields. note that the symbol name 'x' in the call below - # is not used, we only need the tree structure of the `TupleType` definition - # to pass to `tree_map()` in order to retrieve the type of each field. - output_symbol_tree = gtir_to_sdfg_utils.make_symbol_tree("x", node_type) - return gtx_utils.tree_map( - lambda edge_, domain_, sym_, ctx_=ctx: ( - _create_scan_field_operator_impl( - ctx_, - sdfg_builder, - edge_, - domain_, - sym_.type, - map_exit, - ) + # handle tuples of fields + assert isinstance(output_domain, tuple) + output_symbol = gtir_to_sdfg_utils.make_symbol_tree( + "__gtir_guaranteed_unused_dummy_variable", node_type + ) + + return gtx_utils.tree_map( + lambda edge, domain, sym, ctx_=ctx: ( + _create_scan_field_operator_impl( + ctx_, + sdfg_builder, + edge, + domain, + sym.type, + map_exit, ) - )(output, output_domain, output_symbol_tree) + ) + )(output, output_domain, output_symbol) def _scan_input_name(input_name: str) -> str: @@ -508,15 +521,9 @@ def connect_scan_output( # write the stencil result (value on one vertical level) into a 1D field # with full vertical shape representing one column - if isinstance(scan_carry_input, tuple): - assert isinstance(lambda_result_shape, tuple) - lambda_output = gtx_utils.tree_map(connect_scan_output)( - lambda_result, lambda_result_shape, scan_carry_input - ) - else: - assert isinstance(lambda_result[0], gtir_dataflow.DataflowOutputEdge) - assert isinstance(lambda_result_shape, list) - lambda_output = connect_scan_output(lambda_result[0], lambda_result_shape, scan_carry_input) + lambda_output = gtx_utils.tree_map(connect_scan_output)( + lambda_result, lambda_result_shape, scan_carry_input + ) # in case tuples are passed as argument, isolated access nodes might be left in the state, # because not all tuple fields are necessarily accessed inside the lambda scope @@ -569,13 +576,22 @@ def _connect_nested_sdfg_output_to_temporaries( return gtir_dataflow.DataflowOutputEdge(outer_ctx.state, output_expr) -def _remove_nested_sdfg_connector( - inner_ctx: gtir_to_sdfg.SubgraphContext, +def _handle_dataflow_result_of_nested_sdfg( nsdfg_node: dace.nodes.NestedSDFG, + inner_ctx: gtir_to_sdfg.SubgraphContext, + outer_ctx: gtir_to_sdfg.SubgraphContext, inner_data: gtir_to_sdfg_types.FieldopData, -) -> None: - inner_data.dc_node.desc(inner_ctx.sdfg).transient = True - nsdfg_node.out_connectors.pop(inner_data.dc_node.data) + field_domain: infer_domain.NonTupleDomainAccess, +) -> gtir_dataflow.DataflowOutputEdge | None: + if isinstance(field_domain, domain_utils.SymbolicDomain): + return _connect_nested_sdfg_output_to_temporaries( + inner_ctx, outer_ctx, nsdfg_node, inner_data + ) + else: + assert field_domain == infer_domain.DomainAccessDescriptor.NEVER + inner_data.dc_node.desc(inner_ctx.sdfg).transient = True + nsdfg_node.out_connectors.pop(inner_data.dc_node.data) + return None def translate_scan( @@ -657,13 +673,19 @@ def translate_scan( lambda_args = [sdfg_builder.visit(arg, ctx=ctx) for arg in node.args] lambda_args_mapping = [ (im.sym(_scan_input_name(scan_carry), scan_carry_type), init_data), - ] + [(param, arg) for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True)] + ] + [ + (param, arg) + for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) + if arg is not None + ] - lambda_arg_nodes = dict( - itertools.chain( - *[gtir_to_sdfg_types.flatten_tuple(psym, arg) for psym, arg in lambda_args_mapping] - ) - ) + lambda_arg_nodes: dict[str, gtir_to_sdfg_types.FieldopData] = {} + for param, arg in lambda_args_mapping: + lambda_arg_nodes |= { + str(nested_param.id): nested_arg + for nested_param, nested_arg in gtir_to_sdfg_types.flatten_tuple(param, arg) + if nested_arg is not None + } # parse the dataflow output symbols if isinstance(scan_carry_type, ts.TupleType): @@ -678,8 +700,8 @@ def translate_scan( # build the mapping of symbols from nested SDFG to field operator context nsdfg_symbols_mapping = {str(sym): sym for sym in lambda_ctx.sdfg.free_symbols} - for psym, arg in lambda_args_mapping: - nsdfg_symbols_mapping |= gtir_to_sdfg_utils.get_arg_symbol_mapping(psym.id, arg, ctx.sdfg) + for pname, arg in lambda_arg_nodes.items(): + nsdfg_symbols_mapping |= arg.get_symbol_mapping(pname, ctx.sdfg) # the scan nested SDFG is ready: it is instantiated in the field operator context # where the map scope over the horizontal domain lives @@ -693,31 +715,23 @@ def translate_scan( input_edges = [] for input_connector, outer_arg in lambda_arg_nodes.items(): assert not lambda_ctx.sdfg.arrays[input_connector].transient - if outer_arg is None: - # This argument has empty domain, which means that it should not be - # used inside the nested SDFG, and does not need to be connected outside. - assert all( - node.data != input_connector - for node in lambda_ctx.sdfg.all_nodes_recursive() - if isinstance(node, dace.nodes.AccessNode) - ) - lambda_ctx.sdfg.arrays[input_connector].transient = True - else: - arg_desc = outer_arg.dc_node.desc(ctx.sdfg) - input_subset = dace_subsets.Range.from_array(arg_desc) - input_edge = gtir_dataflow.MemletInputEdge( - ctx.state, outer_arg.dc_node, input_subset, nsdfg_node, input_connector - ) - input_edges.append(input_edge) + arg_desc = outer_arg.dc_node.desc(ctx.sdfg) + input_subset = dace_subsets.Range.from_array(arg_desc) + input_edge = gtir_dataflow.MemletInputEdge( + ctx.state, outer_arg.dc_node, input_subset, nsdfg_node, input_connector + ) + input_edges.append(input_edge) # for output connections, we create temporary arrays that contain the computation # results of a column slice for each point in the horizontal domain output_tree = gtx_utils.tree_map( - lambda output_data, output_domain: _connect_nested_sdfg_output_to_temporaries( - lambda_ctx, ctx, nsdfg_node, output_data + lambda output_data, output_domain: _handle_dataflow_result_of_nested_sdfg( + nsdfg_node=nsdfg_node, + inner_ctx=lambda_ctx, + outer_ctx=ctx, + inner_data=output_data, + field_domain=output_domain, ) - if output_domain != infer_domain.DomainAccessDescriptor.NEVER - else _remove_nested_sdfg_connector(lambda_ctx, nsdfg_node, output_data) )(lambda_output, node.annex.domain) # we call a helper method to create a map scope that will compute the entire field diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py index 13dc5c4c1b..53a2a2b0ca 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py @@ -181,7 +181,13 @@ def get_symbol_mapping( FieldopResult: TypeAlias = MaybeNestedInTuple[FieldopData | None] -"""Result of a field operator, can be either a field or a tuple fields.""" +"""Result of a field operator, can be either a field or a tuple fields. + +For tuple of fields, any of the nested fields can be None, in case it is not udes +and therefore does not need to be computed. The information whether a field needs +to be computed or not is the result of GTIR domain inference, and it is stored in +the GTIR node annex domain. +""" @dataclasses.dataclass(frozen=True) @@ -194,19 +200,18 @@ class SymbolicData: """Data type used for field indexing.""" -def flatten_tuple(sym: gtir.Sym, arg: FieldopResult) -> list[tuple[str, FieldopData | None]]: +def flatten_tuple(sym: gtir.Sym, arg: FieldopResult) -> list[tuple[gtir.Sym, FieldopData | None]]: """ - Visit a `FieldopResult`, potentially containing nested tuples, and construct a list - of pairs `(str, FieldopData)` containing the symbol name of each tuple field and - the corresponding `FieldopData`. + Visit a `FieldopResult`, potentially containing nested tuples, and construct + a list of pairs `(gtir.Sym, FieldopData)` containing the symbol of each tuple + field and the corresponding `FieldopData`. """ if isinstance(arg, tuple): assert isinstance(sym.type, ts.TupleType) tuple_symbols = gtir_to_sdfg_utils.flatten_tuple_fields(sym.id, sym.type) tuple_data_fields = gtx_utils.flatten_nested_tuple(arg) return [ - (str(tsym.id), tfield) - for tsym, tfield in zip(tuple_symbols, tuple_data_fields, strict=True) + (tsym, tfield) for tsym, tfield in zip(tuple_symbols, tuple_data_fields, strict=True) ] else: - return [(sym.id, arg)] + return [(sym, arg)] diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_utils.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_utils.py index c236d16c97..f99f2b4660 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_utils.py @@ -13,10 +13,11 @@ import dace from gt4py import eve +from gt4py.eve.extended_typing import NestedTuple from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.program_processors.runners.dace import gtir_python_codegen, gtir_to_sdfg_types +from gt4py.next.program_processors.runners.dace import gtir_python_codegen from gt4py.next.type_system import type_specifications as ts @@ -36,35 +37,6 @@ def debug_info( return default -def get_arg_symbol_mapping( - dataname: str, arg: gtir_to_sdfg_types.FieldopResult, sdfg: dace.SDFG -) -> dict[str, dace.symbolic.SymExpr]: - """ - Helper method to build the mapping from inner to outer SDFG of all symbols - used for storage of a field or a tuple of fields. - - Args: - dataname: The storage name inside the nested SDFG. - arg: The argument field in the parent SDFG. - sdfg: The parent SDFG where the argument field lives. - - Returns: - A mapping from inner symbol names to values or symbolic definitions - in the parent SDFG. - """ - if arg is None: - return {} - if isinstance(arg, gtir_to_sdfg_types.FieldopData): - return arg.get_symbol_mapping(dataname, sdfg) - - symbol_mapping: dict[str, dace.symbolic.SymExpr] = {} - for i, elem in enumerate(arg): - dataname_elem = f"{dataname}_{i}" - symbol_mapping |= get_arg_symbol_mapping(dataname_elem, elem, sdfg) - - return symbol_mapping - - def get_map_variable(dim: gtx_common.Dimension) -> str: """ Format map variable name based on the naming convention for application-specific SDFG transformations. @@ -73,7 +45,7 @@ def get_map_variable(dim: gtx_common.Dimension) -> str: return f"i_{dim.value}_gtx_{dim.kind}{suffix}" -def make_symbol_tree(tuple_name: str, tuple_type: ts.TupleType) -> tuple[gtir.Sym, ...]: +def make_symbol_tree(tuple_name: str, tuple_type: ts.TupleType) -> NestedTuple[gtir.Sym]: """ Creates a tree representation of the symbols corresponding to the tuple fields. The constructed tree preserves the nested nature of the tuple type, if any. @@ -91,7 +63,7 @@ def make_symbol_tree(tuple_name: str, tuple_type: ts.TupleType) -> tuple[gtir.Sy assert all(isinstance(t, ts.DataType) for t in tuple_type.types) fields = [(f"{tuple_name}_{i}", field_type) for i, field_type in enumerate(tuple_type.types)] return tuple( - make_symbol_tree(field_name, field_type) # type: ignore[misc] + make_symbol_tree(field_name, field_type) if isinstance(field_type, ts.TupleType) else im.sym(field_name, field_type) for field_name, field_type in fields diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 475b41a689..0bff0b0aa7 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -159,6 +159,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_REDUCE_WITH_LAMBDA, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_ITERATOR, XFAIL, UNSUPPORTED_MESSAGE), ] ) EMBEDDED_SKIP_LIST = [ diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index fc4579a178..647a94c8a2 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -421,7 +421,6 @@ def testee(qc: cases.IKFloatField, scalar: float): @pytest.mark.uses_scan @pytest.mark.uses_scan_in_field_operator -@pytest.mark.uses_tuple_iterator def test_tuple_scalar_scan(cartesian_case): @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def testee_scan( @@ -974,7 +973,6 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]): @pytest.mark.uses_scan @pytest.mark.uses_tuple_args -@pytest.mark.uses_tuple_iterator def test_scan_nested_tuple_input(cartesian_case): init = 1.0 k_size = cartesian_case.default_sizes[KDim] @@ -1003,7 +1001,6 @@ def simple_scan_operator(carry: float, a: tuple[float, float]) -> float: @pytest.mark.uses_scan -@pytest.mark.uses_tuple_iterator def test_scan_different_domain_in_tuple(cartesian_case): init = 1.0 i_size = cartesian_case.default_sizes[IDim] @@ -1043,7 +1040,6 @@ def foo( @pytest.mark.uses_scan -@pytest.mark.uses_tuple_iterator def test_scan_tuple_field_scalar_mixed(cartesian_case): init = 1.0 i_size = cartesian_case.default_sizes[IDim] diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index ea89bb23ba..13bf25f50d 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -40,6 +40,7 @@ def tuple_output2(inp1, inp2): @pytest.mark.parametrize("stencil", [tuple_output1, tuple_output2]) +@pytest.mark.uses_tuple_iterator @pytest.mark.uses_tuple_returns def test_tuple_output(program_processor, stencil): program_processor, validate = program_processor @@ -71,6 +72,7 @@ def tuple_of_tuple_output2(inp1, inp2, inp3, inp4): return make_tuple(deref(inp1), deref(inp2)), make_tuple(deref(inp3), deref(inp4)) +@pytest.mark.uses_tuple_iterator @pytest.mark.uses_tuple_returns def test_tuple_of_tuple_of_field_output(program_processor): program_processor, validate = program_processor @@ -109,6 +111,7 @@ def stencil(inp1, inp2, inp3, inp4): @pytest.mark.parametrize("stencil", [tuple_output1, tuple_output2]) +@pytest.mark.uses_tuple_iterator def test_tuple_of_field_output_constructed_inside(program_processor, stencil): program_processor, validate = program_processor @@ -144,6 +147,7 @@ def fencil(size0, size1, size2, inp1, inp2, out1, out2): assert np.allclose(inp2.asnumpy(), out2.asnumpy()) +@pytest.mark.uses_tuple_iterator def test_asymetric_nested_tuple_of_field_output_constructed_inside(program_processor): program_processor, validate = program_processor @@ -219,7 +223,6 @@ def tuple_input(inp): @pytest.mark.uses_tuple_args -@pytest.mark.uses_tuple_iterator def test_tuple_field_input(program_processor): program_processor, validate = program_processor @@ -273,7 +276,6 @@ def tuple_tuple_input(inp): @pytest.mark.uses_tuple_args -@pytest.mark.uses_tuple_iterator def test_tuple_of_tuple_of_field_input(program_processor): program_processor, validate = program_processor @@ -321,7 +323,6 @@ def test_field_of_2_extra_dim_input(program_processor): @pytest.mark.uses_tuple_args -@pytest.mark.uses_tuple_iterator def test_scalar_tuple_args(program_processor): @fundef def stencil(inp): @@ -351,7 +352,6 @@ def stencil(inp): @pytest.mark.uses_tuple_args -@pytest.mark.uses_tuple_iterator def test_mixed_field_scalar_tuple_arg(program_processor): @fundef def stencil(inp): diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 22b9f5f9e8..24cd3426e8 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -142,7 +142,6 @@ def k_level_condition_upper_tuple(k_idx, k_level): gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), ), lambda inp: np.concatenate([(inp[0][1:] + inp[1][1:]), [0]]), - marks=pytest.mark.uses_tuple_iterator, ), ], ) @@ -257,6 +256,7 @@ def ksum_even_odd_fencil(i_size, k_size, inp, out): @pytest.mark.uses_scan +@pytest.mark.uses_tuple_iterator def test_ksum_even_odd_scan(program_processor): program_processor, validate = program_processor shape = [1, 7] @@ -305,6 +305,7 @@ def ksum_even_odd_nested_fencil(i_size, k_size, inp, out): @pytest.mark.uses_scan +@pytest.mark.uses_tuple_iterator def test_ksum_even_odd_nested_scan(program_processor): program_processor, validate = program_processor shape = [1, 7] From 9d890935229e22d16cae4e724fd9c2ad50803808 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 31 Oct 2025 09:44:00 +0100 Subject: [PATCH 41/44] edit comment --- .../program_processors/runners/dace/gtir_to_sdfg_scan.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py index 4e25cd3104..11213e2d02 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py @@ -581,10 +581,15 @@ def _handle_dataflow_result_of_nested_sdfg( field_domain: infer_domain.NonTupleDomainAccess, ) -> gtir_dataflow.DataflowOutputEdge | None: if isinstance(field_domain, domain_utils.SymbolicDomain): + # The field is used outside the nested SDFG, therefore it needs to be copied + # to a temporary array in the parent SDFG (outer context). return _connect_nested_sdfg_output_to_temporaries( inner_ctx, outer_ctx, nsdfg_node, inner_data ) else: + # The field is not used outside the nested SDFG. It is likely just storage + # for some internal state, accessed during column scan, and can be turned + # into a transient array inside the nested SDFG. assert field_domain == infer_domain.DomainAccessDescriptor.NEVER inner_data.dc_node.desc(inner_ctx.sdfg).transient = True nsdfg_node.out_connectors.pop(inner_data.dc_node.data) From 6eddd52d2639820d9ffce8ed4d00f51d43a59bfe Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 31 Oct 2025 12:02:57 +0100 Subject: [PATCH 42/44] apply review comments --- .../runners/dace/gtir_to_sdfg.py | 4 +-- .../runners/dace/gtir_to_sdfg_primitives.py | 28 ++++++++++++------- .../runners/dace/gtir_to_sdfg_scan.py | 10 +++---- .../runners/dace/gtir_to_sdfg_types.py | 2 +- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py index 964e07255c..88165dc2ee 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg.py @@ -753,8 +753,8 @@ def _visit_target( ) gtx_utils.tree_map( - lambda source, target, domain_, target_state_=target_state: _visit_target( - source, target, domain_, target_state_ + lambda source, target, target_domain: _visit_target( + source, target, target_domain, target_state ) )(source_tree, target_tree, domain) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py index 1f4298c591..a3072187f7 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_primitives.py @@ -426,12 +426,8 @@ def translate_if( false_br_result = sdfg_builder.visit(false_expr, ctx=fbranch_ctx) node_output = gtx_utils.tree_map( - lambda domain, - true_br, - false_br, - _ctx=ctx, - sdfg_builder=sdfg_builder: _construct_if_branch_output( - ctx=_ctx, + lambda domain, true_br, false_br: _construct_if_branch_output( + ctx=ctx, sdfg_builder=sdfg_builder, field_domain=gtir_domain.get_field_domain(domain), true_br=true_br, @@ -442,10 +438,10 @@ def translate_if( true_br_result, false_br_result, ) - gtx_utils.tree_map(lambda src, dst, _ctx=tbranch_ctx: _write_if_branch_output(_ctx, src, dst))( + gtx_utils.tree_map(lambda src, dst: _write_if_branch_output(tbranch_ctx, src, dst))( true_br_result, node_output ) - gtx_utils.tree_map(lambda src, dst, _ctx=fbranch_ctx: _write_if_branch_output(_ctx, src, dst))( + gtx_utils.tree_map(lambda src, dst: _write_if_branch_output(fbranch_ctx, src, dst))( false_br_result, node_output ) @@ -600,9 +596,21 @@ def translate_tuple_get( data_nodes = sdfg_builder.visit(node.args[1], ctx=ctx) if isinstance(data_nodes, gtir_to_sdfg_types.FieldopData): raise ValueError(f"Invalid tuple expression {node}") - unused_arg_nodes: Iterable[gtir_to_sdfg_types.FieldopData] = gtx_utils.flatten_nested_tuple( - tuple(arg for i, arg in enumerate(data_nodes) if arg is not None and i != index) + # Now we remove the tuple fields that are not used, to avoid an SDFG validation + # error because of isolated access nodes. + unused_arg_nodes = gtx_utils.flatten_nested_tuple( + tuple(arg for i, arg in enumerate(data_nodes) if i != index) ) + # However, for temporary fields inside the tuple (non-globals and non-scalar + # values, supposed to contain the result of some field operator) the gt4py + # domain inference should have already set an empty domain, so the corresponding + # `arg` is expected to be None and can be ignored. + assert all( + not arg.dc_node.desc(ctx.sdfg).transient or isinstance(arg.gt_type, ts.ScalarType) + for arg in unused_arg_nodes + if arg is not None + ) + unused_arg_nodes = tuple(arg for arg in unused_arg_nodes if arg is not None) ctx.state.remove_nodes_from( [arg.dc_node for arg in unused_arg_nodes if ctx.state.degree(arg.dc_node) == 0] ) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py index 11213e2d02..575bd90d56 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_scan.py @@ -83,7 +83,7 @@ def _parse_fieldop_arg_impl( return _parse_fieldop_arg_impl(arg) else: # handle tuples of fields - return gtx_utils.tree_map(lambda x: _parse_fieldop_arg_impl(x))(arg) + return gtx_utils.tree_map(_parse_fieldop_arg_impl)(arg) def _create_scan_field_operator_impl( @@ -276,9 +276,9 @@ def _create_scan_field_operator( ) return gtx_utils.tree_map( - lambda edge, domain, sym, ctx_=ctx: ( + lambda edge, domain, sym: ( _create_scan_field_operator_impl( - ctx_, + ctx, sdfg_builder, edge, domain, @@ -676,8 +676,8 @@ def translate_scan( lambda_args_mapping = [ (im.sym(_scan_input_name(scan_carry), scan_carry_type), init_data), ] + [ - (param, arg) - for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) + (gt_symbol, arg) + for gt_symbol, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) if arg is not None ] diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py index 53a2a2b0ca..c2f5fbd081 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_types.py @@ -183,7 +183,7 @@ def get_symbol_mapping( FieldopResult: TypeAlias = MaybeNestedInTuple[FieldopData | None] """Result of a field operator, can be either a field or a tuple fields. -For tuple of fields, any of the nested fields can be None, in case it is not udes +For tuple of fields, any of the nested fields can be None, in case it is not used and therefore does not need to be computed. The information whether a field needs to be computed or not is the result of GTIR domain inference, and it is stored in the GTIR node annex domain. From 40dccd7751a2df6f3ad064403854502bc1cf4bca Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 3 Nov 2025 09:17:09 +0100 Subject: [PATCH 43/44] remove normalize_domain --- src/gt4py/next/common.py | 11 ----------- src/gt4py/next/embedded/operators.py | 2 +- src/gt4py/next/ffront/decorator.py | 2 +- src/gt4py/next/iterator/embedded.py | 6 +++--- 4 files changed, 5 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index ff2a0bb31b..7698419bbe 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -30,7 +30,6 @@ Final, Generic, Literal, - MaybeNestedInTuple, NamedTuple, Never, Optional, @@ -590,16 +589,6 @@ def __getstate__(self) -> dict[str, Any]: ) # `Domain` is `Sequence[NamedRange]` and therefore a subset -def normalize_domains(domain_like: MaybeNestedInTuple[DomainLike]) -> MaybeNestedInTuple[Domain]: - """ - Convert a potentially nested tuple structure of `DomainLike` objects to `Domain` objects. - """ - if isinstance(domain_like, tuple): - return tuple(normalize_domains(item) for item in domain_like) - else: - return domain(domain_like) - - def domain(domain_like: DomainLike) -> Domain: """ Construct `Domain` from `DomainLike` object. diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index 9398111e5c..116772cc9b 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -110,7 +110,7 @@ def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) -> domain = kwargs.pop("domain", None) out_domain = ( - common.normalize_domains(domain) if domain is not None else _get_out_domain(out) + utils.tree_map(common.domain)(domain) if domain is not None else _get_out_domain(out) ) new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index bd1f06a9f2..308097692c 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -714,7 +714,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: raise errors.MissingArgumentError(None, "out", True) out = kwargs.pop("out") if "domain" in kwargs: - domain = common.normalize_domains(kwargs.pop("domain")) + domain = utils.tree_map(common.domain)(kwargs.pop("domain")) if not isinstance(domain, tuple): domain = utils.tree_map(lambda _: domain)(out) out = utils.tree_map(lambda f, dom: f[dom])(out, domain) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index b6c663ac42..db3b98fed6 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -45,7 +45,7 @@ overload, runtime_checkable, ) -from gt4py.next import common, field_utils +from gt4py.next import common, field_utils, utils from gt4py.next.embedded import ( context as embedded_context, exceptions as embedded_exceptions, @@ -1634,8 +1634,8 @@ def set_at( domain: xtyping.MaybeNestedInTuple[common.DomainLike], target: common.MutableField, ) -> None: - domain = common.normalize_domains(domain) - operators._tuple_assign_field(target, expr, domain) + domain_ = utils.tree_map(common.domain)(domain) + operators._tuple_assign_field(target, expr, domain_) @runtime.get_domain_range.register(EMBEDDED) From 97e0a33df5ae4775748feaf50caa4847193165a3 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 3 Nov 2025 09:18:31 +0100 Subject: [PATCH 44/44] rename --- src/gt4py/next/iterator/embedded.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index db3b98fed6..efbad21c2b 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1631,11 +1631,11 @@ def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProvider @runtime.set_at.register(EMBEDDED) def set_at( expr: common.Field, - domain: xtyping.MaybeNestedInTuple[common.DomainLike], + domain_like: xtyping.MaybeNestedInTuple[common.DomainLike], target: common.MutableField, ) -> None: - domain_ = utils.tree_map(common.domain)(domain) - operators._tuple_assign_field(target, expr, domain_) + domain = utils.tree_map(common.domain)(domain_like) + operators._tuple_assign_field(target, expr, domain) @runtime.get_domain_range.register(EMBEDDED)