Skip to content

Commit

Permalink
add some for core.units
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed May 17, 2024
1 parent a6c2806 commit b53f00f
Showing 1 changed file with 88 additions and 71 deletions.
159 changes: 88 additions & 71 deletions pymatgen/core/units.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""This module implements a FloatWithUnit, which is a subclass of float. It
also defines supported units for some commonly used units for energy, length,
temperature, time and charge. FloatWithUnit also support conversion to one
"""This module implements FloatWithUnit, a subclass of float. It
also defines supported units for commonly used units for energy, length,
temperature, time and charge. FloatWithUnit also support conversion to
another, and additions and subtractions perform automatic conversion if
units are detected. An ArrayWithUnit is also implemented, which is a subclass
units are detected. An ArrayWithUnit is also implemented, a subclass
of numpy's ndarray with similar unit features.
"""

Expand All @@ -19,8 +19,10 @@
import scipy.constants as const

if TYPE_CHECKING:
from collections.abc import Iterator
from typing import Any

from numpy.typing import NDArray
from typing_extensions import Self

__author__ = "Shyue Ping Ong, Matteo Giantomassi"
Expand Down Expand Up @@ -167,7 +169,7 @@ class Unit(collections.abc.Mapping):
Only integer powers are supported for units.
"""

def __init__(self, unit_def) -> None:
def __init__(self, unit_def: str | dict[str, int]) -> None:
"""
Args:
unit_def: A definition for the unit. Either a mapping of unit to
Expand All @@ -180,34 +182,34 @@ def __init__(self, unit_def) -> None:
unit: dict[str, int] = defaultdict(int)

for match in re.finditer(r"([A-Za-z]+)\s*\^*\s*([\-0-9]*)", unit_def):
val = match.group(2)
val = 1 if not val else int(val)
key = match.group(1)
val = match[2]
val = int(val) if val else 1
key = match[1]
unit[key] += val
else:
unit = {k: v for k, v in dict(unit_def).items() if v != 0}
self._unit = _check_mappings(unit)

def __mul__(self, other):
new_units = defaultdict(int)
def __mul__(self, other: Self) -> Self:
new_units: defaultdict = defaultdict(int)
for k, v in self.items():
new_units[k] += v
for k, v in other.items():
new_units[k] += v
return Unit(new_units)
return type(self)(new_units)

def __truediv__(self, other):
new_units = defaultdict(int)
def __truediv__(self, other: Self) -> Self:
new_units: defaultdict = defaultdict(int)
for k, v in self.items():
new_units[k] += v
for k, v in other.items():
new_units[k] -= v
return Unit(new_units)
return type(self)(new_units)

def __pow__(self, i):
return Unit({k: v * i for k, v in self.items()})
def __pow__(self, i: Self) -> Self:
return type(self)({k: v * i for k, v in self.items()})

def __iter__(self):
def __iter__(self) -> Iterator:
return iter(self._unit)

def __getitem__(self, i) -> int:
Expand All @@ -223,15 +225,15 @@ def __repr__(self) -> str:
)

@property
def as_base_units(self):
def as_base_units(self) -> tuple[dict, float]:
"""Convert all units to base SI units, including derived units.
Returns:
tuple[dict, float]: (base_units_dict, scaling factor). base_units_dict will not
contain any constants, which are gathered in the scaling factor.
"""
b = defaultdict(int)
factor = 1
base_units: defaultdict = defaultdict(int)
factor: float = 1
for k, v in self.items():
derived = False
for dct in DERIVED_UNITS.values():
Expand All @@ -240,16 +242,16 @@ def as_base_units(self):
if isinstance(k2, Number):
factor *= k2 ** (v2 * v)
else:
b[k2] += v2 * v
base_units[k2] += v2 * v
derived = True
break
if not derived:
si, f = _get_si_unit(k)
b[si] += v
base_units[si] += v
factor *= f**v
return {k: v for k, v in b.items() if v != 0}, factor
return {k: v for k, v in base_units.items() if v != 0}, factor

