Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 70 additions & 3 deletions astropy/units/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,11 +636,69 @@ def __array_ufunc__(self, function, method, *inputs, **kwargs):
result : `~astropy.units.Quantity`
Results of the ufunc, with the unit set properly.
"""
def _is_plain_numpy(arg):
return isinstance(arg, (np.ndarray, np.generic, numbers.Number))

def _is_table_column(arg):
module = getattr(type(arg), "__module__", "")
return module.startswith("astropy.table.") and hasattr(arg, "info")

# Helper: decide whether to defer to a mixed duck-type input.
# We only defer for binary ufunc __call__ (nin == 2) with mixed
# Quantity and non-Quantity inputs, where the non-Quantity side is a duck that
# either advertises __array_ufunc__ or carries a 'unit' attribute.
def _should_defer_mixed_duck():
if method != "__call__" or getattr(function, "nin", None) != 2:
return False
# Identify non-Quantity inputs.
has_quantity = any(isinstance(i, Quantity) for i in inputs)
has_non_quantity = any(not isinstance(i, Quantity) for i in inputs)
if not (has_quantity and has_non_quantity):
return False
# If all non-Quantity are plain numpy arrays/scalars, do not defer.
# Any foreign duck? (has __array_ufunc__ or 'unit' attribute)
for inp in inputs:
if isinstance(inp, Quantity) or _is_plain_numpy(inp):
continue
if hasattr(inp, "__array_ufunc__") or hasattr(inp, "unit"):
return True
return False

if (
method == "__call__"
and getattr(function, "nin", None) == 2
and any(isinstance(arg, Quantity) for arg in inputs)
):
for arg in inputs:
if (
isinstance(arg, Quantity)
or _is_plain_numpy(arg)
or _is_table_column(arg)
):
continue
try:
unit_attr = getattr(arg, "unit")
except AttributeError:
continue
except Exception:
return NotImplemented
if unit_attr is not None:
return NotImplemented

# Determine required conversion functions -- to bring the unit of the
# input to that expected (e.g., radian for np.sin), or to get
# consistent units between two inputs (e.g., in np.add) --
# and the unit of the result (or tuple of units for nout > 1).
converters, unit = converters_and_unit(function, method, *inputs)
try:
converters, unit = converters_and_unit(function, method, *inputs)
except (AttributeError, TypeError, ValueError, UnitConversionError) as err:
# If we are in a mixed duck-type situation for binary arithmetic,
# return NotImplemented to allow numpy to try the other operand.
if _should_defer_mixed_duck():
return NotImplemented
if isinstance(err, UnitConversionError):
raise UnitTypeError(err.args[0]) from err
raise

out = kwargs.get("out", None)
# Avoid loop back by turning any Quantity output into array views.
Expand All @@ -666,8 +724,17 @@ def __array_ufunc__(self, function, method, *inputs, **kwargs):
# Same for inputs, but here also convert if necessary.
arrays = []
for input_, converter in zip(inputs, converters):
input_ = getattr(input_, "value", input_)
arrays.append(converter(input_) if converter else input_)
try:
input_ = getattr(input_, "value", input_)
if converter:
input_ = converter(input_)
except (AttributeError, TypeError, ValueError):
# If converter value extraction or application fails for mixed
# duck types in binary ufunc calls, defer to the other operand.
if _should_defer_mixed_duck():
return NotImplemented
raise
arrays.append(input_)

# Call our superclass's __array_ufunc__
result = super().__array_ufunc__(function, method, *arrays, **kwargs)
Expand Down
221 changes: 221 additions & 0 deletions astropy/units/tests/test_quantity_ufunc_defer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import numpy as np
import pytest

import astropy.units as u
from astropy.table import Column
from astropy.units import Quantity
from astropy.units.core import UnitTypeError


class DuckArray:
"""Simple duck-type that carries a Quantity internally and handles ufuncs.

