diff --git a/docs/usage.rst b/docs/usage.rst index f3fa25c8..9b98428f 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -1024,6 +1024,20 @@ When writing tests, it is convenient to use :mod:`unyt.testing`. In particular, >>> desired = actual.to("cm") >>> assert_allclose_units(actual, desired) +Integrating :mod:`unyt` Into a Legacy Code with Non-Strict Mode +--------------------------------------------------------------- + +If using a custom :class:`UnitRegistry `, it +is possible to supply ``strict=False`` when initializing. This will change the behavior +when an invalid operation is attempted. Instead of raising a :class:`UnitOperationError `, +a :class:`UnitOperationWarning ` will be +provided instead, units will be stripped, and the operation will return a value +without units. + +This behavior is not recommended in general since many of the important benefits +of :mod:`unyt` are lost, but it can be useful for integrating :mod:`unyt` into +legacy code which was not previously unit-aware so that unit support can be +added gradually. Custom Unit Systems ------------------- diff --git a/unyt/array.py b/unyt/array.py index 144a2c75..81338539 100644 --- a/unyt/array.py +++ b/unyt/array.py @@ -14,10 +14,10 @@ # ----------------------------------------------------------------------------- import copy - +import re from functools import lru_cache from numbers import Number as numeric_type -import re + import numpy as np from numpy import ( add, @@ -123,6 +123,7 @@ MKSCGSConversionError, UnitOperationError, UnitConversionError, + UnitOperationWarning, UnitsNotReducible, SymbolNotFoundError, ) @@ -166,6 +167,20 @@ def _iterable(obj): return True +def _unit_operation_error_raise_or_warn(ufunc, u0, u1, func, *inputs): + if ( + hasattr(u0, "units") + and not u0.units.registry.strict + and hasattr(u1, "units") + and not u1.units.registry.strict + ): + warnings.warn(UnitOperationWarning(ufunc, u0, u1)) + unwrapped_inputs = [i.value if isinstance(i, unyt_array) else i for i in inputs] + return func(*unwrapped_inputs) + else: + raise UnitOperationError(ufunc, u0, u1) + + @lru_cache(maxsize=128, typed=False) def _sqrt_unit(unit): return 1, unit ** 0.5 @@ -1759,12 +1774,16 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): elif ufunc is power: u1 = inp1 if inp0.shape != () and inp1.shape != (): - raise UnitOperationError(ufunc, u0, u1) + return _unit_operation_error_raise_or_warn( + ufunc, u0, u1, func, *inputs + ) if isinstance(u1, unyt_array): if u1.units.is_dimensionless: pass else: - raise UnitOperationError(ufunc, u0, u1.units) + return _unit_operation_error_raise_or_warn( + ufunc, u0, u1.units, func, *inputs + ) if u1.shape == (): u1 = float(u1) else: @@ -1814,9 +1833,13 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): ret = bool(ret) return ret else: - raise UnitOperationError(ufunc, u0, u1) + return _unit_operation_error_raise_or_warn( + ufunc, u0, u1, func, *inputs + ) else: - raise UnitOperationError(ufunc, u0, u1) + return _unit_operation_error_raise_or_warn( + ufunc, u0, u1, func, *inputs + ) conv, offset = u1.get_conversion_factor(u0, inp1.dtype) new_dtype = np.dtype("f" + str(inp1.dtype.itemsize)) conv = new_dtype.type(conv) diff --git a/unyt/exceptions.py b/unyt/exceptions.py index d1001632..2d6c3991 100644 --- a/unyt/exceptions.py +++ b/unyt/exceptions.py @@ -47,6 +47,29 @@ def __str__(self): return err +class UnitOperationWarning(UserWarning): + """A warning that is raised when unit operations are not allowed but + when running with a UnitRegistry with strict=False. In this case, + the operation is allowed to continue but with unit information stripped. + """ + + def __init__(self, operation, unit1, unit2=None): + self.operation = operation + self.unit1 = unit1 + self.unit2 = unit2 + UserWarning.__init__(self) + + def __str__(self): + err = ( + f"The {self.operation.__name__} operator for unyt_arrays" + f" with units {self.unit1!r} (dimensions {self.unit1.dimensions!r})" + ) + if self.unit2 is not None: + err += f" and {self.unit2!r} (dimensions {self.unit2.dimensions!r})" + err += " is not well defined. Performing operation without units instead." + return err + + class UnitConversionError(Exception): """An error raised when converting to a unit with different dimensions. diff --git a/unyt/tests/test_unyt_array.py b/unyt/tests/test_unyt_array.py index 7b4ec7a1..cdd724d9 100644 --- a/unyt/tests/test_unyt_array.py +++ b/unyt/tests/test_unyt_array.py @@ -17,22 +17,25 @@ import copy import itertools import math -import numpy as np import operator import os import pickle -import pytest import shutil import tempfile import warnings +import numpy as np +import pytest +from numpy import array from numpy.testing import ( assert_array_equal, assert_equal, assert_array_almost_equal, assert_almost_equal, ) -from numpy import array +from unyt import dimensions, Unit, degC, K, delta_degC, degF, R, delta_degF +from unyt._on_demand_imports import _astropy, _h5py, _pint, NotAModule +from unyt._physical_ratios import metallicity_sun, speed_of_light_cm_per_s from unyt.array import ( unyt_array, unyt_quantity, @@ -55,15 +58,13 @@ IterableUnitCoercionError, UnitConversionError, UnitOperationError, + UnitOperationWarning, UnitParseError, UnitsNotReducible, ) from unyt.testing import assert_allclose_units, _process_warning -from unyt.unit_symbols import cm, m, g, degree from unyt.unit_registry import UnitRegistry -from unyt._on_demand_imports import _astropy, _h5py, _pint, NotAModule -from unyt._physical_ratios import metallicity_sun, speed_of_light_cm_per_s -from unyt import dimensions, Unit, degC, K, delta_degC, degF, R, delta_degF +from unyt.unit_symbols import cm, m, g, degree def operate_and_compare(a, b, op, answer): @@ -2539,3 +2540,15 @@ def test_invalid_unit_quantity_from_string(s): match="Could not find unit symbol '{}' in the provided symbols.".format(un_str), ): unyt_quantity.from_string(s) + + +def test_non_strict_registry(): + + reg = UnitRegistry(strict=False) + + a1 = unyt_array([1, 2, 3], "m", registry=reg) + a2 = unyt_array([4, 5, 6], "kg", registry=reg) + + with pytest.warns(UnitOperationWarning): + answer = operator.add(a1, a2) + assert_array_equal(answer, [5, 7, 9]) diff --git a/unyt/unit_registry.py b/unyt/unit_registry.py index 50c44da1..0148bf64 100644 --- a/unyt/unit_registry.py +++ b/unyt/unit_registry.py @@ -55,7 +55,9 @@ class UnitRegistry: _unit_system_id = None - def __init__(self, add_default_symbols=True, lut=None, unit_system=None): + def __init__( + self, add_default_symbols=True, lut=None, unit_system=None, *, strict=True + ): self._unit_object_cache = {} if lut: self.lut = lut @@ -67,6 +69,12 @@ def __init__(self, add_default_symbols=True, lut=None, unit_system=None): if add_default_symbols: self.lut.update(default_unit_symbol_lut) + # This boolean determines whether to raise a UnitOperationError or + # strip units and provide a UnitOperationWarning if an invalid + # operation is attempted. The default is strict=True and is + # strongly recommended. + self.strict = strict + def __getitem__(self, key): try: ret = self.lut[str(key)]