def get_conversion_factor(self, new_unit):
def get_conversion_factor(self, new_unit: str) -> float:
"""Get a conversion factor between this unit and a new unit.
Compound units are supported, but must have the same powers in each
unit type.
Expand All @@ -258,9 +260,10 @@ def get_conversion_factor(self, new_unit):
new_unit: The new unit.
"""
old_base, old_factor = self.as_base_units
new_base, new_factor = Unit(new_unit).as_base_units
new_base, new_factor = type(self)(new_unit).as_base_units
units_new = sorted(new_base.items(), key=lambda d: _UNAME2UTYPE[d[0]])
units_old = sorted(old_base.items(), key=lambda d: _UNAME2UTYPE[d[0]])

factor = old_factor / new_factor
for old, new in zip(units_old, units_new):
if old[1] != new[1]:
Expand Down Expand Up @@ -292,16 +295,22 @@ class FloatWithUnit(float):
32.932522246000005 eV
"""

def __init__(self, val: float | Number, unit: str, unit_type: str | None = None) -> None:
def __init__(
self,
val: float | Number,
unit: str,
unit_type: str | None = None,
) -> None:
"""Initialize a float with unit.
Args:
val (float): Value
unit (Unit): A unit. e.g. "C".
unit (str): A unit. e.g. "C".
unit_type (str): A type of unit. e.g. "charge"
"""
if unit_type is not None and str(unit) not in ALL_UNITS[unit_type]:
raise UnitError(f"{unit} is not a supported unit for {unit_type}")

self._unit = Unit(unit)
self._unit_type = unit_type

Expand All @@ -323,7 +332,7 @@ def __add__(self, other):
val = other
if other.unit != self._unit:
val = other.to(self._unit)
return FloatWithUnit(float(self) + val, unit_type=self._unit_type, unit=self._unit)
return type(self)(float(self) + val, unit_type=self._unit_type, unit=self._unit)

def __sub__(self, other):
if not hasattr(other, "unit_type"):
Expand All @@ -333,29 +342,29 @@ def __sub__(self, other):
val = other
if other.unit != self._unit:
val = other.to(self._unit)
return FloatWithUnit(float(self) - val, unit_type=self._unit_type, unit=self._unit)
return type(self)(float(self) - val, unit_type=self._unit_type, unit=self._unit)

def __mul__(self, other):
if not isinstance(other, FloatWithUnit):
return FloatWithUnit(float(self) * other, unit_type=self._unit_type, unit=self._unit)
return FloatWithUnit(float(self) * other, unit_type=None, unit=self._unit * other._unit)
if not isinstance(other, type(self)):
return type(self)(float(self) * other, unit_type=self._unit_type, unit=self._unit)
return type(self)(float(self) * other, unit_type=None, unit=self._unit * other._unit)

def __rmul__(self, other):
if not isinstance(other, FloatWithUnit):
return FloatWithUnit(float(self) * other, unit_type=self._unit_type, unit=self._unit)
return FloatWithUnit(float(self) * other, unit_type=None, unit=self._unit * other._unit)
if not isinstance(other, type(self)):
return type(self)(float(self) * other, unit_type=self._unit_type, unit=self._unit)
return type(self)(float(self) * other, unit_type=None, unit=self._unit * other._unit)

def __pow__(self, i):
return FloatWithUnit(float(self) ** i, unit_type=None, unit=self._unit**i)
return type(self)(float(self) ** i, unit_type=None, unit=self._unit**i)

def __truediv__(self, other):
val = super().__truediv__(other)
if not isinstance(other, FloatWithUnit):
return FloatWithUnit(val, unit_type=self._unit_type, unit=self._unit)
return FloatWithUnit(val, unit_type=None, unit=self._unit / other._unit)
if not isinstance(other, type(self)):
return type(self)(val, unit_type=self._unit_type, unit=self._unit)
return type(self)(val, unit_type=None, unit=self._unit / other._unit)

def __neg__(self):
return FloatWithUnit(super().__neg__(), unit_type=self._unit_type, unit=self._unit)
return type(self)(super().__neg__(), unit_type=self._unit_type, unit=self._unit)

def __getnewargs__(self):
"""Used by pickle to recreate object."""
Expand Down Expand Up @@ -408,25 +417,25 @@ def from_str(cls, string: str) -> Self:
return cls(num, unit, unit_type=unit_type)
return cls(num, unit, unit_type=None)

def to(self, new_unit):
"""Conversion to a new_unit. Right now, only supports 1 to 1 mapping of
units of each type.
def to(self, new_unit: str) -> Self:
"""Conversion to a new_unit. Right now, only supports
1 to 1 mapping of units of each type.
Args:
new_unit: New unit type.
Returns:
A FloatWithUnit object in the new units.
A FloatWithUnit in the new units.
Example usage:
>>> e = Energy(1.1, "eV")
>>> e = Energy(1.1, "Ha")
>>> e.to("eV")
>>> energy = Energy(1.1, "eV")
>>> energy = Energy(1.1, "Ha")
>>> energy.to("eV")
29.932522246 eV
"""
return FloatWithUnit(
return type(self)(
self * self.unit.get_conversion_factor(new_unit),
unit_type=self._unit_type,
unit_type=self.unit_type,
unit=new_unit,
)

Expand All @@ -435,22 +444,25 @@ def as_base_units(self):
"""This FloatWithUnit in base SI units, including derived units.
Returns:
A FloatWithUnit object in base SI units
A FloatWithUnit in base SI units
"""
return self.to(self.unit.as_base_units[0])

@property
def supported_units(self):
def supported_units(self) -> tuple:
"""Supported units for specific unit type."""
return tuple(ALL_UNITS[self._unit_type])
if self.unit_type is None:
raise RuntimeError("Cannot get supported unit for None.")

return tuple(ALL_UNITS[self.unit_type])


class ArrayWithUnit(np.ndarray):
"""Subclasses numpy.ndarray to attach a unit type. Typically, you should
use the pre-defined unit type subclasses such as EnergyArray,
LengthArray, etc. instead of using ArrayWithFloatWithUnit directly.
Supports conversion, addition and subtraction of the same unit type. e.g.
Support conversion, addition and subtraction of the same unit type. e.g.
1 m + 20 cm will be automatically converted to 1.2 m (units follow the
leftmost quantity).
Expand All @@ -463,35 +475,30 @@ class ArrayWithUnit(np.ndarray):
array([ 28.21138386, 56.42276772]) eV
"""

def __new__(cls, input_array, unit, unit_type=None) -> Self:
def __new__(
cls,
input_array: NDArray,
unit: str,
unit_type: str | None = None,
) -> Self:
"""Override __new__."""
# Input array is an already formed ndarray instance
# We first cast to be our class type
obj = np.asarray(input_array).view(cls)
# add the new attributes to the created instance
# Add the new attributes to the created instance
obj._unit = Unit(unit)
obj._unit_type = unit_type
return obj

def __array_finalize__(self, obj):
"""See http://docs.scipy.org/doc/numpy/user/basics.subclassing.html for
comments.
def __array_finalize__(self, obj) -> None:
"""See http://docs.scipy.org/doc/numpy/user/basics.subclassing.html
for comments.
"""
if obj is None:
return
self._unit = getattr(obj, "_unit", None)
self._unit_type = getattr(obj, "_unit_type", None)

@property
def unit_type(self) -> str:
"""The type of unit. Energy, Charge, etc."""
return self._unit_type

@property
def unit(self) -> str:
"""The unit, e.g. "eV"."""
return self._unit

def __reduce__(self):
reduce = list(super().__reduce__())
reduce[2] = {"np_state": reduce[2], "_unit": self._unit}
Expand Down Expand Up @@ -565,7 +572,17 @@ def __truediv__(self, other):
def __neg__(self):
return type(self)(-np.array(self), unit_type=self.unit_type, unit=self.unit)

def to(self, new_unit):
@property
def unit_type(self) -> str | None:
"""The type of unit. Energy, Charge, etc."""
return self._unit_type

@property
def unit(self) -> Unit:
"""The unit, e.g. "eV"."""
return self._unit

def to(self, new_unit: str) -> Self:
"""Conversion to a new_unit.
Args:
Expand Down Expand Up @@ -597,12 +614,12 @@ def as_base_units(self):

# TODO abstract base class property?
@property
def supported_units(self):
def supported_units(self) -> dict:
"""Supported units for specific unit type."""
return ALL_UNITS[self.unit_type]

# TODO abstract base class method?
def conversions(self):
def conversions(self) -> str:
"""Get a string showing the available conversions.
Useful tool in interactive mode.
"""
Expand Down

0 comments on commit b53f00f

Please sign in to comment.