Skip to content

Commit 6f90b7c

Browse files
committed
Add an arithmetic_compat option to xr.set_options, which determines how non-index coordinates of the same name are compared for potential conflicts when performing binary operations.
The default of compat='minimal' matches the previous behaviour.
1 parent 121f266 commit 6f90b7c

File tree

8 files changed

+135
-16
lines changed

8 files changed

+135
-16
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ New Features
2222
- :py:func:`combine_nested` now support :py:class:`DataTree` objects
2323
(:pull:`10849`).
2424
By `Stephan Hoyer <https://github.com/shoyer>`_.
25+
- :py:func:`set_options` now supports an ``arithmetic_compat`` option which determines how non-index coordinates
26+
of the same name are compared for potential conflicts when performing binary operations. The default for it is
27+
``arithmetic_compat='minimal'`` which matches the existing behaviour.
28+
By `Matthew Willson <https://github.com/mjwillson>`_.
2529

2630
Breaking Changes
2731
~~~~~~~~~~~~~~~~

xarray/core/coordinates.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@
2121
assert_no_index_corrupted,
2222
create_default_index_implicit,
2323
)
24-
from xarray.core.types import DataVars, ErrorOptions, Self, T_DataArray, T_Xarray
24+
from xarray.core.types import (
25+
CompatOptions,
26+
DataVars,
27+
ErrorOptions,
28+
Self,
29+
T_DataArray,
30+
T_Xarray,
31+
)
2532
from xarray.core.utils import (
2633
Frozen,
2734
ReprObject,
@@ -31,6 +38,7 @@
3138
from xarray.core.variable import Variable, as_variable, calculate_dimensions
3239
from xarray.structure.alignment import Aligner
3340
from xarray.structure.merge import merge_coordinates_without_align, merge_coords
41+
from xarray.util.deprecation_helpers import CombineKwargDefault
3442

3543
if TYPE_CHECKING:
3644
from xarray.core.common import DataWithCoords
@@ -499,18 +507,20 @@ def _drop_coords(self, coord_names):
499507
# redirect to DatasetCoordinates._drop_coords
500508
self._data.coords._drop_coords(coord_names)
501509

502-
def _merge_raw(self, other, reflexive):
510+
def _merge_raw(self, other, reflexive, compat: CompatOptions | CombineKwargDefault):
503511
"""For use with binary arithmetic."""
504512
if other is None:
505513
variables = dict(self.variables)
506514
indexes = dict(self.xindexes)
507515
else:
508516
coord_list = [self, other] if not reflexive else [other, self]
509-
variables, indexes = merge_coordinates_without_align(coord_list)
517+
variables, indexes = merge_coordinates_without_align(
518+
coord_list, compat=compat
519+
)
510520
return variables, indexes
511521

512522
@contextmanager
513-
def _merge_inplace(self, other):
523+
def _merge_inplace(self, other, compat: CompatOptions | CombineKwargDefault):
514524
"""For use with in-place binary arithmetic."""
515525
if other is None:
516526
yield
@@ -523,12 +533,16 @@ def _merge_inplace(self, other):
523533
if k not in self.xindexes
524534
}
525535
variables, indexes = merge_coordinates_without_align(
526-
[self, other], prioritized
536+
[self, other], prioritized, compat=compat
527537
)
528538
yield
529539
self._update_coords(variables, indexes)
530540

