diff --git a/pymatgen/core/units.py b/pymatgen/core/units.py index 11364373f98..8c7b228c256 100644 --- a/pymatgen/core/units.py +++ b/pymatgen/core/units.py @@ -1,8 +1,8 @@ -"""This module implements a FloatWithUnit, which is a subclass of float. It -also defines supported units for some commonly used units for energy, length, -temperature, time and charge. FloatWithUnit also support conversion to one +"""This module implements FloatWithUnit, a subclass of float. It +also defines supported units for commonly used units for energy, length, +temperature, time and charge. FloatWithUnit also support conversion to another, and additions and subtractions perform automatic conversion if -units are detected. An ArrayWithUnit is also implemented, which is a subclass +units are detected. An ArrayWithUnit is also implemented, a subclass of numpy's ndarray with similar unit features. """ @@ -19,8 +19,10 @@ import scipy.constants as const if TYPE_CHECKING: + from collections.abc import Iterator from typing import Any + from numpy.typing import NDArray from typing_extensions import Self __author__ = "Shyue Ping Ong, Matteo Giantomassi" @@ -167,7 +169,7 @@ class Unit(collections.abc.Mapping): Only integer powers are supported for units. """ - def __init__(self, unit_def) -> None: + def __init__(self, unit_def: str | dict[str, int]) -> None: """ Args: unit_def: A definition for the unit. Either a mapping of unit to @@ -180,34 +182,34 @@ def __init__(self, unit_def) -> None: unit: dict[str, int] = defaultdict(int) for match in re.finditer(r"([A-Za-z]+)\s*\^*\s*([\-0-9]*)", unit_def): - val = match.group(2) - val = 1 if not val else int(val) - key = match.group(1) + val = match[2] + val = int(val) if val else 1 + key = match[1] unit[key] += val else: unit = {k: v for k, v in dict(unit_def).items() if v != 0} self._unit = _check_mappings(unit) - def __mul__(self, other): - new_units = defaultdict(int) + def __mul__(self, other: Self) -> Self: + new_units: defaultdict = defaultdict(int) for k, v in self.items(): new_units[k] += v for k, v in other.items(): new_units[k] += v - return Unit(new_units) + return type(self)(new_units) - def __truediv__(self, other): - new_units = defaultdict(int) + def __truediv__(self, other: Self) -> Self: + new_units: defaultdict = defaultdict(int) for k, v in self.items(): new_units[k] += v for k, v in other.items(): new_units[k] -= v - return Unit(new_units) + return type(self)(new_units) - def __pow__(self, i): - return Unit({k: v * i for k, v in self.items()}) + def __pow__(self, i: Self) -> Self: + return type(self)({k: v * i for k, v in self.items()}) - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self._unit) def __getitem__(self, i) -> int: @@ -223,15 +225,15 @@ def __repr__(self) -> str: ) @property - def as_base_units(self): + def as_base_units(self) -> tuple[dict, float]: """Convert all units to base SI units, including derived units. Returns: tuple[dict, float]: (base_units_dict, scaling factor). base_units_dict will not contain any constants, which are gathered in the scaling factor. """ - b = defaultdict(int) - factor = 1 + base_units: defaultdict = defaultdict(int) + factor: float = 1 for k, v in self.items(): derived = False for dct in DERIVED_UNITS.values(): @@ -240,16 +242,16 @@ def as_base_units(self): if isinstance(k2, Number): factor *= k2 ** (v2 * v) else: - b[k2] += v2 * v + base_units[k2] += v2 * v derived = True break if not derived: si, f = _get_si_unit(k) - b[si] += v + base_units[si] += v factor *= f**v - return {k: v for k, v in b.items() if v != 0}, factor + return {k: v for k, v in base_units.items() if v != 0}, factor - def get_conversion_factor(self, new_unit): + def get_conversion_factor(self, new_unit: str) -> float: """Get a conversion factor between this unit and a new unit. Compound units are supported, but must have the same powers in each unit type. @@ -258,9 +260,10 @@ def get_conversion_factor(self, new_unit): new_unit: The new unit. """ old_base, old_factor = self.as_base_units - new_base, new_factor = Unit(new_unit).as_base_units + new_base, new_factor = type(self)(new_unit).as_base_units units_new = sorted(new_base.items(), key=lambda d: _UNAME2UTYPE[d[0]]) units_old = sorted(old_base.items(), key=lambda d: _UNAME2UTYPE[d[0]]) + factor = old_factor / new_factor for old, new in zip(units_old, units_new): if old[1] != new[1]: @@ -292,16 +295,22 @@ class FloatWithUnit(float): 32.932522246000005 eV """ - def __init__(self, val: float | Number, unit: str, unit_type: str | None = None) -> None: + def __init__( + self, + val: float | Number, + unit: str, + unit_type: str | None = None, + ) -> None: """Initialize a float with unit. Args: val (float): Value - unit (Unit): A unit. e.g. "C". + unit (str): A unit. e.g. "C". unit_type (str): A type of unit. e.g. "charge" """ if unit_type is not None and str(unit) not in ALL_UNITS[unit_type]: raise UnitError(f"{unit} is not a supported unit for {unit_type}") + self._unit = Unit(unit) self._unit_type = unit_type @@ -323,7 +332,7 @@ def __add__(self, other): val = other if other.unit != self._unit: val = other.to(self._unit) - return FloatWithUnit(float(self) + val, unit_type=self._unit_type, unit=self._unit) + return type(self)(float(self) + val, unit_type=self._unit_type, unit=self._unit) def __sub__(self, other): if not hasattr(other, "unit_type"): @@ -333,29 +342,29 @@ def __sub__(self, other): val = other if other.unit != self._unit: val = other.to(self._unit) - return FloatWithUnit(float(self) - val, unit_type=self._unit_type, unit=self._unit) + return type(self)(float(self) - val, unit_type=self._unit_type, unit=self._unit) def __mul__(self, other): - if not isinstance(other, FloatWithUnit): - return FloatWithUnit(float(self) * other, unit_type=self._unit_type, unit=self._unit) - return FloatWithUnit(float(self) * other, unit_type=None, unit=self._unit * other._unit) + if not isinstance(other, type(self)): + return type(self)(float(self) * other, unit_type=self._unit_type, unit=self._unit) + return type(self)(float(self) * other, unit_type=None, unit=self._unit * other._unit) def __rmul__(self, other): - if not isinstance(other, FloatWithUnit): - return FloatWithUnit(float(self) * other, unit_type=self._unit_type, unit=self._unit) - return FloatWithUnit(float(self) * other, unit_type=None, unit=self._unit * other._unit) + if not isinstance(other, type(self)): + return type(self)(float(self) * other, unit_type=self._unit_type, unit=self._unit) + return type(self)(float(self) * other, unit_type=None, unit=self._unit * other._unit) def __pow__(self, i): - return FloatWithUnit(float(self) ** i, unit_type=None, unit=self._unit**i) + return type(self)(float(self) ** i, unit_type=None, unit=self._unit**i) def __truediv__(self, other): val = super().__truediv__(other) - if not isinstance(other, FloatWithUnit): - return FloatWithUnit(val, unit_type=self._unit_type, unit=self._unit) - return FloatWithUnit(val, unit_type=None, unit=self._unit / other._unit) + if not isinstance(other, type(self)): + return type(self)(val, unit_type=self._unit_type, unit=self._unit) + return type(self)(val, unit_type=None, unit=self._unit / other._unit) def __neg__(self): - return FloatWithUnit(super().__neg__(), unit_type=self._unit_type, unit=self._unit) + return type(self)(super().__neg__(), unit_type=self._unit_type, unit=self._unit) def __getnewargs__(self): """Used by pickle to recreate object.""" @@ -408,25 +417,25 @@ def from_str(cls, string: str) -> Self: return cls(num, unit, unit_type=unit_type) return cls(num, unit, unit_type=None) - def to(self, new_unit): - """Conversion to a new_unit. Right now, only supports 1 to 1 mapping of - units of each type. + def to(self, new_unit: str) -> Self: + """Conversion to a new_unit. Right now, only supports + 1 to 1 mapping of units of each type. Args: new_unit: New unit type. Returns: - A FloatWithUnit object in the new units. + A FloatWithUnit in the new units. Example usage: - >>> e = Energy(1.1, "eV") - >>> e = Energy(1.1, "Ha") - >>> e.to("eV") + >>> energy = Energy(1.1, "eV") + >>> energy = Energy(1.1, "Ha") + >>> energy.to("eV") 29.932522246 eV """ - return FloatWithUnit( + return type(self)( self * self.unit.get_conversion_factor(new_unit), - unit_type=self._unit_type, + unit_type=self.unit_type, unit=new_unit, ) @@ -435,14 +444,17 @@ def as_base_units(self): """This FloatWithUnit in base SI units, including derived units. Returns: - A FloatWithUnit object in base SI units + A FloatWithUnit in base SI units """ return self.to(self.unit.as_base_units[0]) @property - def supported_units(self): + def supported_units(self) -> tuple: """Supported units for specific unit type.""" - return tuple(ALL_UNITS[self._unit_type]) + if self.unit_type is None: + raise RuntimeError("Cannot get supported unit for None.") + + return tuple(ALL_UNITS[self.unit_type]) class ArrayWithUnit(np.ndarray): @@ -450,7 +462,7 @@ class ArrayWithUnit(np.ndarray): use the pre-defined unit type subclasses such as EnergyArray, LengthArray, etc. instead of using ArrayWithFloatWithUnit directly. - Supports conversion, addition and subtraction of the same unit type. e.g. + Support conversion, addition and subtraction of the same unit type. e.g. 1 m + 20 cm will be automatically converted to 1.2 m (units follow the leftmost quantity). @@ -463,35 +475,30 @@ class ArrayWithUnit(np.ndarray): array([ 28.21138386, 56.42276772]) eV """ - def __new__(cls, input_array, unit, unit_type=None) -> Self: + def __new__( + cls, + input_array: NDArray, + unit: str, + unit_type: str | None = None, + ) -> Self: """Override __new__.""" # Input array is an already formed ndarray instance # We first cast to be our class type obj = np.asarray(input_array).view(cls) - # add the new attributes to the created instance + # Add the new attributes to the created instance obj._unit = Unit(unit) obj._unit_type = unit_type return obj - def __array_finalize__(self, obj): - """See http://docs.scipy.org/doc/numpy/user/basics.subclassing.html for - comments. + def __array_finalize__(self, obj) -> None: + """See http://docs.scipy.org/doc/numpy/user/basics.subclassing.html + for comments. """ if obj is None: return self._unit = getattr(obj, "_unit", None) self._unit_type = getattr(obj, "_unit_type", None) - @property - def unit_type(self) -> str: - """The type of unit. Energy, Charge, etc.""" - return self._unit_type - - @property - def unit(self) -> str: - """The unit, e.g. "eV".""" - return self._unit - def __reduce__(self): reduce = list(super().__reduce__()) reduce[2] = {"np_state": reduce[2], "_unit": self._unit} @@ -565,7 +572,17 @@ def __truediv__(self, other): def __neg__(self): return type(self)(-np.array(self), unit_type=self.unit_type, unit=self.unit) - def to(self, new_unit): + @property + def unit_type(self) -> str | None: + """The type of unit. Energy, Charge, etc.""" + return self._unit_type + + @property + def unit(self) -> Unit: + """The unit, e.g. "eV".""" + return self._unit + + def to(self, new_unit: str) -> Self: """Conversion to a new_unit. Args: @@ -597,12 +614,12 @@ def as_base_units(self): # TODO abstract base class property? @property - def supported_units(self): + def supported_units(self) -> dict: """Supported units for specific unit type.""" return ALL_UNITS[self.unit_type] # TODO abstract base class method? - def conversions(self): + def conversions(self) -> str: """Get a string showing the available conversions. Useful tool in interactive mode. """