diff --git a/astropy/units/quantity.py b/astropy/units/quantity.py index b98abfafb09c..f8d496e63009 100644 --- a/astropy/units/quantity.py +++ b/astropy/units/quantity.py @@ -636,11 +636,69 @@ def __array_ufunc__(self, function, method, *inputs, **kwargs): result : `~astropy.units.Quantity` Results of the ufunc, with the unit set properly. """ + def _is_plain_numpy(arg): + return isinstance(arg, (np.ndarray, np.generic, numbers.Number)) + + def _is_table_column(arg): + module = getattr(type(arg), "__module__", "") + return module.startswith("astropy.table.") and hasattr(arg, "info") + + # Helper: decide whether to defer to a mixed duck-type input. + # We only defer for binary ufunc __call__ (nin == 2) with mixed + # Quantity and non-Quantity inputs, where the non-Quantity side is a duck that + # either advertises __array_ufunc__ or carries a 'unit' attribute. + def _should_defer_mixed_duck(): + if method != "__call__" or getattr(function, "nin", None) != 2: + return False + # Identify non-Quantity inputs. + has_quantity = any(isinstance(i, Quantity) for i in inputs) + has_non_quantity = any(not isinstance(i, Quantity) for i in inputs) + if not (has_quantity and has_non_quantity): + return False + # If all non-Quantity are plain numpy arrays/scalars, do not defer. + # Any foreign duck? (has __array_ufunc__ or 'unit' attribute) + for inp in inputs: + if isinstance(inp, Quantity) or _is_plain_numpy(inp): + continue + if hasattr(inp, "__array_ufunc__") or hasattr(inp, "unit"): + return True + return False + + if ( + method == "__call__" + and getattr(function, "nin", None) == 2 + and any(isinstance(arg, Quantity) for arg in inputs) + ): + for arg in inputs: + if ( + isinstance(arg, Quantity) + or _is_plain_numpy(arg) + or _is_table_column(arg) + ): + continue + try: + unit_attr = getattr(arg, "unit") + except AttributeError: + continue + except Exception: + return NotImplemented + if unit_attr is not None: + return NotImplemented + # Determine required conversion functions -- to bring the unit of the # input to that expected (e.g., radian for np.sin), or to get # consistent units between two inputs (e.g., in np.add) -- # and the unit of the result (or tuple of units for nout > 1). - converters, unit = converters_and_unit(function, method, *inputs) + try: + converters, unit = converters_and_unit(function, method, *inputs) + except (AttributeError, TypeError, ValueError, UnitConversionError) as err: + # If we are in a mixed duck-type situation for binary arithmetic, + # return NotImplemented to allow numpy to try the other operand. + if _should_defer_mixed_duck(): + return NotImplemented + if isinstance(err, UnitConversionError): + raise UnitTypeError(err.args[0]) from err + raise out = kwargs.get("out", None) # Avoid loop back by turning any Quantity output into array views. @@ -666,8 +724,17 @@ def __array_ufunc__(self, function, method, *inputs, **kwargs): # Same for inputs, but here also convert if necessary. arrays = [] for input_, converter in zip(inputs, converters): - input_ = getattr(input_, "value", input_) - arrays.append(converter(input_) if converter else input_) + try: + input_ = getattr(input_, "value", input_) + if converter: + input_ = converter(input_) + except (AttributeError, TypeError, ValueError): + # If converter value extraction or application fails for mixed + # duck types in binary ufunc calls, defer to the other operand. + if _should_defer_mixed_duck(): + return NotImplemented + raise + arrays.append(input_) # Call our superclass's __array_ufunc__ result = super().__array_ufunc__(function, method, *arrays, **kwargs) diff --git a/astropy/units/tests/test_quantity_ufunc_defer.py b/astropy/units/tests/test_quantity_ufunc_defer.py new file mode 100644 index 000000000000..13e1ed7e1df4 --- /dev/null +++ b/astropy/units/tests/test_quantity_ufunc_defer.py @@ -0,0 +1,221 @@ +import numpy as np +import pytest + +import astropy.units as u +from astropy.table import Column +from astropy.units import Quantity +from astropy.units.core import UnitTypeError + + +class DuckArray: + """Simple duck-type that carries a Quantity internally and handles ufuncs. + + The purpose is to verify that astropy Quantity returns NotImplemented for + mixed operations it cannot interpret, allowing this duck to handle them. + """ + + def __init__(self, q): + assert isinstance(q, Quantity) + self.q = q + + @property + def unit(self): + # Expose a unit attribute to appear quantity-like + return self.q.unit + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + # Convert any Quantity inputs to their values in our unit, then apply. + if method != "__call__" or ufunc.nin != 2: + return NotImplemented + a, b = inputs + target_unit = None + if isinstance(a, Quantity): + target_unit = a.unit + a = a.to(self.unit).value + if isinstance(b, Quantity): + if target_unit is None: + target_unit = b.unit + b = b.to(self.unit).value + # Extract DuckArray payloads + if isinstance(a, DuckArray): + a = a.q.to(self.unit).value + if isinstance(b, DuckArray): + b = b.q.to(self.unit).value + res = getattr(ufunc, method)(a, b, **kwargs) + result_quantity = res * self.unit + if target_unit is not None: + result_quantity = result_quantity.to(target_unit) + return DuckArray(result_quantity) + + # enable comparisons in assertions + def __eq__(self, other): + if isinstance(other, DuckArray): + return np.allclose(other.q.value, self.q.value) and other.q.unit == self.q.unit + return NotImplemented + + +def test_mixed_duck_add_defers_to_duck(): + q = 1 * u.m + duck = DuckArray(1 * u.mm) + # Quantity should return NotImplemented and let duck handle + res = np.add(q, duck) + assert isinstance(res, DuckArray) + assert res.q.unit == u.m + assert np.allclose(res.q.value, 1.001) + + # Also for operator + res2 = q + duck + assert isinstance(res2, DuckArray) + assert res2 == res + + +def test_mixed_duck_reflected_add(): + duck = DuckArray(2 * u.mm) + q = 3 * u.m + res = np.add(duck, q) + assert isinstance(res, DuckArray) + assert res.q.unit == u.m + assert np.allclose(res.q.value, 3.002) + + res2 = duck + q + assert isinstance(res2, DuckArray) + assert res2 == res + + +def test_mixed_duck_subtract_defers_both_ways(): + duck = DuckArray(5 * u.cm) + q = 1 * u.m + + res = np.subtract(q, duck) + assert isinstance(res, DuckArray) + assert res.q.unit == u.m + assert np.allclose(res.q.value, 0.95) + + res2 = np.subtract(duck, q) + assert isinstance(res2, DuckArray) + assert res2.q.unit == u.m + assert np.allclose(res2.q.value, -0.95) + + +def test_incompatible_quantities_still_raise(): + a = 1 * u.m + b = 1 * u.s + with pytest.raises(UnitTypeError): + np.add(a, b) + + +def test_numpy_scalar_and_array_unchanged(): + q = 3 * u.m + s = 2.0 + arr = np.array([1.0, 2.0]) + # Multiplication with numpy scalars/arrays should continue to be handled + # by Quantity as before. + res_s = np.multiply(q, s) + assert isinstance(res_s, Quantity) + assert res_s.unit == u.m + + res_arr = np.multiply(q, arr) + assert isinstance(res_arr, Quantity) + assert res_arr.unit == u.m + + +def test_converter_condition_arg_valueerror_defers(): + # Build a duck that presents a 'value' that is non-numeric so that + # converter(input_) will raise ValueError in the Quantity path. Quantity + # should then return NotImplemented and the duck takes over. + class BadValueDuck: + def __init__(self, unit): + self._unit = unit + self.value = "bad" # non-numeric to trigger ValueError + + @property + def unit(self): + return self._unit + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + if method == "__call__" and ufunc.nin == 2: + return DuckArray(5 * self.unit) + return NotImplemented + + q = 1 * u.m + duck = BadValueDuck(u.m) + res = np.add(q, duck) + assert isinstance(res, DuckArray) + assert res.q == 5 * u.m + + + +def test_converter_discovery_failure_defers(): + # Foreign duck lacks astropy-compatible unit container; getting converter + # may raise TypeError/AttributeError. Quantity should defer. + class ForeignNoConv: + __array_priority__ = 1e6 + def __init__(self): + self.unit = object() + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + if method == "__call__" and ufunc.nin == 2: + return 'handled-by-foreign' + return NotImplemented + q = 1 * u.m + f = ForeignNoConv() + assert np.add(q, f) == 'handled-by-foreign' + + +def test_converter_application_typeerror_defers(): + # Converter may be constructed, but applying it raises TypeError. + class ForeignBadApply: + def __init__(self, unit): + self._unit = unit + class Weird: + # Behaves oddly under numpy array coercion + def __array__(self): + raise TypeError('cannot array-coerce') + self.value = Weird() + @property + def unit(self): + return self._unit + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + if method == "__call__" and ufunc.nin == 2: + return 'handled-typeerror' + return NotImplemented + q = 1 * u.m + f = ForeignBadApply(u.m) + assert np.add(q, f) == 'handled-typeerror' + + +def test_reduce_accumulate_at_unchanged(): + # Ensure we did not alter behavior for ufunc methods other than __call__ + arr = np.array([1.0, 2.0, 3.0]) * u.m + # reduce keeps unit for add + r = np.add.reduce(arr) + assert isinstance(r, Quantity) + assert r.unit == u.m + # accumulate keeps unit and shape + acc = np.add.accumulate(arr) + assert isinstance(acc, Quantity) + assert acc.unit == u.m + assert acc.shape == arr.shape + # at modifies in place; for compatibility we only check it runs + a = arr.copy() + idx = np.array([0]) + np.add.at(a, idx, 1 * u.m) + assert a[0].unit == u.m + + +def test_quantity_column_interaction_preserved(): + column = Column([1, 2], unit=u.m) + q = 3 * u.m + + res = q + column + assert isinstance(res, Quantity) + assert np.allclose(res.value, [4, 5]) + + reflected = column + q + assert hasattr(reflected, "unit") + assert reflected.unit == u.m + assert np.allclose(reflected.value, [4, 5]) + + via_numpy = np.add(column, q) + assert hasattr(via_numpy, "unit") + assert via_numpy.unit == u.m + assert np.allclose(via_numpy.value, [4, 5]) diff --git a/docs/units/quantity.rst b/docs/units/quantity.rst index faf55130cd79..3e438f0d7cd7 100644 --- a/docs/units/quantity.rst +++ b/docs/units/quantity.rst @@ -219,6 +219,16 @@ To perform these operations on |Quantity| objects: >>> 20. * u.cm / (1. * u.m) # doctest: +FLOAT_CMP +.. note:: + Interoperability with duck types: for binary ufunc calls with mixed + operands where the other object is a non-astropy, quantity-like duck + type (e.g., an array that implements the NumPy ``__array_ufunc__`` + protocol or carries a ``unit`` attribute), and Astropy cannot interpret + that input, ``Quantity.__array_ufunc__`` will return ``NotImplemented``. + This allows NumPy to dispatch to the other operand’s implementation or its + reflected operation. Methods such as reduce/accumulate/at are unchanged. Behavior for pure Astropy ``Quantity`` operands and + for NumPy scalars/ndarrays is unchanged. + For multiplication, you can change how to represent the resulting object by using the :meth:`~astropy.units.quantity.Quantity.to` method: