From ab9b51b4928d7deb8f459561fc74f69691db73dd Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 14 Feb 2025 05:02:13 +0100 Subject: [PATCH 1/3] Clip domain --- src/gt4py/next/iterator/ir_utils/domain_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 27900b6db6..679ffedbe4 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -46,7 +46,10 @@ class SymbolicRange: stop: itir.Expr def translate(self, distance: int) -> SymbolicRange: - return SymbolicRange(im.plus(self.start, distance), im.plus(self.stop, distance)) + start = im.plus(self.start, distance) + # TODO(tehrengruber): temporary solution to avoid oob-access without concat_where + start = im.call("maximum")(0, start) + return SymbolicRange(start, im.plus(self.stop, distance)) @dataclasses.dataclass From 04fa40b888c05940f2657895676ba4c2020da669 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 14 Feb 2025 07:11:57 +0100 Subject: [PATCH 2/3] Don't use negative domains in domain inference tests --- .../transforms_tests/test_domain_inference.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 4a2a441510..204e6c4b05 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -224,8 +224,8 @@ def test_laplace(offset_provider): im.deref(im.shift("Joff", -1)("arg0")), ) ) - domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (0, 7)}) - expected_domains = {"in_field1": {IDim: (-1, 12), JDim: (-1, 8)}} + domain = im.domain(common.GridType.CARTESIAN, {IDim: (1, 11), JDim: (1, 7)}) + expected_domains = {"in_field1": {IDim: (0, 12), JDim: (0, 8)}} testee, expected = setup_test_as_fieldop(stencil, domain) run_test_expr(testee, expected, domain, expected_domains, offset_provider) @@ -238,10 +238,10 @@ def test_shift_x_y_two_inputs(offset_provider): im.deref(im.shift("Joff", 1)("arg1")), ) ) - domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (0, 7)}) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (1, 11), JDim: (0, 7)}) expected_domains = { - "in_field1": {IDim: (-1, 10), JDim: (0, 7)}, - "in_field2": {IDim: (0, 11), JDim: (1, 8)}, + "in_field1": {IDim: (0, 10), JDim: (0, 7)}, + "in_field2": {IDim: (1, 11), JDim: (1, 8)}, } testee, expected = setup_test_as_fieldop( stencil, @@ -257,9 +257,9 @@ def test_shift_x_y_two_inputs_literal(offset_provider): im.deref(im.shift("Joff", 1)("arg1")), ) ) - domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (0, 7)}) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (1, 11), JDim: (0, 7)}) expected_domains = { - "in_field1": {IDim: (-1, 10), JDim: (0, 7)}, + "in_field1": {IDim: (0, 10), JDim: (0, 7)}, } testee, expected = setup_test_as_fieldop( stencil, @@ -279,11 +279,11 @@ def test_shift_x_y_z_three_inputs(offset_provider): im.deref(im.shift("Koff", -1)("arg2")), ) ) - domain_dict = {IDim: (0, 11), JDim: (0, 7), KDim: (0, 3)} + domain_dict = {IDim: (0, 11), JDim: (0, 7), KDim: (1, 3)} expected_domains = { - "in_field1": {IDim: (1, 12), JDim: (0, 7), KDim: (0, 3)}, - "in_field2": {IDim: (0, 11), JDim: (1, 8), KDim: (0, 3)}, - "in_field3": {IDim: (0, 11), JDim: (0, 7), KDim: (-1, 2)}, + "in_field1": {IDim: (1, 12), JDim: (0, 7), KDim: (1, 3)}, + "in_field2": {IDim: (0, 11), JDim: (1, 8), KDim: (1, 3)}, + "in_field3": {IDim: (0, 11), JDim: (0, 7), KDim: (0, 2)}, } testee, expected = setup_test_as_fieldop( stencil, @@ -329,7 +329,7 @@ def test_nested_stencils(offset_provider): tmp = im.as_fieldop(inner_stencil)(im.ref("in_field1"), im.ref("in_field2")) testee = im.as_fieldop(stencil)(im.ref("in_field1"), tmp) - domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (0, 7)}) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (1, 7)}) domain_inner = translate_domain(domain, {"Ioff": 0, "Joff": -1}, offset_provider) expected_inner = im.as_fieldop(inner_stencil, domain_inner)( @@ -338,7 +338,7 @@ def test_nested_stencils(offset_provider): expected = im.as_fieldop(stencil, domain)(im.ref("in_field1"), expected_inner) expected_domains = { - "in_field1": im.domain(common.GridType.CARTESIAN, {IDim: (1, 12), JDim: (-1, 7)}), + "in_field1": im.domain(common.GridType.CARTESIAN, {IDim: (1, 12), JDim: (0, 7)}), "in_field2": translate_domain(domain, {"Ioff": 0, "Joff": -2}, offset_provider), } actual_call, actual_domains = infer_domain.infer_expr( @@ -515,9 +515,9 @@ def test_cond(offset_provider): testee = im.if_(cond, field_1, field_2) - domain = im.domain(common.GridType.CARTESIAN, {"IDim": (0, 11)}) + domain = im.domain(common.GridType.CARTESIAN, {"IDim": (2, 13)}) domain_tmp = translate_domain(domain, {"Ioff": -1}, offset_provider) - expected_domains_dict = {"in_field1": {IDim: (0, 12)}, "in_field2": {IDim: (-2, 12)}} + expected_domains_dict = {"in_field1": {IDim: (2, 14)}, "in_field2": {IDim: (0, 14)}} expected_tmp2 = im.as_fieldop(tmp_stencil2, domain_tmp)( im.ref("in_field1"), im.ref("in_field2") ) @@ -731,7 +731,7 @@ def test_nested_let_args(offset_provider): ), )(premap_field("inner", "Ioff", -1)) - domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (2, 11)}) domain_m1 = translate_domain(domain, {"Ioff": -1}, offset_provider) domain_m2 = translate_domain(domain, {"Ioff": -2}, offset_provider) @@ -754,9 +754,9 @@ def test_program_let(offset_provider): let_tmp = im.let("inner", premap_field("outer", "Ioff", -1))(premap_field("inner", "Ioff", -1)) as_fieldop = im.as_fieldop(stencil_tmp)(im.ref("tmp")) - domain_lm2_rm1 = im.domain(common.GridType.CARTESIAN, {IDim: (-2, 10)}) - domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - domain_lm1 = im.domain(common.GridType.CARTESIAN, {IDim: (-1, 11)}) + domain_lm2_rm1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (2, 11)}) + domain_lm1 = im.domain(common.GridType.CARTESIAN, {IDim: (1, 11)}) params = [im.sym(name) for name in ["in_field", "out_field", "outer"]] From 45ad53a86861aa23191197572fdec7782f6a8abc Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 19 Mar 2025 11:57:15 +0100 Subject: [PATCH 3/3] Remaining fixes to get PMAP-G to work --- .../iterator/transforms/fuse_as_fieldop.py | 9 ++++++++- .../next/iterator/transforms/global_tmps.py | 6 ++++-- .../inline_center_deref_lift_vars.py | 19 ++++++++++++------- src/gt4py/next/type_system/type_info.py | 2 +- 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 81633dfb87..8b54ab08a9 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -345,9 +345,15 @@ def apply( if not uids: uids = eve_utils.UIDGenerator() - return cls(uids=uids, enabled_transformations=enabled_transformations).visit( + new_node = cls(uids=uids, enabled_transformations=enabled_transformations).visit( node, within_set_at_expr=within_set_at_expr ) + new_node = type_inference.infer( + new_node, + offset_provider_type=offset_provider_type, + allow_undeclared_symbols=allow_undeclared_symbols, + ) + return new_node def transform_fuse_make_tuple(self, node: itir.Node, **kwargs): if not cpm.is_call_to(node, "make_tuple"): @@ -429,6 +435,7 @@ def transform_fuse_as_fieldop(self, node: itir.Node, **kwargs): return None def transform_inline_let_vars_opcount_preserving(self, node: itir.Node, **kwargs): + return None # when multiple `as_fieldop` calls are fused that use the same argument, this argument # might become referenced once only. In order to be able to continue fusing such arguments # try inlining here. diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index ac7fcb8f1c..5ae3ee2558 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -81,8 +81,10 @@ def _transform_by_pattern( # able to eliminate all tuples, e.g., by propagating the scalar ifs to the top-level # of a SetAt, the CollapseTuple pass will eliminate most of this cases. if isinstance(domain, tuple): - flattened_domains: tuple[domain_utils.SymbolicDomain] = ( - next_utils.flatten_nested_tuple(domain) # type: ignore[assignment] # mypy not smart enough + flattened_domains: tuple[domain_utils.SymbolicDomain] = tuple( + domain + for domain in next_utils.flatten_nested_tuple(domain) + if domain is not infer_domain.DomainAccessDescriptor.NEVER # type: ignore[assignment] # mypy not smart enough ) if not all(d == flattened_domains[0] for d in flattened_domains): raise NotImplementedError( diff --git a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py index c0a8c9f1b7..3505870aeb 100644 --- a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py +++ b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py @@ -17,6 +17,7 @@ from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda from gt4py.next.iterator.transforms.inline_lifts import InlineLifts +from gt4py.next.iterator.type_system import inference as type_inference def is_center_derefed_only(node: itir.Node) -> bool: @@ -95,17 +96,21 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): for i, (param, arg) in enumerate(zip(node.fun.params, node.args)): if cpm.is_applied_lift(arg) and is_center_derefed_only(param): eligible_params[i] = True - bound_arg_evaluator = self.uids.sequential_id(prefix="_icdlv") - capture_lift = im.promote_to_const_iterator(im.call(bound_arg_evaluator)()) - trace_shifts.copy_recorded_shifts(from_=param, to=capture_lift) - new_args.append(capture_lift) - # since we deref an applied lift here we can (but don't need to) immediately - # inline - evaluators[bound_arg_evaluator] = im.lambda_()( + bound_arg_evaluator_name = self.uids.sequential_id(prefix="__icdlv") + bound_arg_evaluator = im.lambda_()( InlineLifts(flags=InlineLifts.Flag.INLINE_DEREF_LIFT).visit( im.deref(arg), recurse=False ) ) + capture_lift = im.promote_to_const_iterator( + im.call(im.ref(bound_arg_evaluator_name, bound_arg_evaluator.type))() + ) + trace_shifts.copy_recorded_shifts(from_=param, to=capture_lift) + + new_args.append(capture_lift) + # since we deref an applied lift here we can (but don't need to) immediately + # inline + evaluators[bound_arg_evaluator_name] = bound_arg_evaluator else: new_args.append(arg) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index bbaaa82728..8e8cf0b131 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -491,7 +491,7 @@ def is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec) -> bool: is_compatible &= is_compatible_type(arg_a, arg_b) is_compatible &= is_compatible_type(type_a.returns, type_b.returns) else: - is_compatible &= is_concretizable(type_a, type_b) + is_compatible &= is_concretizable(type_a, type_b) or is_concretizable(type_b, type_a) return is_compatible