diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 94459fce90..c1b08019cc 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -910,7 +910,9 @@ def remapping(cls) -> ConnectivityKind: class ConnectivityType: # TODO(havogt): would better live in type_specifications but would have to solve a circular import domain: tuple[Dimension, ...] codomain: Dimension - skip_value: Optional[core_defs.IntegralScalar] + skip_value: Optional[ + core_defs.IntegralScalar + ] # TODO(tehrengruber): isn't this a value of the `NeighborConnectivityType` only dtype: core_defs.DType @property @@ -918,6 +920,12 @@ def has_skip_values(self) -> bool: return self.skip_value is not None +@dataclasses.dataclass(frozen=True) +class CartesianConnectivityType(ConnectivityType): + domain: tuple[Dimension] + offset: int + + @dataclasses.dataclass(frozen=True) class NeighborConnectivityType(ConnectivityType): # TODO(havogt): refactor towards encoding this information in the local dimensions of the ConnectivityType.domain @@ -932,8 +940,7 @@ def neighbor_dim(self) -> Dimension: return self.domain[1] -@runtime_checkable -class Connectivity(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT_co]): +class Connectivity(Field[DimsT, core_defs.IntegralScalar], Generic[DimsT, DimT_co]): @property @abc.abstractmethod def codomain(self) -> DimT_co: @@ -947,22 +954,8 @@ def codomain(self) -> DimT_co: Currently, this would just complicate implementation as we do not use this information. """ - def __gt_type__(self) -> ConnectivityType: - if is_neighbor_connectivity(self): - return NeighborConnectivityType( - domain=self.domain.dims, - codomain=self.codomain, - dtype=self.dtype, - skip_value=self.skip_value, - max_neighbors=self.ndarray.shape[1], - ) - else: - return ConnectivityType( - domain=self.domain.dims, - codomain=self.codomain, - dtype=self.dtype, - skip_value=self.skip_value, - ) + @abc.abstractmethod + def __gt_type__(self) -> ConnectivityType: ... @property def kind(self) -> ConnectivityKind: @@ -1034,6 +1027,115 @@ def __xor__(self, other: Field | core_defs.IntegralScalar) -> Never: raise TypeError("'Connectivity' does not support this operation.") +DomainDimT = TypeVar("DomainDimT", bound="Dimension") + + +@dataclasses.dataclass(frozen=True, eq=False) +class CartesianConnectivity(Connectivity[Dims[DomainDimT], DimT]): + domain_dim: DomainDimT + codomain: DimT + offset: int = 0 + + def __init__( + self, domain_dim: DomainDimT, offset: int = 0, *, codomain: Optional[DimT] = None + ) -> None: + object.__setattr__(self, "domain_dim", domain_dim) + object.__setattr__(self, "codomain", codomain if codomain is not None else domain_dim) + object.__setattr__(self, "offset", offset) + + @classmethod + def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ignore[override] + raise NotImplementedError() + + @property + def ndarray(self) -> Never: + raise NotImplementedError() + + def asnumpy(self) -> Never: + raise NotImplementedError() + + def as_scalar(self) -> Never: + raise NotImplementedError() + + @functools.cached_property + def domain(self) -> Domain: + return Domain(dims=(self.domain_dim,), ranges=(UnitRange.infinite(),)) + + @property + def __gt_origin__(self) -> Never: + raise TypeError("'CartesianConnectivity' does not support this operation.") + + def __gt_type__(self) -> CartesianConnectivityType: + assert len(self.domain.dims) == 1 + return CartesianConnectivityType( + domain=self.domain.dims, + codomain=self.codomain, + dtype=self.dtype, + skip_value=self.skip_value, + offset=self.offset, + ) + + @property + def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: + return core_defs.Int32DType() # type: ignore[return-value] + + # This is a workaround to make this class concrete, since `codomain` is an + # abstract property of the `Connectivity` Protocol. + if not TYPE_CHECKING: + + @functools.cached_property + def codomain(self) -> DimT: + raise RuntimeError("This property should be always set in the constructor.") + + @property + def skip_value(self) -> None: + return None + + @functools.cached_property + def kind(self) -> ConnectivityKind: + return ( + ConnectivityKind.translation() + if self.domain_dim == self.codomain + else ConnectivityKind.relocation() + ) + + @classmethod + def for_translation( + cls, dimension: DomainDimT, offset: int + ) -> CartesianConnectivity[DomainDimT, DomainDimT]: + return cast(CartesianConnectivity[DomainDimT, DomainDimT], cls(dimension, offset)) + + @classmethod + def for_relocation(cls, old: DimT, new: DomainDimT) -> CartesianConnectivity[DomainDimT, DimT]: + return cls(new, codomain=old) + + def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: + if not isinstance(image_range, UnitRange): + if image_range.dim != self.codomain: + raise ValueError( + f"Dimension '{image_range.dim}' does not match the codomain dimension '{self.codomain}'." + ) + + image_range = image_range.unit_range + + assert isinstance(image_range, UnitRange) + return (named_range((self.domain_dim, image_range - self.offset)),) + + def premap( + self, + index_field: Connectivity | fbuiltins.FieldOffset, + *args: Connectivity | fbuiltins.FieldOffset, + ) -> Connectivity: + raise NotImplementedError() + + __call__ = premap + + def restrict(self, index: AnyIndexSpec) -> Never: + raise NotImplementedError() # we could possibly implement with a FunctionField, but we don't have a use-case + + __getitem__ = restrict + + # Utility function to construct a `Field` from different buffer representations. # Consider removing this function and using `Field` constructor directly. See also `_connectivity`. @functools.singledispatch @@ -1061,12 +1163,19 @@ def _connectivity( raise NotImplementedError -class NeighborConnectivity(Connectivity, Protocol): - # TODO(havogt): work towards encoding this properly in the type - def __gt_type__(self) -> NeighborConnectivityType: ... +class NeighborConnectivity(Connectivity[DimsT, DimT_co]): + def __gt_type__(self) -> NeighborConnectivityType: + return NeighborConnectivityType( + domain=self.domain.dims, + codomain=self.codomain, + dtype=self.dtype, + skip_value=self.skip_value, + max_neighbors=self.ndarray.shape[1], + ) def is_neighbor_connectivity(obj: Any) -> TypeGuard[NeighborConnectivity]: + # TODO: reevaluate if not isinstance(obj, Connectivity): return False domain_dims = obj.domain.dims @@ -1078,7 +1187,7 @@ def is_neighbor_connectivity(obj: Any) -> TypeGuard[NeighborConnectivity]: class NeighborTable( - NeighborConnectivity, Protocol + NeighborConnectivity ): # TODO(havogt): try to express by inheriting from NdArrayConnectivityField (but this would require a protocol to move it out of `embedded.nd_array_field`) @property def ndarray(self) -> core_defs.NDArrayObject: @@ -1088,12 +1197,15 @@ def ndarray(self) -> core_defs.NDArrayObject: ... -def is_neighbor_table(obj: Any) -> TypeGuard[NeighborTable]: - return is_neighbor_connectivity(obj) and hasattr(obj, "ndarray") +# TODO: delete. A protocol and duck typing in it's current form is not enough since we use the +# type of the connectivity to propagate structural information, e.g. that we have a cartesian +# and not a neighbor connectivity. We would need to extend the protocol for this +# def is_neighbor_table(obj: Any) -> TypeGuard[NeighborTable]: +# return is_neighbor_connectivity(obj) and hasattr(obj, "ndarray") -OffsetProviderElem: TypeAlias = Dimension | NeighborConnectivity -OffsetProviderTypeElem: TypeAlias = Dimension | NeighborConnectivityType +OffsetProviderElem: TypeAlias = CartesianConnectivity | NeighborConnectivity +OffsetProviderTypeElem: TypeAlias = CartesianConnectivityType | NeighborConnectivityType # Note: `OffsetProvider` and `OffsetProviderType` should not be accessed directly, # use the `get_offset` and `get_offset_type` functions instead. OffsetProvider: TypeAlias = Mapping[Tag, OffsetProviderElem] @@ -1133,8 +1245,6 @@ def get_offset(offset_provider: OffsetProvider, offset_tag: str) -> OffsetProvid `OffsetProviderType` should go through this function. """ # TODO(havogt): Once we have a custom class for `OffsetProvider`, we can absorb this functionality into it. - if offset_tag.startswith(_IMPLICIT_OFFSET_PREFIX): - return Dimension(value=_get_dimension_name_from_implicit_offset(offset_tag)) if offset_tag not in offset_provider: raise KeyError(f"Offset '{offset_tag}' not found in offset provider.") return offset_provider[offset_tag] # TODO return a valid dimension @@ -1165,105 +1275,6 @@ def hash_offset_provider_items_by_id(offset_provider: OffsetProvider) -> int: return hash(tuple((k, id(v)) for k, v in offset_provider.items())) -DomainDimT = TypeVar("DomainDimT", bound="Dimension") - - -@dataclasses.dataclass(frozen=True, eq=False) -class CartesianConnectivity(Connectivity[Dims[DomainDimT], DimT]): - domain_dim: DomainDimT - codomain: DimT - offset: int = 0 - - def __init__( - self, domain_dim: DomainDimT, offset: int = 0, *, codomain: Optional[DimT] = None - ) -> None: - object.__setattr__(self, "domain_dim", domain_dim) - object.__setattr__(self, "codomain", codomain if codomain is not None else domain_dim) - object.__setattr__(self, "offset", offset) - - @classmethod - def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ignore[override] - raise NotImplementedError() - - @property - def ndarray(self) -> Never: - raise NotImplementedError() - - def asnumpy(self) -> Never: - raise NotImplementedError() - - def as_scalar(self) -> Never: - raise NotImplementedError() - - @functools.cached_property - def domain(self) -> Domain: - return Domain(dims=(self.domain_dim,), ranges=(UnitRange.infinite(),)) - - @property - def __gt_origin__(self) -> Never: - raise TypeError("'CartesianConnectivity' does not support this operation.") - - @property - def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: - return core_defs.Int32DType() # type: ignore[return-value] - - # This is a workaround to make this class concrete, since `codomain` is an - # abstract property of the `Connectivity` Protocol. - if not TYPE_CHECKING: - - @functools.cached_property - def codomain(self) -> DimT: - raise RuntimeError("This property should be always set in the constructor.") - - @property - def skip_value(self) -> None: - return None - - @functools.cached_property - def kind(self) -> ConnectivityKind: - return ( - ConnectivityKind.translation() - if self.domain_dim == self.codomain - else ConnectivityKind.relocation() - ) - - @classmethod - def for_translation( - cls, dimension: DomainDimT, offset: int - ) -> CartesianConnectivity[DomainDimT, DomainDimT]: - return cast(CartesianConnectivity[DomainDimT, DomainDimT], cls(dimension, offset)) - - @classmethod - def for_relocation(cls, old: DimT, new: DomainDimT) -> CartesianConnectivity[DomainDimT, DimT]: - return cls(new, codomain=old) - - def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: - if not isinstance(image_range, UnitRange): - if image_range.dim != self.codomain: - raise ValueError( - f"Dimension '{image_range.dim}' does not match the codomain dimension '{self.codomain}'." - ) - - image_range = image_range.unit_range - - assert isinstance(image_range, UnitRange) - return (named_range((self.domain_dim, image_range - self.offset)),) - - def premap( - self, - index_field: Connectivity | fbuiltins.FieldOffset, - *args: Connectivity | fbuiltins.FieldOffset, - ) -> Connectivity: - raise NotImplementedError() - - __call__ = premap - - def restrict(self, index: AnyIndexSpec) -> Never: - raise NotImplementedError() # we could possibly implement with a FunctionField, but we don't have a use-case - - __getitem__ = restrict - - @enum.unique class GridType(StrEnum): CARTESIAN = "cartesian" @@ -1274,7 +1285,13 @@ def order_dimensions(dims: Iterable[Dimension]) -> list[Dimension]: """Find the canonical ordering of the dimensions in `dims`.""" if sum(1 for dim in dims if dim.kind == DimensionKind.LOCAL) > 1: raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.") - return sorted(dims, key=lambda dim: (_DIM_KIND_ORDER[dim.kind], dim.value)) + return sorted( + dims, + key=lambda dim: ( + _DIM_KIND_ORDER[dim.kind], + flip_staggered(dim).value if is_staggered(dim) else dim.value, + ), + ) def check_dims(dims: Sequence[Dimension]) -> None: @@ -1362,3 +1379,30 @@ def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Call #: Equivalent to the `_FillValue` attribute in the UGRID Conventions #: (see: http://ugrid-conventions.github.io/ugrid-conventions/). _DEFAULT_SKIP_VALUE: Final[int] = -1 +_STAGGERED_PREFIX = "_Staggered" + + +def is_staggered(dim: Dimension) -> bool: + return dim.value.startswith(_STAGGERED_PREFIX) + + +def flip_staggered(dim: Dimension) -> Dimension: + if is_staggered(dim): + return Dimension(dim.value[len(_STAGGERED_PREFIX) :], dim.kind) + else: + return Dimension(f"{_STAGGERED_PREFIX}{dim.value}", dim.kind) + +def as_non_staggered(dim: Dimension) -> Dimension: + if is_staggered(dim): + return flip_staggered(dim) + return dim + +def connectivity_for_cartesian_shift(dim: Dimension, offset: int | float) -> CartesianConnectivity: + if isinstance(offset, float): + integral_offset, half = divmod(offset, 1) + assert half == 0.5 + if not dim.value.startswith(_STAGGERED_PREFIX): + integral_offset += 1 + return CartesianConnectivity(dim, int(integral_offset), codomain=flip_staggered(dim)) + else: + return CartesianConnectivity(dim, offset, codomain=dim) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index f4aee67332..b6adf2bcee 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -478,7 +478,7 @@ def _dace_descriptor(self) -> dace.data.Data: @dataclasses.dataclass(frozen=True) class NdArrayConnectivityField( - common.Connectivity[common.DimsT, common.DimT], + common.NeighborConnectivity[common.DimsT, common.DimT], NdArrayField[common.DimsT, core_defs.IntegralScalar], ): _codomain: common.DimT @@ -1017,7 +1017,7 @@ def _builtin_op( offset_definition = common.get_offset( current_offset_provider, axis.value ) # assumes offset and local dimension have same name - assert common.is_neighbor_table(offset_definition) + assert isinstance(offset_definition, common.NeighborConnectivity) new_domain = common.Domain(*[nr for nr in field.domain if nr.dim != axis]) broadcast_slice = tuple( diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 68bf108a0a..5265cfe716 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -10,7 +10,7 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits -from gt4py.next import errors, utils +from gt4py.next import common, errors, utils from gt4py.next.common import DimensionKind, promote_dims from gt4py.next.ffront import ( # noqa dialect_ast_enums, @@ -655,13 +655,22 @@ def _deduce_compare_type( def _deduce_binop_type( self, node: foast.BinOp, *, left: foast.Expr, right: foast.Expr, **kwargs: Any ) -> Optional[ts.TypeSpec]: - # e.g. `IDim+1` + # e.g. `IDim+1` or `IDim+0.5` if ( isinstance(left.type, ts.DimensionType) and isinstance(right.type, ts.ScalarType) - and type_info.is_integral(right.type) + and type_info.is_arithmetic(right.type) ): - return ts.OffsetType(source=left.type.dim, target=(left.type.dim,)) + if not isinstance(right, foast.Constant): + raise NotImplementedError( + "Cartesian offsets are only supported with literal rhs, e.g. `IDim + 1`, but not `IDim + expr`." + ) + offset_index = right.value + if node.op == dialect_ast_enums.BinaryOperator.SUB: + offset_index *= -1 + conn = common.connectivity_for_cartesian_shift(left.type.dim, offset_index) + return ts.OffsetType(source=conn.codomain, target=(conn.domain_dim,)) + if isinstance(left.type, ts.OffsetType): raise errors.DSLError( node.location, f"Type '{left.type}' can not be used in operator '{node.op}'." diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index d74fb5dce8..3a5839450e 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -302,17 +302,19 @@ def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: # `field(Dim + idx)` case foast.BinOp( op=dialect_ast_enums.BinaryOperator.ADD | dialect_ast_enums.BinaryOperator.SUB, - left=foast.Name(id=dimension), # TODO(tehrengruber): use type of lhs + left=foast.Name(), # TODO(tehrengruber): use type instead right=foast.Constant(value=offset_index), ): if arg.op == dialect_ast_enums.BinaryOperator.SUB: offset_index *= -1 - # TODO(havogt): we rely on the naming-convention for implicit offsets, see `dimension_to_implicit_offset` + conn = common.connectivity_for_cartesian_shift( + node.args[0].left.type.dim, offset_index + ) current_expr = im.as_fieldop( im.lambda_("__it")( im.deref( im.shift( - common.dimension_to_implicit_offset(dimension), offset_index + im.cartesian_offset(conn.domain_dim, conn.codomain), conn.offset )("__it") ) ) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 5e3ab441d2..1b9b17e65e 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -88,10 +88,6 @@ def __str__(self): InfinityLiteral.POSITIVE = InfinityLiteral(name="POSITIVE") -class OffsetLiteral(Expr): - value: Union[int, str] - - class AxisLiteral(Expr): # TODO(havogt): Refactor to use declare Axis/Dimension at the Program level. # Now every use of the literal has to provide the kind, where usually we only care of the name. @@ -99,6 +95,16 @@ class AxisLiteral(Expr): kind: common.DimensionKind = common.DimensionKind.HORIZONTAL +class CartesianOffset(Expr): + domain: AxisLiteral + codomain: AxisLiteral + + +# TODO(tehrengruber): allow int only and create OffsetRef for str instead +class OffsetLiteral(Expr): + value: Union[int, str] + + class SymRef(Expr): id: Coerced[SymbolRef] @@ -157,8 +163,9 @@ class Program(Node, ValidatedSymbolTableTrait): Expr.__hash__ = Node.__hash__ # type: ignore[method-assign] Literal.__hash__ = Node.__hash__ # type: ignore[method-assign] NoneLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign] -OffsetLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign] AxisLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign] +OffsetLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign] +CartesianOffset.__hash__ = Node.__hash__ # type: ignore[method-assign] SymRef.__hash__ = Node.__hash__ # type: ignore[method-assign] Lambda.__hash__ = Node.__hash__ # type: ignore[method-assign] FunCall.__hash__ = Node.__hash__ # type: ignore[method-assign] diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 3fa088d785..124d855721 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -116,27 +116,44 @@ def translate( #: func:`gt4py.next.iterator.transforms.infer_domain.infer_expr` for more details. symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> SymbolicDomain: - offset_provider_type = common.offset_provider_to_type(offset_provider) - dims = list(self.ranges.keys()) new_ranges = {dim: self.ranges[dim] for dim in dims} if len(shift) == 0: return self if len(shift) == 2: off, val = shift - assert isinstance(off, itir.OffsetLiteral) and isinstance(off.value, str) - connectivity_type = common.get_offset_type(offset_provider_type, off.value) - if isinstance(connectivity_type, common.Dimension): + connectivity: common.Connectivity + if isinstance(off, itir.CartesianOffset): + domain = common.Dimension(value=off.domain.value, kind=off.domain.kind) + codomain = common.Dimension(value=off.codomain.value, kind=off.codomain.kind) + connectivity = common.CartesianConnectivity(domain, codomain=codomain) + elif isinstance(off, itir.OffsetLiteral): + assert isinstance(off.value, str) + connectivity = common.get_offset(offset_provider, off.value) + else: + raise AssertionError() + + if isinstance(connectivity, common.CartesianConnectivity): if val is trace_shifts.Sentinel.VALUE: raise NotImplementedError("Dynamic offsets not supported.") assert isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int) - current_dim = connectivity_type + assert len(connectivity.domain.dims) == 1 # cartesian offset - new_ranges[current_dim] = SymbolicRange.translate( - self.ranges[current_dim], val.value + + old_dim = connectivity.domain.dims[0] + new_dim = connectivity.codomain + + assert new_dim not in new_ranges or old_dim == new_dim + + new_range = SymbolicRange.translate( + self.ranges[old_dim], connectivity.offset + val.value + ) + new_ranges = dict( + (dim, range_) if dim != old_dim else (new_dim, new_range) + for dim, range_ in new_ranges.items() ) - elif isinstance(connectivity_type, common.NeighborConnectivityType): + elif isinstance(connectivity, common.NeighborConnectivity): # unstructured shift assert ( isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int) @@ -157,8 +174,8 @@ def translate( for k, v in _max_domain_sizes_by_location_type(offset_provider).items() } - old_dim = connectivity_type.source_dim - new_dim = connectivity_type.codomain + old_dim = connectivity.domain.dims[0] + new_dim = connectivity.codomain assert new_dim not in new_ranges or old_dim == new_dim @@ -172,6 +189,7 @@ def translate( ) else: raise AssertionError() + return SymbolicDomain(self.grid_type, new_ranges) elif len(shift) > 2: return self.translate(shift[0:2], offset_provider, symbolic_domain_sizes).translate( diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index fefca65a62..33f11c4d54 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -578,6 +578,10 @@ def axis_literal(dim: common.Dimension) -> itir.AxisLiteral: return itir.AxisLiteral(value=dim.value, kind=dim.kind) +def cartesian_offset(domain: common.Dimension, codomain: common.Dimension): + return itir.CartesianOffset(domain=axis_literal(domain), codomain=axis_literal(codomain)) + + def cast_as_fieldop(type_: str, domain: Optional[itir.FunCall] = None): """ Promotes the function `cast_` to a field_operator. diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index 4f890019ec..87af68bec6 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -231,6 +231,20 @@ def grid_type_from_domain(domain: itir.FunCall) -> common.GridType: return common.GridType.UNSTRUCTURED +def dim_from_axis_literal(axis_literal: itir.AxisLiteral) -> common.Dimension: + return common.Dimension(value=axis_literal.value, kind=axis_literal.kind) + + +def connectivity_from_cartesian_offset( + cart_offset: itir.CartesianOffset, +) -> common.CartesianConnectivity: + return common.CartesianConnectivity( + domain_dim=dim_from_axis_literal(cart_offset.domain), + codomain=dim_from_axis_literal(cart_offset.codomain), + offset=0, + ) + + def _flatten_tuple_expr(expr: itir.Expr) -> tuple[itir.Expr]: if cpm.is_call_to(expr, "make_tuple"): return sum( diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 5063e26392..c827641560 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -143,6 +143,11 @@ def visit_InfinityLiteral(self, node: ir.InfinityLiteral, *, prec: int) -> list[ def visit_OffsetLiteral(self, node: ir.OffsetLiteral, *, prec: int) -> list[str]: return [str(node.value) + "ₒ"] + def visit_CartesianOffset(self, node: ir.CartesianOffset, *, prec: int) -> list[str]: + (domain,) = self.visit(node.domain, prec=0) + (codomain,) = self.visit(node.codomain, prec=0) + return [f"{domain}₂{codomain}"] + def visit_AxisLiteral(self, node: ir.AxisLiteral, *, prec: int) -> list[str]: kind = "" if node.kind == ir.DimensionKind.HORIZONTAL: diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 8173ceebbb..b198bd36a4 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -137,7 +137,8 @@ def _can_deref(x): def _shift(*offsets): assert all( - isinstance(offset, ir.OffsetLiteral) or offset in [Sentinel.ALL_NEIGHBORS, Sentinel.VALUE] + isinstance(offset, (ir.OffsetLiteral, ir.CartesianOffset)) + or offset in [Sentinel.ALL_NEIGHBORS, Sentinel.VALUE] for offset in offsets ) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 8ea0e43f50..bf67b643b4 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -473,6 +473,17 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs) -> it_ts.Offse assert isinstance(node.value, str) return it_ts.OffsetLiteralType(value=node.value) + def visit_CartesianOffset( + self, node: itir.CartesianOffset, *, ctx + ) -> it_ts.CartesianOffsetType | ts.DeferredType: + self.visit(node.domain, ctx=ctx) + self.visit(node.codomain, ctx=ctx) + domain, codomain = node.domain.type, node.codomain.type + if domain is None or codomain is None: + return ts.DeferredType(constraint=it_ts.CartesianOffsetType) + assert isinstance(domain, ts.DimensionType) and isinstance(codomain, ts.DimensionType) + return it_ts.CartesianOffsetType(domain=domain.dim, codomain=codomain.dim) + def visit_Literal(self, node: itir.Literal, **kwargs) -> ts.ScalarType: assert isinstance(node.type, ts.ScalarType) return node.type diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index 39e9e607ce..40c6dde6d7 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -20,6 +20,11 @@ class OffsetLiteralType(ts.TypeSpec): value: ts.ScalarType | str +class CartesianOffsetType(ts.TypeSpec): + domain: common.Dimension + codomain: common.Dimension + + class IteratorType(ts.DataType, ts.CallableType): position_dims: list[common.Dimension] | Literal["unknown"] defined_dims: list[common.Dimension] diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 6d77c70375..0b224fd75d 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -16,8 +16,8 @@ from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Callable, Iterable, Optional, Union from gt4py.next import common, utils -from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, ir as itir +from gt4py.next.iterator.ir_utils import misc as ir_misc from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.utils import tree_map @@ -458,7 +458,7 @@ def _canonicalize_nb_fields( def _resolve_dimensions( input_dims: list[common.Dimension], - shift_tuple: tuple[itir.OffsetLiteral, ...], + shift_tuple: tuple[itir.OffsetLiteral | itir.CartesianOffset, ...], offset_provider_type: common.OffsetProviderType, ) -> list[common.Dimension]: """ @@ -486,14 +486,25 @@ def _resolve_dimensions( >>> Edge = common.Dimension(value="Edge") >>> Vertex = common.Dimension(value="Vertex") + >>> Cell = common.Dimension(value="Cell") >>> K = common.Dimension(value="K", kind=common.DimensionKind.VERTICAL) >>> V2E = common.Dimension(value="V2E") + >>> C2V = common.Dimension(value="C2V") >>> input_dims = [Edge, K] >>> shift_tuple = ( + ... itir.OffsetLiteral(value="C2V"), + ... itir.OffsetLiteral(value=0), ... itir.OffsetLiteral(value="V2E"), ... itir.OffsetLiteral(value=0), ... ) >>> offset_provider_type = { + ... "C2V": common.NeighborConnectivityType( + ... domain=(Cell, C2V), + ... codomain=Vertex, + ... skip_value=None, + ... dtype=None, + ... max_neighbors=3, + ... ), ... "V2E": common.NeighborConnectivityType( ... domain=(Vertex, V2E), ... codomain=Edge, @@ -504,21 +515,49 @@ def _resolve_dimensions( ... "KOff": K, ... } >>> _resolve_dimensions(input_dims, shift_tuple, offset_provider_type) - [Dimension(value='Vertex', kind=), Dimension(value='K', kind=)] + [Dimension(value='Cell', kind=), Dimension(value='K', kind=)] + >>> from gt4py.next.iterator.ir_utils import ir_makers as im + >>> IDim = common.Dimension(value="IDim") + >>> IHalfDim = common.flip_staggered(IDim) + >>> JDim = common.Dimension(value="JDim") + >>> JHalfDim = common.flip_staggered(JDim) + >>> input_dims = [IDim, JDim] + >>> shift_tuple = ( + ... itir.CartesianOffset( + ... domain=im.axis_literal(IDim), codomain=im.axis_literal(IHalfDim) + ... ), + ... itir.OffsetLiteral(value=0), + ... itir.CartesianOffset(domain=im.axis_literal(JDim), codomain=im.axis_literal(IDim)), + ... itir.OffsetLiteral(value=0), + ... itir.CartesianOffset( + ... domain=im.axis_literal(IHalfDim), codomain=im.axis_literal(JDim) + ... ), + ... itir.OffsetLiteral(value=0), + ... ) + >>> _resolve_dimensions(input_dims, shift_tuple, offset_provider_type) + [Dimension(value='JDim', kind=), Dimension(value='IDim', kind=)] + """ resolved_dims = [] for input_dim in input_dims: + resolved_dim = input_dim for off_literal in reversed( shift_tuple[::2] - ): # Only OffsetLiterals are processed, located at even indices in shift_tuple. Shifts are applied in reverse order: the last shift in the tuple is applied first. - assert isinstance(off_literal.value, str) - offset_type = common.get_offset_type(offset_provider_type, off_literal.value) - if isinstance(offset_type, common.Dimension) and input_dim == offset_type: - continue # No shift applied - if isinstance(offset_type, (fbuiltins.FieldOffset, common.NeighborConnectivityType)): - if input_dim == offset_type.codomain: # Check if input fits to offset - input_dim = offset_type.domain[0] # Update input_dim for next iteration - resolved_dims.append(input_dim) + ): # Only OffsetLiterals/CartesianOffsets are processed, located at even indices in shift_tuple. Shifts are applied in reverse order: the last shift in the tuple is applied first. + if isinstance(off_literal, itir.CartesianOffset): + if resolved_dim == ir_misc.dim_from_axis_literal(off_literal.codomain): + resolved_dim = ir_misc.dim_from_axis_literal(off_literal.domain) + else: + assert isinstance(off_literal, itir.OffsetLiteral) and isinstance( + off_literal.value, str + ) + offset_type = common.get_offset_type(offset_provider_type, off_literal.value) + if isinstance(offset_type, common.Dimension) and resolved_dim == offset_type: + continue # No shift applied + if isinstance(offset_type, common.NeighborConnectivityType): + if resolved_dim == offset_type.codomain: # Check if input fits to offset + resolved_dim = offset_type.domain[0] # Update input_dim for next iteration + resolved_dims.append(resolved_dim) return resolved_dims @@ -664,22 +703,26 @@ def apply_shift( new_position_dims = [*it.position_dims] assert len(offset_literals) % 2 == 0 for offset_axis, _ in zip(offset_literals[:-1:2], offset_literals[1::2], strict=True): - assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance( - offset_axis.value, str - ) - type_ = common.get_offset_type(offset_provider_type, offset_axis.value) - if isinstance(type_, common.Dimension): - pass - elif isinstance(type_, common.NeighborConnectivityType): - found = False - for i, dim in enumerate(new_position_dims): - if dim.value == type_.source_dim.value: - assert not found - new_position_dims[i] = type_.codomain - found = True - assert found + source_dim: common.Dimension + target_dim: common.Dimension + if isinstance(offset_axis, it_ts.CartesianOffsetType): + source_dim, target_dim = offset_axis.domain, offset_axis.codomain else: - raise NotImplementedError(f"{type_} is not a supported Connectivity type.") + assert isinstance(offset_axis, it_ts.OffsetLiteralType) + assert isinstance(offset_axis.value, str) + type_ = common.get_offset_type(offset_provider_type, offset_axis.value) + assert isinstance( + type_, (common.CartesianConnectivityType, common.NeighborConnectivityType) + ) + source_dim, target_dim = type_.domain[0], type_.codomain + + found = False + for i, dim in enumerate(new_position_dims): + if dim == source_dim: + assert not found + new_position_dims[i] = target_dim + found = True + assert found else: # during re-inference we don't have an offset provider type new_position_dims = "unknown" diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 041868b00e..b46f29293a 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -206,7 +206,9 @@ def make_argument(name: str, type_: ts.TypeSpec) -> str | BufferSID | Tuple: source_buffer=name, dimensions=[ DimensionSpec( - name=dim.value, + name=dim.value + if not common.is_staggered(dim) + else common.flip_staggered(dim).value, static_stride=1 if ( config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index e03fa84e50..7a0400ea91 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -10,6 +10,7 @@ import dataclasses import pathlib +import types from typing import Protocol, TypeVar import factory @@ -42,6 +43,9 @@ def __call__( ) -> stages.BuildSystemProject[SrcL, LS, TgtL]: ... +_MODULES: list[types.ModuleType] = [] + + @dataclasses.dataclass(frozen=True) class Compiler( workflow.ChainableWorkflowMixin[ @@ -83,11 +87,12 @@ def __call__( f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." ) - compiled_prog = getattr( - importer.import_from_path(src_dir / new_data.module), new_data.entry_point_name - ) + m = importer.import_from_path(src_dir / new_data.module) + # Keep a reference to the module so they are not garbage collected. This avoids a SEGFAULT + # in nanobind when calling the compiled program. + _MODULES.append(m) - return compiled_prog + return getattr(m, new_data.entry_point_name) class CompilerFactory(factory.Factory): diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 0c76757d70..9a548d9fab 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -143,7 +143,7 @@ def _process_connectivity_args( arg_exprs.append( f"gridtools::hymap::keys::make_values({nbtbl})" ) - elif isinstance(connectivity_type, common.Dimension): + elif isinstance(connectivity_type, common.CartesianConnectivityType): pass else: raise AssertionError( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index ecd8ed88ed..9ecf9a3588 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -91,10 +91,18 @@ def _name_from_named_range(named_range_call: itir.FunCall) -> str: return named_range_call.args[0].value +class FlipStaggeredDims(eve.NodeTranslator): + def visit_AxisLiteral(self, node: itir.AxisLiteral) -> itir.AxisLiteral: + dim = ir_utils_misc.dim_from_axis_literal(node) + if common.is_staggered(dim): + return im.axis_literal(common.flip_staggered(dim)) + return node + + def _collect_dimensions_from_domain( body: Iterable[itir.Stmt], ) -> dict[str, TagDefinition]: - domains = _get_domains(body) + domains = FlipStaggeredDims().visit(_get_domains(body)) offset_definitions = {} for domain in domains: if domain.fun == itir.SymRef(id="cartesian_domain"): @@ -131,22 +139,27 @@ def _collect_offset_definitions( grid_type: common.GridType, offset_provider_type: common.OffsetProviderType, ) -> dict[str, TagDefinition]: - used_offset_tags: set[str] = ( - node.walk_values() - .if_isinstance(itir.OffsetLiteral) - .filter(lambda offset_literal: isinstance(offset_literal.value, str)) - .getattr("value") - ).to_set() - # implicit offsets don't occur in the `offset_provider_type`, get them from the used offset tags - offset_provider_type = { - offset_name: common.get_offset_type(offset_provider_type, offset_name) - for offset_name in used_offset_tags - } | {**offset_provider_type} offset_definitions = {} + offset_provider_type = {**offset_provider_type} - for offset_name, dim_or_connectivity_type in offset_provider_type.items(): - if isinstance(dim_or_connectivity_type, common.Dimension): - dim: common.Dimension = dim_or_connectivity_type + cartesian_offsets: set[itir.CartesianOffset] = ( + node.walk_values().if_isinstance(itir.CartesianOffset) + ).to_set() + for cart_offset in cartesian_offsets: + dims = [ + common.Dimension(value=v.value, kind=v.kind) + for v in (cart_offset.domain, cart_offset.codomain) + ] + for dim in dims: + if common.is_staggered(dim): + dim = common.flip_staggered(dim) + offset_definitions[dim.value] = TagDefinition(name=Sym(id=dim.value)) + + for offset_name, connectivity_type in offset_provider_type.items(): + if isinstance(connectivity_type, common.CartesianConnectivityType): + if connectivity_type.domain[0] != connectivity_type.codomain: + raise NotImplementedError() + dim, *_ = connectivity_type.domain if grid_type == common.GridType.CARTESIAN: # create alias from offset to dimension offset_definitions[dim.value] = TagDefinition(name=Sym(id=dim.value)) @@ -166,9 +179,7 @@ def _collect_offset_definitions( offset_definitions[offset_name] = TagDefinition( name=Sym(id=offset_name), alias=SymRef(id=dim.value) ) - elif isinstance( - connectivity_type := dim_or_connectivity_type, common.NeighborConnectivityType - ): + elif isinstance(connectivity_type := connectivity_type, common.NeighborConnectivityType): assert grid_type == common.GridType.UNSTRUCTURED offset_definitions[offset_name] = TagDefinition(name=Sym(id=offset_name)) if offset_name != connectivity_type.neighbor_dim.value: @@ -372,8 +383,15 @@ def visit_Literal(self, node: itir.Literal, **kwargs: Any) -> Literal: def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs: Any) -> OffsetLiteral: return OffsetLiteral(value=node.value) + def visit_CartesianOffset(self, node: itir.CartesianOffset, **kwargs: Any) -> Literal: + return self.visit(node.codomain, **kwargs) + def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs: Any) -> Literal: - return Literal(value=node.value, type="axis_literal") + assert isinstance(node.type, ts.DimensionType) + dim = node.type.dim + if common.is_staggered(dim): + dim = common.flip_staggered(dim) + return Literal(value=dim.value, type="axis_literal") def _make_domain(self, node: itir.FunCall) -> tuple[TaggedValues, TaggedValues]: tags = [] diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index 049e2a85be..3877e43bf2 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -32,7 +32,11 @@ from gt4py.eve.extended_typing import MaybeNestedInTuple, NestedTuple from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + ir_makers as im, + misc as itir_misc, +) from gt4py.next.iterator.transforms import symbol_ref_utils from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args from gt4py.next.program_processors.runners.dace.lowering import ( @@ -1528,12 +1532,13 @@ def _visit_shift_multidim( return offset_provider_arg, offset_value_arg, it def _make_cartesian_shift( - self, it: IteratorExpr, offset_dim: gtx_common.Dimension, offset_expr: DataExpr + self, it: IteratorExpr, conn: gtx_common.CartesianConnectivityType, offset_expr: DataExpr ) -> IteratorExpr: """Implements cartesian shift along one dimension.""" - assert any(dim == offset_dim for dim, _ in it.field_domain) + (old_dim,) = conn.domain + new_dim = conn.codomain new_index: SymbolExpr | ValueExpr - index_expr = it.indices[offset_dim] + index_expr = it.indices[old_dim] if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr): # purely symbolic expression which can be interpreted at compile time new_index = SymbolExpr( @@ -1596,9 +1601,9 @@ def _make_cartesian_shift( ) # a new iterator with a shifted index along one dimension - shifted_indices = { - dim: (new_index if dim == offset_dim else index) for dim, index in it.indices.items() - } + shifted_indices = dict( + (new_dim, new_index) if dim == old_dim else (dim, index) for dim, index in it.indices.items() + ) return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices) def _make_dynamic_neighbor_offset( @@ -1691,11 +1696,16 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: node.args[0], node.fun.args ) - # first argument of the shift node is the offset provider - assert isinstance(offset_provider_arg, gtir.OffsetLiteral) - offset = offset_provider_arg.value - assert isinstance(offset, str) - offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset) + if isinstance(offset_provider_arg, gtir.CartesianOffset): + conn = itir_misc.connectivity_from_cartesian_offset(offset_provider_arg) + offset_provider_type = conn.__gt_type__() + else: + assert isinstance(offset_provider_arg, gtir.OffsetLiteral) + assert isinstance(offset_provider_arg.value, str) + offset_provider_type = self.subgraph_builder.get_offset_provider_type( + offset_provider_arg.value + ) + # second argument should be the offset value, which could be a symbolic expression or a dynamic offset offset_expr = ( SymbolExpr(offset_value_arg.value, gtir_to_sdfg_types.INDEX_DTYPE) @@ -1703,13 +1713,14 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: else self.visit(offset_value_arg) ) - if isinstance(offset_provider_type, gtx_common.Dimension): + if isinstance(offset_provider_type, gtx_common.CartesianConnectivityType): return self._make_cartesian_shift(it, offset_provider_type, offset_expr) else: + assert isinstance(offset_value_arg, gtir.OffsetLiteral) # initially, the storage for the connectivity tables is created as transient; # when the tables are used, the storage is changed to non-transient, # so the corresponding arrays are supposed to be allocated by the SDFG caller - offset_table = gtx_dace_args.connectivity_identifier(offset) + offset_table = gtx_dace_args.connectivity_identifier(offset_provider_arg.value) self.sdfg.arrays[offset_table].transient = False offset_table_node = self.state.add_access(offset_table) diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index 8b5496989f..0a3debd8a4 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -156,7 +156,7 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ used_connectivities: dict[str, gtx_common.NeighborConnectivity] = { conn_id: conn for offset, conn in self.connectivities.items() - if gtx_common.is_neighbor_table(conn) + if isinstance(conn, gtx_common.NeighborConnectivity) and (conn_id := gtx_dace_args.connectivity_identifier(offset)) in self.sdfg_closure_cache["arrays"] } diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 7bd796785d..e435d64f66 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -84,7 +84,11 @@ def extract_connectivity_args( # Note: this function is on the hot path and needs to have minimal overhead. zero_origin = (0, 0) assert all( - hasattr(conn, "ndarray") or isinstance(conn, common.Dimension) + isinstance(conn, common.CartesianConnectivity) + or ( + isinstance(conn, common.NeighborConnectivity) + and field_utils.verify_device_field_type(conn, device) + ) for conn in offset_provider.values() ) # Note: the order here needs to agree with the order of the generated bindings. @@ -92,15 +96,10 @@ def extract_connectivity_args( # the keys' order is taken into account. Any modification to the hashing # of offset providers may break this assumption here. args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [ - (ndarray, zero_origin) + (conn.ndarray, zero_origin) for conn in offset_provider.values() - if (ndarray := getattr(conn, "ndarray", None)) is not None + if isinstance(conn, common.NeighborConnectivity) ] - assert all( - common.is_neighbor_table(conn) and field_utils.verify_device_field_type(conn, device) - for conn in offset_provider.values() - if hasattr(conn, "ndarray") - ) return args diff --git a/tests/next_tests/fixtures/past_common.py b/tests/next_tests/fixtures/past_common.py index 718db6c3a6..0b2681059c 100644 --- a/tests/next_tests/fixtures/past_common.py +++ b/tests/next_tests/fixtures/past_common.py @@ -15,9 +15,7 @@ IDim = gtx.Dimension("IDim") -Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) JDim = gtx.Dimension("JDim") -Joff = gtx.FieldOffset("Joff", source=JDim, target=(JDim,)) # TODO(tehrengruber): Improve test structure. Identity needs to be decorated diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 78e6c62781..31d04b39bd 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -48,12 +48,11 @@ E2VDim, Edge, IDim, - Ioff, + IHalfDim, JDim, - Joff, + JHalfDim, KDim, KHalfDim, - Koff, V2EDim, Vertex, exec_alloc_descriptor, @@ -67,6 +66,7 @@ # mypy does not accept [IDim, ...] as a type IField: TypeAlias = gtx.Field[[IDim], np.int32] # type: ignore [valid-type] +IHalfField: TypeAlias = gtx.Field[[IHalfDim], np.int32] # type: ignore [valid-type] JField: TypeAlias = gtx.Field[[JDim], np.int32] # type: ignore [valid-type] IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type] IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type] @@ -494,6 +494,7 @@ def verify_with_default_data( case: Case, fieldop: decorator.FieldOperator, ref: Callable, + offset_provider: Optional[OffsetProvider] = None, comparison: Callable[[Any, Any], bool] = tree_mapped_np_allclose, ) -> None: """ @@ -508,6 +509,8 @@ def verify_with_default_data( fieldview_prog: The field operator or program to be verified. ref: A callable which will be called with all the input arguments of the fieldview code, after applying ``.ndarray`` on the fields. + offset_provider: An override for the test case's offset_provider. + Use with care! comparison: A comparison function, which will be called as ``comparison(ref, )`` and should return a boolean. """ @@ -521,7 +524,7 @@ def verify_with_default_data( *inps, **kwfields, ref=ref(*ref_args), - offset_provider=case.offset_provider, + offset_provider=offset_provider, comparison=comparison, ) @@ -572,7 +575,7 @@ def unstructured_case_3d(unstructured_case): return dataclasses.replace( unstructured_case, default_sizes={**unstructured_case.default_sizes, KDim: 10}, - offset_provider={**unstructured_case.offset_provider, "Koff": KDim}, + offset_provider={**unstructured_case.offset_provider}, ) @@ -724,7 +727,9 @@ def from_cartesian_grid_descriptor( IDim: grid_descriptor.sizes[0], JDim: grid_descriptor.sizes[1], KDim: grid_descriptor.sizes[2], - KHalfDim: grid_descriptor.sizes[3], + IHalfDim: grid_descriptor.sizes[0] - 1, + JHalfDim: grid_descriptor.sizes[1] - 1, + KHalfDim: grid_descriptor.sizes[2] - 1, }, grid_type=common.GridType.CARTESIAN, allocator=allocator, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 7640553e6a..39297591e9 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -30,9 +30,6 @@ "IDim", "JDim", "KDim", - "Ioff", - "Joff", - "Koff", "Vertex", "Edge", "Cell", @@ -133,12 +130,11 @@ def debug_itir(tree): DType = TypeVar("DType") IDim = gtx.Dimension("IDim") +IHalfDim = common.flip_staggered(IDim) JDim = gtx.Dimension("JDim") +JHalfDim = common.flip_staggered(JDim) KDim = gtx.Dimension("KDim", kind=gtx.DimensionKind.VERTICAL) -KHalfDim = gtx.Dimension("KHalf", kind=gtx.DimensionKind.VERTICAL) -Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) -Joff = gtx.FieldOffset("Joff", source=JDim, target=(JDim,)) -Koff = gtx.FieldOffset("Koff", source=KDim, target=(KDim,)) +KHalfDim = common.flip_staggered(KDim) Vertex = gtx.Dimension("Vertex") Edge = gtx.Dimension("Edge") @@ -172,18 +168,13 @@ def offset_provider_type(self) -> common.OffsetProviderType: ... def simple_cartesian_grid( - sizes: int | tuple[int, int, int, int] = (5, 7, 9, 11), + sizes: int | tuple[int, int, int, int] = (5, 7, 9), ) -> CartesianGridDescriptor: if isinstance(sizes, int): - sizes = (sizes,) * 4 - assert len(sizes) == 4, "sizes must be a tuple of four integers" + sizes = (sizes,) * 3 + assert len(sizes) == 3, "sizes must be a tuple of three integers" - offset_provider = { - "Ioff": IDim, - "Joff": JDim, - "Koff": KDim, - "KHalfoff": KHalfDim, - } + offset_provider = {} return types.SimpleNamespace( name="simple_cartesian_grid", diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 8060d5bb36..40f6ecc4c9 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -37,10 +37,8 @@ E2VDim, Edge, IDim, - Ioff, JDim, KDim, - Koff, V2EDim, Vertex, cartesian_case, @@ -119,7 +117,7 @@ def testee() -> cases.IFloatField: def test_cartesian_shift(cartesian_case): @gtx.field_operator def testee(a: cases.IJKField) -> cases.IJKField: - return a(Ioff[1]) + return a(IDim + 1) a = cases.allocate(cartesian_case, testee, "a").extend({IDim: (0, 1)})() out = cases.allocate(cartesian_case, testee, cases.RETURN)() @@ -205,8 +203,8 @@ def test_fold_shifts(cartesian_case): @gtx.field_operator def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField: - tmp = a + b(Ioff[1]) - return tmp(Ioff[1]) + tmp = a + b(IDim + 1) + return tmp(IDim + 1) a = cases.allocate(cartesian_case, testee, "a").extend({cases.IDim: (0, 1)})() b = cases.allocate(cartesian_case, testee, "b").extend({cases.IDim: (0, 2)})() @@ -355,7 +353,7 @@ def test_scalar_arg_with_field(cartesian_case): @gtx.field_operator def testee(a: cases.IJKField, b: int32) -> cases.IJKField: tmp = b * a - return tmp(Ioff[1]) + return tmp(IDim + 1) a = cases.allocate(cartesian_case, testee, "a").extend({IDim: (0, 1)})() b = cases.allocate(cartesian_case, testee, "b")() @@ -452,7 +450,7 @@ def testee_scan(state: float, inp: float) -> float: @gtx.field_operator def testee(inp: gtx.Field[[KDim], float]) -> gtx.Field[[KDim], float]: - return testee_scan(inp(Koff[1])) + return testee_scan(inp(KDim + 1)) inp = cases.allocate( cartesian_case, @@ -672,6 +670,9 @@ def testee() -> gtx.Field[[IDim], int_alias]: @pytest.mark.uses_dynamic_offsets def test_offset_field(cartesian_case): + Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) + Koff = gtx.FieldOffset("Koff", source=KDim, target=(KDim,)) + ref = np.full( (cartesian_case.default_sizes[IDim], cartesian_case.default_sizes[KDim]), True, dtype=bool ) @@ -682,8 +683,8 @@ def testee(a: cases.IKField, offset_field: cases.IKField) -> gtx.Field[[IDim, KD # note: this leads to an access to offset_field in # IDim: (0, out.size[I]), KDim: (0, out.size[K]+1) a_i_k = a_i(as_offset(Koff, offset_field)) - b_i = a(Ioff[1]) - b_i_k = b_i(Koff[1]) + b_i = a(IDim + 1) + b_i_k = b_i(KDim + 1) return a_i_k == b_i_k out = cases.allocate(cartesian_case, testee, cases.RETURN)() @@ -700,7 +701,10 @@ def testee(a: cases.IKField, offset_field: cases.IKField) -> gtx.Field[[IDim, KD a, offset_field, out=out, - offset_provider={"Ioff": IDim, "Koff": KDim}, + offset_provider={ + "Ioff": common.CartesianConnectivity(domain_dim=IDim), + "Koff": common.CartesianConnectivity(domain_dim=KDim), + }, ref=ref, comparison=lambda out, ref: np.all(out == ref), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 1a1984a71b..0043881703 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -21,11 +21,8 @@ V2E, Edge, IDim, - Ioff, JDim, - Joff, KDim, - Koff, V2EDim, Vertex, cartesian_case, @@ -146,7 +143,7 @@ def reduction(edge_f: EKField) -> VKField: @gtx.field_operator def fencil_op(edge_f: EKField) -> VKField: red = reduction(edge_f) - return red(Koff[1]) + return red(KDim + 1) @gtx.program def fencil(edge_f: EKField, out: VKField): @@ -377,7 +374,7 @@ def test_broadcast_shifted(cartesian_case): @gtx.field_operator def simple_broadcast(inp: cases.IField) -> cases.IJField: bcasted = broadcast(inp, (IDim, JDim)) - return bcasted(Joff[1]) + return bcasted(JDim + 1) cases.verify_with_default_data( cartesian_case, simple_broadcast, ref=lambda inp: inp[:, np.newaxis] @@ -439,7 +436,7 @@ def conditional_shifted( mask: cases.IBoolField, a: cases.IFloatField, b: cases.IFloatField ) -> gtx.Field[[IDim], float64]: tmp = where(mask, a, b) - return tmp(Ioff[1]) + return tmp(IDim + 1) @gtx.program def conditional_program( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 1abaa47d03..1099ae4671 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -19,7 +19,6 @@ from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( IDim, - Ioff, JDim, cartesian_case, exec_alloc_descriptor, @@ -56,7 +55,7 @@ def test_identity_fo_execution(cartesian_case, identity_def): def test_shift_by_one_execution(cartesian_case): @gtx.field_operator def shift_by_one(in_field: cases.IFloatField) -> cases.IFloatField: - return in_field(Ioff[1]) + return in_field(IDim + 1) # direct call to field operator # TODO(tehrengruber): slicing located fields not supported currently diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_staggered.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_staggered.py new file mode 100644 index 0000000000..c33d56ac76 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_staggered.py @@ -0,0 +1,142 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import functools +import math +from functools import reduce +from typing import TypeAlias + +import numpy as np +import pytest + +import gt4py.next as gtx +from gt4py.next import ( + astype, + broadcast, + common, + errors, + float32, + float64, + int32, + int64, + minimum, + neighbor_sum, + utils as gt_utils, +) +from gt4py.next.ffront.experimental import as_offset + +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import ( + C2E, + E2V, + V2E, + E2VDim, + Edge, + IDim, + IHalfDim, + JDim, + KDim, + V2EDim, + Vertex, + cartesian_case, + unstructured_case, + unstructured_case_3d, +) +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + exec_alloc_descriptor, + mesh_descriptor, +) + + +@pytest.mark.uses_cartesian_shift +def test_copy_half_field(cartesian_case): + @gtx.field_operator + def testee(a: cases.IHalfField) -> cases.IHalfField: + field_tuple = (a, a) + field_0 = field_tuple[0] + field_1 = field_tuple[1] + return field_0 + + cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a, offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_shift_plus(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IField) -> cases.IField: + return a(IDim + 1) # always pass an I-index to an IField + + a = cases.allocate(cartesian_case, testee, "a").extend({IDim: (0, 1)})() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify(cartesian_case, testee, a, out=out, ref=a[1:], offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_half_shift_plus(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IField) -> cases.IHalfField: + return a(IHalfDim + 0.5) # always pass an I-index to an IField + + a = cases.allocate(cartesian_case, testee, "a").extend({IDim: (-1, 0)})() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify(cartesian_case, testee, a, out=out, ref=a, offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_half_shift_back(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IHalfField) -> cases.IHalfField: + return a(IDim + 0.5)(IHalfDim - 0.5) # always pass an I-index to an IField + + a = cases.allocate(cartesian_case, testee, "a")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify(cartesian_case, testee, a, out=out, ref=a, offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_half_shift_plus1(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IHalfField) -> cases.IHalfField: + return a(IHalfDim + 1) # always pass an IHalf-index to an IHalfField + + a = cases.allocate(cartesian_case, testee, "a")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify(cartesian_case, testee, a, out=out[:-1], ref=a[1:], offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_half_shift_minus(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IField) -> cases.IHalfField: + return a(IHalfDim - 0.5) # always pass an I-index to an IField + + a = cases.allocate(cartesian_case, testee, "a").extend({IDim: (0, -1)})() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify(cartesian_case, testee, a, out=out, ref=a[:], offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_half_shift_half2center(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IHalfField) -> cases.IField: + return 2 * a(IDim + 0.5) # always pass an IHalf-index to an IHalfField + + a = cases.allocate(cartesian_case, testee, "a").extend({IHalfDim: (0, 1)})() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify(cartesian_case, testee, a, out=out, ref=2 * a[:], offset_provider={}) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py index cd10c10437..2550692248 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py @@ -9,7 +9,7 @@ import numpy as np from typing import Tuple import pytest -from next_tests.integration_tests.cases import IDim, JDim, KDim, Koff, cartesian_case +from next_tests.integration_tests.cases import IDim, JDim, KDim, cartesian_case from gt4py import next as gtx from gt4py.next import int32 from gt4py.next.ffront.fbuiltins import where, broadcast @@ -25,7 +25,7 @@ def test_where_k_offset(cartesian_case): def fieldop_where_k_offset( inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType] ) -> cases.IKField: - return where(k_index > 0, inp(Koff[-1]), 2) + return where(k_index > 0, inp(KDim - 1), 2) @gtx.program def prog( diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index 09dc04acb1..2bac637e9b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -26,7 +26,6 @@ I = gtx.Dimension("I") -Ioff = gtx.FieldOffset("Ioff", source=I, target=(I,)) @fundef @@ -80,7 +79,8 @@ def test_index_builtin(program_processor): def index_program_shift(out, size): set_at( as_fieldop( - lambda i: deref(i) + deref(shift(Ioff, 1)(i)), cartesian_domain(named_range(I, 0, size)) + lambda i: deref(i) + deref(shift("Ioff", 1)(i)), + cartesian_domain(named_range(I, 0, size)), )(index(I)), cartesian_domain(named_range(I, 0, size)), out, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 6d7fd9df2b..df143eda54 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -28,7 +28,6 @@ Cell = gtx.Dimension("Cell") KDim = gtx.Dimension("KDim", kind=gtx.DimensionKind.VERTICAL) -Koff = gtx.FieldOffset("Koff", KDim, (KDim,)) class State(NamedTuple): @@ -59,9 +58,9 @@ def _solve_nonhydro_stencil_52_like( gtx.Field[[Cell, KDim], float], gtx.Field[[Cell, KDim], float], gtx.Field[[Cell, KDim], bool] ]: """No projector required as we write all output of the scan (including dummy field)""" - z_a = z_beta(Koff[-1]) * z_alpha(Koff[-1]) - z_c = z_beta * z_alpha(Koff[1]) - z_b = z_alpha * (z_beta(Koff[-1]) + z_beta) + z_a = z_beta(KDim - 1) * z_alpha(KDim - 1) + z_c = z_beta * z_alpha(KDim + 1) + z_b = z_alpha * (z_beta(KDim - 1) + z_beta) z_q_res, w_res, dummy = _scan(w, z_q, z_a, z_b, z_c) return z_q_res, w_res, dummy @@ -86,9 +85,9 @@ def _solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge( z_q: gtx.Field[[Cell, KDim], float], w: gtx.Field[[Cell, KDim], float], ) -> tuple[gtx.Field[[Cell, KDim], float], gtx.Field[[Cell, KDim], float]]: - z_a = z_beta(Koff[-1]) * z_alpha(Koff[-1]) - z_c = z_beta * z_alpha(Koff[1]) - z_b = z_alpha * (z_beta(Koff[-1]) + z_beta) + z_a = z_beta(KDim - 1) * z_alpha(KDim - 1) + z_c = z_beta * z_alpha(KDim + 1) + z_b = z_alpha * (z_beta(KDim - 1) + z_beta) z_q_res, w_res, _ = _scan(w, z_q, z_a, z_b, z_c) return z_q_res, w_res @@ -112,9 +111,9 @@ def _solve_nonhydro_stencil_52_like_z_q( z_q: gtx.Field[[Cell, KDim], float], w: gtx.Field[[Cell, KDim], float], ) -> gtx.Field[[Cell, KDim], float]: - z_a = z_beta(Koff[-1]) * z_alpha(Koff[-1]) - z_c = z_beta * z_alpha(Koff[1]) - z_b = z_alpha * (z_beta(Koff[-1]) + z_beta) + z_a = z_beta(KDim - 1) * z_alpha(KDim - 1) + z_c = z_beta * z_alpha(KDim + 1) + z_b = z_alpha * (z_beta(KDim - 1) + z_beta) z_q_res, w_res, _ = _scan(w, z_q, z_a, z_b, z_c) return z_q_res @@ -137,9 +136,9 @@ def _solve_nonhydro_stencil_52_like_z_q_tup( z_q: gtx.Field[[Cell, KDim], float], w: gtx.Field[[Cell, KDim], float], ) -> tuple[gtx.Field[[Cell, KDim], float]]: - z_a = z_beta(Koff[-1]) * z_alpha(Koff[-1]) - z_c = z_beta * z_alpha(Koff[1]) - z_b = z_alpha * (z_beta(Koff[-1]) + z_beta) + z_a = z_beta(KDim - 1) * z_alpha(KDim - 1) + z_c = z_beta * z_alpha(KDim + 1) + z_b = z_alpha * (z_beta(KDim - 1) + z_beta) z_q_res, w_res, _ = _scan(w, z_q, z_a, z_b, z_c) return (z_q_res,) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py index 5c6bd5a54a..7a07299b18 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py @@ -12,7 +12,7 @@ import gt4py.next as gtx from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import IDim, JDim, Joff, cartesian_case +from next_tests.integration_tests.cases import IDim, JDim, cartesian_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py index dd30caa726..e5a4fb3be7 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -17,6 +17,7 @@ IDim, JDim, KDim, + KHalfDim, C2E, E2V, V2E, @@ -37,7 +38,6 @@ mesh_descriptor, ) -KHalfDim = gtx.Dimension("KHalf", kind=gtx.DimensionKind.VERTICAL) pytestmark = pytest.mark.uses_cartesian_shift diff --git a/tests/next_tests/unit_tests/embedded_tests/test_basic_program.py b/tests/next_tests/unit_tests/embedded_tests/test_basic_program.py index 9b4dae165a..a3e2c18e34 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_basic_program.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_basic_program.py @@ -12,14 +12,13 @@ IDim = gtx.Dimension("IDim") -IOff = gtx.FieldOffset("IOff", source=IDim, target=(IDim,)) @gtx.field_operator def fop( a: gtx.Field[[IDim], gtx.float64], b: gtx.Field[[IDim], gtx.float64] ) -> gtx.Field[[IDim], gtx.float64]: - return a(IOff[1]) + b + return a(IDim + 1) + b @gtx.program @@ -36,6 +35,6 @@ def test_basic(): b = gtx.as_field([(IDim, gtx.common.UnitRange(0, 4))], np.asarray([0.0, 1.0, 2.0, 3.0])) out = gtx.as_field([(IDim, gtx.common.UnitRange(0, 4))], np.asarray([0.0, 0.0, 0.0, 0.0])) - prog(a, b, out, offset_provider={"IOff": IDim}) + prog(a, b, out) assert out.domain == b.domain assert np.allclose(out.ndarray, a.ndarray + b.ndarray) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 21177d0aea..97ef301e38 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -131,9 +131,7 @@ def foo(inp: gtx.Field[[TDim], float64]): lowered = FieldOperatorLowering.apply(parsed) reference = im.as_fieldop( - im.lambda_("__it")( - im.deref(im.shift(common.dimension_to_implicit_offset(TDim.value), 1)("__it")) - ) + im.lambda_("__it")(im.deref(im.shift(im.cartesian_offset(TDim, TDim), 1)("__it"))) )("inp") assert lowered.expr == reference diff --git a/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py index 0817b5f19d..69812ab1f6 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py @@ -12,7 +12,6 @@ from gt4py.next.type_system import type_specifications as ts IDim = gtx.Dimension("IDim") -field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) def test_inline_dynamic_shift_as_fieldop_arg(uids): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index d73fc1945f..c3f42761eb 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -33,10 +33,8 @@ Edge, Cell, IDim, - Ioff, JDim, KDim, - Koff, V2EDim, Vertex, exec_alloc_descriptor,