531-
def merge(self, other: Mapping[Any, Any] | None) -> Dataset:
541+
def merge(
542+
self,
543+
other: Mapping[Any, Any] | None,
544+
compat: CompatOptions | CombineKwargDefault = "minimal",
545+
) -> Dataset:
532546
"""Merge two sets of coordinates to create a new Dataset
533547
534548
The method implements the logic used for joining coordinates in the
@@ -545,6 +559,8 @@ def merge(self, other: Mapping[Any, Any] | None) -> Dataset:
545559
other : dict-like, optional
546560
A :py:class:`Coordinates` object or any mapping that can be turned
547561
into coordinates.
562+
compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override", "minimal"}, default: "minimal"
563+
Compatibility checks to use between coordinate variables.
548564
549565
Returns
550566
-------
@@ -559,7 +575,7 @@ def merge(self, other: Mapping[Any, Any] | None) -> Dataset:
559575
if not isinstance(other, Coordinates):
560576
other = Dataset(coords=other).coords
561577

562-
coords, indexes = merge_coordinates_without_align([self, other])
578+
coords, indexes = merge_coordinates_without_align([self, other], compat=compat)
563579
coord_names = set(coords)
564580
return Dataset._construct_direct(
565581
variables=coords, coord_names=coord_names, indexes=indexes

xarray/core/dataarray.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4899,7 +4899,9 @@ def _binary_op(
48994899
if not reflexive
49004900
else f(other_variable_or_arraylike, self.variable)
49014901
)
4902-
coords, indexes = self.coords._merge_raw(other_coords, reflexive)
4902+
coords, indexes = self.coords._merge_raw(
4903+
other_coords, reflexive, compat=OPTIONS["arithmetic_compat"]
4904+
)
49034905
name = result_name([self, other])
49044906

49054907
return self._replace(variable, coords, name, indexes=indexes)
@@ -4919,7 +4921,9 @@ def _inplace_binary_op(self, other: DaCompatible, f: Callable) -> Self:
49194921
other_coords = getattr(other, "coords", None)
49204922
other_variable = getattr(other, "variable", other)
49214923
try:
4922-
with self.coords._merge_inplace(other_coords):
4924+
with self.coords._merge_inplace(
4925+
other_coords, compat=OPTIONS["arithmetic_compat"]
4926+
):
49234927
f(self.variable, other_variable)
49244928
except MergeError as exc:
49254929
raise MergeError(

xarray/core/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7765,7 +7765,7 @@ def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars):
77657765
return type(self)(new_data_vars)
77667766

77677767
other_coords: Coordinates | None = getattr(other, "coords", None)
7768-
ds = self.coords.merge(other_coords)
7768+
ds = self.coords.merge(other_coords, compat=OPTIONS["arithmetic_compat"])
77697769

77707770
if isinstance(other, Dataset):
77717771
new_vars = apply_over_both(

xarray/core/options.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22

33
import warnings
44
from collections.abc import Sequence
5-
from typing import TYPE_CHECKING, Any, Literal, TypedDict
5+
from typing import TYPE_CHECKING, Any, Literal, TypedDict, get_args
66

7+
from xarray.core.types import CompatOptions
78
from xarray.core.utils import FrozenDict
89

910
if TYPE_CHECKING:
1011
from matplotlib.colors import Colormap
1112

1213
Options = Literal[
14+
"arithmetic_compat",
1315
"arithmetic_join",
1416
"chunk_manager",
1517
"cmap_divergent",
@@ -40,6 +42,7 @@
4042

4143
class T_Options(TypedDict):
4244
arithmetic_broadcast: bool
45+
arithmetic_compat: CompatOptions
4346
arithmetic_join: Literal["inner", "outer", "left", "right", "exact"]
4447
chunk_manager: str
4548
cmap_divergent: str | Colormap
@@ -70,6 +73,7 @@ class T_Options(TypedDict):
7073

7174
OPTIONS: T_Options = {
7275
"arithmetic_broadcast": True,
76+
"arithmetic_compat": "minimal",
7377
"arithmetic_join": "inner",
7478
"chunk_manager": "dask",
7579
"cmap_divergent": "RdBu_r",
@@ -109,6 +113,7 @@ def _positive_integer(value: Any) -> bool:
109113

110114
_VALIDATORS = {
111115
"arithmetic_broadcast": lambda value: isinstance(value, bool),
116+
"arithmetic_compat": get_args(CompatOptions).__contains__,
112117
"arithmetic_join": _JOIN_OPTIONS.__contains__,
113118
"display_max_children": _positive_integer,
114119
"display_max_rows": _positive_integer,
@@ -178,18 +183,34 @@ class set_options:
178183
179184
Parameters
180185
----------
186+
arithmetic_broadcast: bool, default: True
187+
Whether to perform automatic broadcasting in binary operations.
188+
arithmetic_compat: {"identical", "equals", "broadcast_equals", "no_conflicts", "override", "minimal"}, default: "minimal"
189+
How to compare non-index coordinates of the same name for potential
190+
conflicts when performing binary operations. (For the alignment of index
191+
coordinates in binary operations, see `arithmetic_join`.)
192+
193+
- "identical": all values, dimensions and attributes of the coordinates
194+
must be the same.
195+
- "equals": all values and dimensions of the coordinates must be the
196+
same.
197+
- "broadcast_equals": all values of the coordinates must be equal after
198+
broadcasting to ensure common dimensions.
199+
- "no_conflicts": only values which are not null in both coordinates
200+
must be equal. The returned coordinate then contains the combination
201+
of all non-null values.
202+
- "override": skip comparing and take the coordinates from the first
203+
operand.
204+
- "minimal": drop conflicting coordinates.
181205
arithmetic_join : {"inner", "outer", "left", "right", "exact"}, default: "inner"
182-
DataArray/Dataset alignment in binary operations:
206+
DataArray/Dataset index alignment in binary operations:
183207
184208
- "outer": use the union of object indexes
185209
- "inner": use the intersection of object indexes
186210
- "left": use indexes from the first object with each dimension
187211
- "right": use indexes from the last object with each dimension
188212
- "exact": instead of aligning, raise `ValueError` when indexes to be
189213
aligned are not equal
190-
- "override": if indexes are of same size, rewrite indexes to be
191-
those of the first object with that dimension. Indexes for the same
192-
dimension must have the same size in all objects.
193214
chunk_manager : str, default: "dask"
194215
Chunk manager to use for chunked array computations when multiple
195216
options are installed.

xarray/structure/merge.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ def merge_coordinates_without_align(
433433
prioritized: Mapping[Any, MergeElement] | None = None,
434434
exclude_dims: AbstractSet = frozenset(),
435435
combine_attrs: CombineAttrsOptions = "override",
436+
compat: CompatOptions | CombineKwargDefault = "minimal",
436437
) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]:
437438
"""Merge variables/indexes from coordinates without automatic alignments.
438439
@@ -457,7 +458,7 @@ def merge_coordinates_without_align(
457458
# TODO: indexes should probably be filtered in collected elements
458459
# before merging them
459460
merged_coords, merged_indexes = merge_collected(
460-
filtered, prioritized, combine_attrs=combine_attrs
461+
filtered, prioritized, compat=compat, combine_attrs=combine_attrs
461462
)
462463
merged_indexes = filter_indexes_from_coords(merged_indexes, set(merged_coords))
463464

xarray/tests/test_dataarray.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
DataArray,
2626
Dataset,
2727
IndexVariable,
28+
MergeError,
2829
Variable,
2930
align,
3031
broadcast,
@@ -2516,6 +2517,43 @@ def test_math_with_coords(self) -> None:
25162517
actual = alt + orig
25172518
assert_identical(expected, actual)
25182519

2520+
def test_math_with_arithmetic_compat_options(self) -> None:
2521+
# Setting up a clash of non-index coordinate 'foo':
2522+
a = xr.DataArray(
2523+
data=[0, 0, 0],
2524+
dims=["x"],
2525+
coords={
2526+
"x": [1, 2, 3],
2527+
"foo": (["x"], [1.0, 2.0, np.nan]),
2528+
},
2529+
)
2530+
b = xr.DataArray(
2531+
data=[0, 0, 0],
2532+
dims=["x"],
2533+
coords={
2534+
"x": [1, 2, 3],
2535+
"foo": (["x"], [np.nan, 2.0, 3.0]),
2536+
},
2537+
)
2538+
2539+
with xr.set_options(arithmetic_compat="minimal"):
2540+
assert_equal(a + b, a.drop_vars("foo"))
2541+
2542+
with xr.set_options(arithmetic_compat="override"):
2543+
assert_equal(a + b, a)
2544+
assert_equal(b + a, b)
2545+
2546+
with xr.set_options(arithmetic_compat="no_conflicts"):
2547+
expected = a.assign_coords(foo=(["x"], [1.0, 2.0, 3.0]))
2548+
assert_equal(a + b, expected)
2549+
assert_equal(b + a, expected)
2550+
2551+
with xr.set_options(arithmetic_compat="equals"):
2552+
with pytest.raises(MergeError):
2553+
a + b
2554+
with pytest.raises(MergeError):
2555+
b + a
2556+
25192557
def test_index_math(self) -> None:
25202558
orig = DataArray(range(3), dims="x", name="x")
25212559
actual = orig + 1

xarray/tests/test_dataset.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6834,6 +6834,41 @@ def test_binary_op_join_setting(self) -> None:
68346834
actual = ds1 + ds2
68356835
assert_equal(actual, expected)
68366836

6837+
def test_binary_op_compat_setting(self) -> None:
6838+
# Setting up a clash of non-index coordinate 'foo':
6839+
a = xr.Dataset(
6840+
data_vars={"var": (["x"], [0, 0, 0])},
6841+
coords={
6842+
"x": [1, 2, 3],
6843+
"foo": (["x"], [1.0, 2.0, np.nan]),
6844+
},
6845+
)
6846+
b = xr.Dataset(
6847+
data_vars={"var": (["x"], [0, 0, 0])},
6848+
coords={
6849+
"x": [1, 2, 3],
6850+
"foo": (["x"], [np.nan, 2.0, 3.0]),
6851+
},
6852+
)
6853+
6854+
with xr.set_options(arithmetic_compat="minimal"):
6855+
assert_equal(a + b, a.drop_vars("foo"))
6856+
6857+
with xr.set_options(arithmetic_compat="override"):
6858+
assert_equal(a + b, a)
6859+
assert_equal(b + a, b)
6860+
6861+
with xr.set_options(arithmetic_compat="no_conflicts"):
6862+
expected = a.assign_coords(foo=(["x"], [1.0, 2.0, 3.0]))
6863+
assert_equal(a + b, expected)
6864+
assert_equal(b + a, expected)
6865+
6866+
with xr.set_options(arithmetic_compat="equals"):
6867+
with pytest.raises(MergeError):
6868+
a + b
6869+
with pytest.raises(MergeError):
6870+
b + a
6871+
68376872
@pytest.mark.parametrize(
68386873
["keep_attrs", "expected"],
68396874
(

0 commit comments

Comments
 (0)