Skip to content

Commit

Permalink
Merge pull request #164 from neutrinoceros/bugfix_yt_GH_874_2
Browse files Browse the repository at this point in the history
bugfix: fix commutativity in unyt_array operators
  • Loading branch information
ngoldbaum authored Aug 1, 2020
2 parents e957e4c + 9b72505 commit 14e4cdb
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
14 changes: 12 additions & 2 deletions unyt/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
UnitOperationError,
UnitConversionError,
UnitsNotReducible,
SymbolNotFoundError,
)
from unyt.equivalencies import equivalence_registry
from unyt._on_demand_imports import _astropy, _pint
Expand Down Expand Up @@ -160,7 +161,13 @@ def _sqrt_unit(unit):

@lru_cache(maxsize=128, typed=False)
def _multiply_units(unit1, unit2):
ret = (unit1 * unit2).simplify()
try:
ret = (unit1 * unit2).simplify()
except SymbolNotFoundError:
# Some operators are not natively commutative when operands are
# defined within different unit registries, and conversion
# is defined one way but not the other.
ret = (unit2 * unit1).simplify()
return ret.as_coeff_unit()


Expand Down Expand Up @@ -195,7 +202,10 @@ def _square_unit(unit):

@lru_cache(maxsize=128, typed=False)
def _divide_units(unit1, unit2):
ret = (unit1 / unit2).simplify()
try:
ret = (unit1 / unit2).simplify()
except SymbolNotFoundError:
ret = (1 / (unit2 / unit1).simplify()).units
return ret.as_coeff_unit()


Expand Down
15 changes: 15 additions & 0 deletions unyt/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import pytest
from sympy import Symbol

from unyt.array import unyt_quantity
from unyt.testing import assert_allclose_units
from unyt.unit_registry import UnitRegistry
from unyt.dimensions import (
Expand Down Expand Up @@ -874,3 +875,17 @@ def test_degF():
def test_delta_degF():
a = 1 * Unit("delta_degF")
assert str(a) == "1 Δ°F"


def test_mixed_registry_operations():

reg = UnitRegistry(unit_system="cgs")
reg.add("fake_length", 0.001, length)
a = unyt_quantity(1, units="fake_length", registry=reg)
b = unyt_quantity(1, "cm")

assert_almost_equal(a + b, b + a)
assert_almost_equal(a - b, -(b - a))
assert_almost_equal(a * b, b * a)
assert_almost_equal(b / a, b / a.in_units("km"))
assert_almost_equal(a / b, a / b.in_units("km"))

0 comments on commit 14e4cdb

Please sign in to comment.