Skip to content

Commit

Permalink
td64 dtypes that aren't second
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgsavage committed May 13, 2024
1 parent 33cccf1 commit 104b2cd
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 27 deletions.
28 changes: 22 additions & 6 deletions pint/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import (
Any,
NoReturn,
Tuple,
TypeAlias, # noqa
)

Expand Down Expand Up @@ -305,16 +306,31 @@ def is_timedelta(obj: Any) -> bool:

def is_timedelta_array(obj: Any) -> bool:
"""Check if the object is a datetime array."""
if isinstance(obj, ndarray) and obj.dtype == np_timedelta64:
if isinstance(obj, ndarray) and obj.dtype.type == np_timedelta64:
return True


def to_seconds(obj: Any) -> float:
"""Convert a timedelta object to seconds."""
def convert_timedelta(obj: Any) -> Tuple[float, str]:
"""Convert a timedelta object to magnitude and unit string."""
_dtype_to_unit = {
"timedelta64[Y]": "year",
"timedelta64[M]": "month",
"timedelta64[W]": "week",
"timedelta64[D]": "day",
"timedelta64[h]": "hour",
"timedelta64[m]": "minute",
"timedelta64[s]": "s",
"timedelta64[ms]": "ms",
"timedelta64[us]": "us",
"timedelta64[ns]": "ns",
"timedelta64[ps]": "ps",
"timedelta64[fs]": "fs",
"timedelta64[as]": "as",
}
if isinstance(obj, datetime.timedelta):
return obj.total_seconds()
elif isinstance(obj, np_timedelta64) or obj.dtype == np_timedelta64:
return obj.astype(float)
return obj.total_seconds(), "s"
elif isinstance(obj, np_timedelta64) or obj.dtype.type == np_timedelta64:
return obj.astype(float), _dtype_to_unit[str(obj.dtype)]
raise TypeError(f"Cannot convert {obj!r} to seconds.")


Expand Down
21 changes: 11 additions & 10 deletions pint/facets/plain/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@
from ...compat import (
HAS_NUMPY,
_to_magnitude,
convert_timedelta,
deprecated,
eq,
is_duck_array_type,
is_timedelta,
is_timedelta_array,
is_upcast_type,
np,
to_seconds,
zero_or_nan,
)
from ...errors import DimensionalityError, OffsetUnitCalculusError, PintTypeError
Expand Down Expand Up @@ -204,11 +204,17 @@ def __new__(cls, value, units=None):

if units is None and isinstance(value, cls):
return copy.copy(value)

inst = SharedRegistryObject().__new__(cls)
if units is None and (is_timedelta(value) or is_timedelta_array(value)):
units = inst.UnitsContainer({"s": 1})
elif units is None:

if is_timedelta(value) or is_timedelta_array(value):
m, u = convert_timedelta(value)
inst._magnitude = m
inst._units = inst.UnitsContainer({u: 1})
if units:
inst.ito(units)
return inst

if units is None:
units = inst.UnitsContainer()
else:
if isinstance(units, (UnitsContainer, UnitDefinition)):
Expand All @@ -228,11 +234,6 @@ def __new__(cls, value, units=None):
"units must be of type str, PlainQuantity or "
"UnitsContainer; not {}.".format(type(units))
)
if is_timedelta(value) or is_timedelta_array(value):
inst._magnitude = to_seconds(value)
inst._units = inst.UnitsContainer({"s": 1})
return inst.to(units)

if isinstance(value, cls):
magnitude = value.to(units)._magnitude
else:
Expand Down
60 changes: 49 additions & 11 deletions pint/testsuite/test_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,19 +1892,57 @@ def test_init_quantity(self):
assert q_hours == 3 * self.ureg.second
assert q_hours.units == self.ureg.hour

@pytest.mark.parametrize(
["timedelta_unit", "pint_unit"],
(
pytest.param("s", "second", id="second"),
pytest.param("ms", "millisecond", id="millisecond"),
pytest.param("us", "microsecond", id="microsecond"),
pytest.param("ns", "nanosecond", id="nanosecond"),
pytest.param("m", "minute", id="minute"),
pytest.param("h", "hour", id="hour"),
pytest.param("D", "day", id="day"),
pytest.param("W", "week", id="week"),
pytest.param("M", "month", id="month"),
pytest.param("Y", "year", id="year"),
),
)
@helpers.requires_numpy
def test_init_quantity_np(self):
td = np.timedelta64(3, "s")
assert self.Q_(td) == 3 * self.ureg.second
q_hours = self.Q_(td, "hours")
assert q_hours == 3 * self.ureg.second
assert q_hours.units == self.ureg.hour
def test_init_quantity_np(self, timedelta_unit, pint_unit):
# test init with the timedelta unit
td = np.timedelta64(3, timedelta_unit)
result = self.Q_(td)
expected = self.Q_(3, pint_unit)
helpers.assert_quantity_almost_equal(result, expected)
# check units are same. Use Q_ since Unit(s) != Unit(second)
helpers.assert_quantity_almost_equal(
self.Q_(1, result.units), self.Q_(1, expected.units)
)

td = np.array([3], dtype="timedelta64")
assert self.Q_(td) == np.array([3]) * self.ureg.second
q_hours = self.Q_(td, "hours")
assert q_hours == np.array([3]) * self.ureg.second
assert q_hours.units == self.ureg.hour
# test init with unit specified
result = self.Q_(td, "hours")
expected = self.Q_(3, pint_unit).to("hours")
helpers.assert_quantity_almost_equal(result, expected)
helpers.assert_quantity_almost_equal(
self.Q_(1, result.units), self.Q_(1, expected.units)
)

# test array
td = np.array([3], dtype="timedelta64[{}]".format(timedelta_unit))
result = self.Q_(td)
expected = self.Q_([3], pint_unit)
helpers.assert_quantity_almost_equal(result, expected)
helpers.assert_quantity_almost_equal(
self.Q_(1, result.units), self.Q_(1, expected.units)
)

# test array with unit specified
result = self.Q_(td, "hours")
expected = self.Q_([3], pint_unit).to("hours")
helpers.assert_quantity_almost_equal(result, expected)
helpers.assert_quantity_almost_equal(
self.Q_(1, result.units), self.Q_(1, expected.units)
)


# TODO: do not subclass from QuantityTestCase
Expand Down

0 comments on commit 104b2cd

Please sign in to comment.