From 104b2cdc144017a224cfda71a7f75bb74ba3aba2 Mon Sep 17 00:00:00 2001 From: Andrew Date: Mon, 13 May 2024 23:49:19 +0100 Subject: [PATCH] td64 dtypes that aren't second --- pint/compat.py | 28 +++++++++++---- pint/facets/plain/quantity.py | 21 ++++++------ pint/testsuite/test_quantity.py | 60 +++++++++++++++++++++++++++------ 3 files changed, 82 insertions(+), 27 deletions(-) diff --git a/pint/compat.py b/pint/compat.py index 612d6be4c..783fa236c 100644 --- a/pint/compat.py +++ b/pint/compat.py @@ -20,6 +20,7 @@ from typing import ( Any, NoReturn, + Tuple, TypeAlias, # noqa ) @@ -305,16 +306,31 @@ def is_timedelta(obj: Any) -> bool: def is_timedelta_array(obj: Any) -> bool: """Check if the object is a datetime array.""" - if isinstance(obj, ndarray) and obj.dtype == np_timedelta64: + if isinstance(obj, ndarray) and obj.dtype.type == np_timedelta64: return True -def to_seconds(obj: Any) -> float: - """Convert a timedelta object to seconds.""" +def convert_timedelta(obj: Any) -> Tuple[float, str]: + """Convert a timedelta object to magnitude and unit string.""" + _dtype_to_unit = { + "timedelta64[Y]": "year", + "timedelta64[M]": "month", + "timedelta64[W]": "week", + "timedelta64[D]": "day", + "timedelta64[h]": "hour", + "timedelta64[m]": "minute", + "timedelta64[s]": "s", + "timedelta64[ms]": "ms", + "timedelta64[us]": "us", + "timedelta64[ns]": "ns", + "timedelta64[ps]": "ps", + "timedelta64[fs]": "fs", + "timedelta64[as]": "as", + } if isinstance(obj, datetime.timedelta): - return obj.total_seconds() - elif isinstance(obj, np_timedelta64) or obj.dtype == np_timedelta64: - return obj.astype(float) + return obj.total_seconds(), "s" + elif isinstance(obj, np_timedelta64) or obj.dtype.type == np_timedelta64: + return obj.astype(float), _dtype_to_unit[str(obj.dtype)] raise TypeError(f"Cannot convert {obj!r} to seconds.") diff --git a/pint/facets/plain/quantity.py b/pint/facets/plain/quantity.py index 638be3974..eabc59662 100644 --- a/pint/facets/plain/quantity.py +++ b/pint/facets/plain/quantity.py @@ -27,6 +27,7 @@ from ...compat import ( HAS_NUMPY, _to_magnitude, + convert_timedelta, deprecated, eq, is_duck_array_type, @@ -34,7 +35,6 @@ is_timedelta_array, is_upcast_type, np, - to_seconds, zero_or_nan, ) from ...errors import DimensionalityError, OffsetUnitCalculusError, PintTypeError @@ -204,11 +204,17 @@ def __new__(cls, value, units=None): if units is None and isinstance(value, cls): return copy.copy(value) - inst = SharedRegistryObject().__new__(cls) - if units is None and (is_timedelta(value) or is_timedelta_array(value)): - units = inst.UnitsContainer({"s": 1}) - elif units is None: + + if is_timedelta(value) or is_timedelta_array(value): + m, u = convert_timedelta(value) + inst._magnitude = m + inst._units = inst.UnitsContainer({u: 1}) + if units: + inst.ito(units) + return inst + + if units is None: units = inst.UnitsContainer() else: if isinstance(units, (UnitsContainer, UnitDefinition)): @@ -228,11 +234,6 @@ def __new__(cls, value, units=None): "units must be of type str, PlainQuantity or " "UnitsContainer; not {}.".format(type(units)) ) - if is_timedelta(value) or is_timedelta_array(value): - inst._magnitude = to_seconds(value) - inst._units = inst.UnitsContainer({"s": 1}) - return inst.to(units) - if isinstance(value, cls): magnitude = value.to(units)._magnitude else: diff --git a/pint/testsuite/test_quantity.py b/pint/testsuite/test_quantity.py index 6bc9ebb1e..3f9af463b 100644 --- a/pint/testsuite/test_quantity.py +++ b/pint/testsuite/test_quantity.py @@ -1892,19 +1892,57 @@ def test_init_quantity(self): assert q_hours == 3 * self.ureg.second assert q_hours.units == self.ureg.hour + @pytest.mark.parametrize( + ["timedelta_unit", "pint_unit"], + ( + pytest.param("s", "second", id="second"), + pytest.param("ms", "millisecond", id="millisecond"), + pytest.param("us", "microsecond", id="microsecond"), + pytest.param("ns", "nanosecond", id="nanosecond"), + pytest.param("m", "minute", id="minute"), + pytest.param("h", "hour", id="hour"), + pytest.param("D", "day", id="day"), + pytest.param("W", "week", id="week"), + pytest.param("M", "month", id="month"), + pytest.param("Y", "year", id="year"), + ), + ) @helpers.requires_numpy - def test_init_quantity_np(self): - td = np.timedelta64(3, "s") - assert self.Q_(td) == 3 * self.ureg.second - q_hours = self.Q_(td, "hours") - assert q_hours == 3 * self.ureg.second - assert q_hours.units == self.ureg.hour + def test_init_quantity_np(self, timedelta_unit, pint_unit): + # test init with the timedelta unit + td = np.timedelta64(3, timedelta_unit) + result = self.Q_(td) + expected = self.Q_(3, pint_unit) + helpers.assert_quantity_almost_equal(result, expected) + # check units are same. Use Q_ since Unit(s) != Unit(second) + helpers.assert_quantity_almost_equal( + self.Q_(1, result.units), self.Q_(1, expected.units) + ) - td = np.array([3], dtype="timedelta64") - assert self.Q_(td) == np.array([3]) * self.ureg.second - q_hours = self.Q_(td, "hours") - assert q_hours == np.array([3]) * self.ureg.second - assert q_hours.units == self.ureg.hour + # test init with unit specified + result = self.Q_(td, "hours") + expected = self.Q_(3, pint_unit).to("hours") + helpers.assert_quantity_almost_equal(result, expected) + helpers.assert_quantity_almost_equal( + self.Q_(1, result.units), self.Q_(1, expected.units) + ) + + # test array + td = np.array([3], dtype="timedelta64[{}]".format(timedelta_unit)) + result = self.Q_(td) + expected = self.Q_([3], pint_unit) + helpers.assert_quantity_almost_equal(result, expected) + helpers.assert_quantity_almost_equal( + self.Q_(1, result.units), self.Q_(1, expected.units) + ) + + # test array with unit specified + result = self.Q_(td, "hours") + expected = self.Q_([3], pint_unit).to("hours") + helpers.assert_quantity_almost_equal(result, expected) + helpers.assert_quantity_almost_equal( + self.Q_(1, result.units), self.Q_(1, expected.units) + ) # TODO: do not subclass from QuantityTestCase