The purpose is to verify that astropy Quantity returns NotImplemented for
mixed operations it cannot interpret, allowing this duck to handle them.
"""

def __init__(self, q):
assert isinstance(q, Quantity)
self.q = q

@property
def unit(self):
# Expose a unit attribute to appear quantity-like
return self.q.unit

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
# Convert any Quantity inputs to their values in our unit, then apply.
if method != "__call__" or ufunc.nin != 2:
return NotImplemented
a, b = inputs
target_unit = None
if isinstance(a, Quantity):
target_unit = a.unit
a = a.to(self.unit).value
if isinstance(b, Quantity):
if target_unit is None:
target_unit = b.unit
b = b.to(self.unit).value
# Extract DuckArray payloads
if isinstance(a, DuckArray):
a = a.q.to(self.unit).value
if isinstance(b, DuckArray):
b = b.q.to(self.unit).value
res = getattr(ufunc, method)(a, b, **kwargs)
result_quantity = res * self.unit
if target_unit is not None:
result_quantity = result_quantity.to(target_unit)
return DuckArray(result_quantity)

# enable comparisons in assertions
def __eq__(self, other):
if isinstance(other, DuckArray):
return np.allclose(other.q.value, self.q.value) and other.q.unit == self.q.unit
return NotImplemented


def test_mixed_duck_add_defers_to_duck():
q = 1 * u.m
duck = DuckArray(1 * u.mm)
# Quantity should return NotImplemented and let duck handle
res = np.add(q, duck)
assert isinstance(res, DuckArray)
assert res.q.unit == u.m
assert np.allclose(res.q.value, 1.001)

# Also for operator
res2 = q + duck
assert isinstance(res2, DuckArray)
assert res2 == res


def test_mixed_duck_reflected_add():
duck = DuckArray(2 * u.mm)
q = 3 * u.m
res = np.add(duck, q)
assert isinstance(res, DuckArray)
assert res.q.unit == u.m
assert np.allclose(res.q.value, 3.002)

res2 = duck + q
assert isinstance(res2, DuckArray)
assert res2 == res


def test_mixed_duck_subtract_defers_both_ways():
duck = DuckArray(5 * u.cm)
q = 1 * u.m

res = np.subtract(q, duck)
assert isinstance(res, DuckArray)
assert res.q.unit == u.m
assert np.allclose(res.q.value, 0.95)

res2 = np.subtract(duck, q)
assert isinstance(res2, DuckArray)
assert res2.q.unit == u.m
assert np.allclose(res2.q.value, -0.95)


def test_incompatible_quantities_still_raise():
a = 1 * u.m
b = 1 * u.s
with pytest.raises(UnitTypeError):
np.add(a, b)


def test_numpy_scalar_and_array_unchanged():
q = 3 * u.m
s = 2.0
arr = np.array([1.0, 2.0])
# Multiplication with numpy scalars/arrays should continue to be handled
# by Quantity as before.
res_s = np.multiply(q, s)
assert isinstance(res_s, Quantity)
assert res_s.unit == u.m

res_arr = np.multiply(q, arr)
assert isinstance(res_arr, Quantity)
assert res_arr.unit == u.m


def test_converter_condition_arg_valueerror_defers():
# Build a duck that presents a 'value' that is non-numeric so that
# converter(input_) will raise ValueError in the Quantity path. Quantity
# should then return NotImplemented and the duck takes over.
class BadValueDuck:
def __init__(self, unit):
self._unit = unit
self.value = "bad" # non-numeric to trigger ValueError

@property
def unit(self):
return self._unit

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if method == "__call__" and ufunc.nin == 2:
return DuckArray(5 * self.unit)
return NotImplemented

q = 1 * u.m
duck = BadValueDuck(u.m)
res = np.add(q, duck)
assert isinstance(res, DuckArray)
assert res.q == 5 * u.m



def test_converter_discovery_failure_defers():
# Foreign duck lacks astropy-compatible unit container; getting converter
# may raise TypeError/AttributeError. Quantity should defer.
class ForeignNoConv:
__array_priority__ = 1e6
def __init__(self):
self.unit = object()
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if method == "__call__" and ufunc.nin == 2:
return 'handled-by-foreign'
return NotImplemented
q = 1 * u.m
f = ForeignNoConv()
assert np.add(q, f) == 'handled-by-foreign'


def test_converter_application_typeerror_defers():
# Converter may be constructed, but applying it raises TypeError.
class ForeignBadApply:
def __init__(self, unit):
self._unit = unit
class Weird:
# Behaves oddly under numpy array coercion
def __array__(self):
raise TypeError('cannot array-coerce')
self.value = Weird()
@property
def unit(self):
return self._unit
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if method == "__call__" and ufunc.nin == 2:
return 'handled-typeerror'
return NotImplemented
q = 1 * u.m
f = ForeignBadApply(u.m)
assert np.add(q, f) == 'handled-typeerror'


def test_reduce_accumulate_at_unchanged():
# Ensure we did not alter behavior for ufunc methods other than __call__
arr = np.array([1.0, 2.0, 3.0]) * u.m
# reduce keeps unit for add
r = np.add.reduce(arr)
assert isinstance(r, Quantity)
assert r.unit == u.m
# accumulate keeps unit and shape
acc = np.add.accumulate(arr)
assert isinstance(acc, Quantity)
assert acc.unit == u.m
assert acc.shape == arr.shape
# at modifies in place; for compatibility we only check it runs
a = arr.copy()
idx = np.array([0])
np.add.at(a, idx, 1 * u.m)
assert a[0].unit == u.m


def test_quantity_column_interaction_preserved():
column = Column([1, 2], unit=u.m)
q = 3 * u.m

res = q + column
assert isinstance(res, Quantity)
assert np.allclose(res.value, [4, 5])

reflected = column + q
assert hasattr(reflected, "unit")
assert reflected.unit == u.m
assert np.allclose(reflected.value, [4, 5])

via_numpy = np.add(column, q)
assert hasattr(via_numpy, "unit")
assert via_numpy.unit == u.m
assert np.allclose(via_numpy.value, [4, 5])
10 changes: 10 additions & 0 deletions docs/units/quantity.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,16 @@ To perform these operations on |Quantity| objects:
>>> 20. * u.cm / (1. * u.m) # doctest: +FLOAT_CMP
<Quantity 20. cm / m>

.. note::
Interoperability with duck types: for binary ufunc calls with mixed
operands where the other object is a non-astropy, quantity-like duck
type (e.g., an array that implements the NumPy ``__array_ufunc__``
protocol or carries a ``unit`` attribute), and Astropy cannot interpret
that input, ``Quantity.__array_ufunc__`` will return ``NotImplemented``.
This allows NumPy to dispatch to the other operand’s implementation or its
reflected operation. Methods such as reduce/accumulate/at are unchanged. Behavior for pure Astropy ``Quantity`` operands and
for NumPy scalars/ndarrays is unchanged.

For multiplication, you can change how to represent the resulting object by
using the :meth:`~astropy.units.quantity.Quantity.to` method:

Expand Down