diff --git a/pymatgen/core/trajectory.py b/pymatgen/core/trajectory.py index 925020f4297..9d2294e6ae4 100644 --- a/pymatgen/core/trajectory.py +++ b/pymatgen/core/trajectory.py @@ -8,7 +8,7 @@ import warnings from fnmatch import fnmatch from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union, cast import numpy as np from monty.io import zopen @@ -20,16 +20,19 @@ if TYPE_CHECKING: from collections.abc import Iterator + from typing import Any from typing_extensions import Self - from pymatgen.util.typing import Matrix3D, SitePropsType, Vector3D + from pymatgen.util.typing import Matrix3D, PathLike, SitePropsType, Vector3D __author__ = "Eric Sivonxay, Shyam Dwaraknath, Mingjian Wen, Evan Spotte-Smith" __version__ = "0.1" __date__ = "Jun 29, 2022" +ValidIndex = Union[int, slice, list[int], np.ndarray] + class Trajectory(MSONable): """Trajectory of a geometry optimization or molecular dynamics simulation. @@ -123,7 +126,7 @@ def __init__( if isinstance(lattice, Lattice): lattice = lattice.matrix elif isinstance(lattice, list) and isinstance(lattice[0], Lattice): - lattice = [x.matrix for x in lattice] # type: ignore + lattice = [cast(Lattice, x).matrix for x in lattice] lattice = np.asarray(lattice) if not constant_lattice and lattice.shape == (3, 3): @@ -158,6 +161,105 @@ def __init__( self._check_frame_props(frame_properties) self.frame_properties = frame_properties + def __iter__(self) -> Iterator[Structure | Molecule]: + """Iterator of the trajectory, yielding a pymatgen Structure or Molecule for each frame.""" + for idx in range(len(self)): + yield self[idx] + + def __len__(self) -> int: + """Number of frames in the trajectory.""" + return len(self.coords) + + def __getitem__(self, frames: ValidIndex) -> Molecule | Structure | Self: + """Get a subset of the trajectory. + + The output depends on the type of the input `frames`. If an int is given, return + a pymatgen Molecule or Structure at the specified frame. If a list or a slice, return a new + trajectory with a subset of frames. + + Args: + frames: Indices of the trajectory to return. + + Returns: + Subset of trajectory + """ + # Convert to position mode if not already + self.to_positions() + + # For integer input, return the structure at that frame + if isinstance(frames, int): + if frames >= len(self): + raise IndexError(f"index={frames} out of range, trajectory only has {len(self)} frames") + + if self.lattice is None: + charge = 0 if self.charge is None else int(self.charge) + spin = None if self.spin_multiplicity is None else int(self.spin_multiplicity) + + return Molecule( + self.species, + self.coords[frames], + charge=charge, + spin_multiplicity=spin, + site_properties=self._get_site_props(frames), # type: ignore[arg-type] + ) + + lattice = self.lattice if self.constant_lattice else self.lattice[frames] + + return Structure( + Lattice(lattice), + self.species, + self.coords[frames], + site_properties=self._get_site_props(frames), # type: ignore[arg-type] + to_unit_cell=True, + ) + + # For slice input, return a trajectory + if isinstance(frames, (slice, list, np.ndarray)): + if isinstance(frames, slice): + start, stop, step = frames.indices(len(self)) + selected = list(range(start, stop, step)) + else: + # Get rid of frames that exceed trajectory length + selected = [idx for idx in frames if idx < len(self)] + + if len(selected) < len(frames): + bad_frames = [idx for idx in frames if idx > len(self)] + raise IndexError(f"index={bad_frames} out of range, trajectory only has {len(self)} frames") + + coords = self.coords[selected] + frame_properties = ( + None if self.frame_properties is None else [self.frame_properties[idx] for idx in selected] + ) + + if self.lattice is None: + return type(self)( + species=self.species, + coords=coords, + charge=self.charge, + spin_multiplicity=self.spin_multiplicity, + site_properties=self._get_site_props(selected), + frame_properties=frame_properties, + time_step=self.time_step, + coords_are_displacement=False, + base_positions=self.base_positions, + ) + + lattice = self.lattice if self.constant_lattice else self.lattice[selected] + + return type(self)( + species=self.species, + coords=coords, + lattice=lattice, + site_properties=self._get_site_props(selected), + frame_properties=frame_properties, + constant_lattice=self.constant_lattice, + time_step=self.time_step, + coords_are_displacement=False, + base_positions=self.base_positions, + ) + + raise TypeError(f"bad index={frames!r}, expected one of {str(ValidIndex).split('Union')[1]}") + def get_structure(self, idx: int) -> Structure: """Get structure at specified index. @@ -290,115 +392,9 @@ def extend(self, trajectory: Trajectory) -> None: # len(self) is used there. self.coords = np.concatenate((self.coords, trajectory.coords)) - def __iter__(self) -> Iterator[Structure | Molecule]: - """Iterator of the trajectory, yielding a pymatgen Structure or Molecule for each frame.""" - for idx in range(len(self)): - yield self[idx] - - def __len__(self) -> int: - """Number of frames in the trajectory.""" - return len(self.coords) - - def __getitem__(self, frames: int | slice | list[int]) -> Molecule | Structure | Trajectory: - """Get a subset of the trajectory. - - The output depends on the type of the input `frames`. If an int is given, return - a pymatgen Molecule or Structure at the specified frame. If a list or a slice, return a new - trajectory with a subset of frames. - - Args: - frames: Indices of the trajectory to return. - - Returns: - Subset of trajectory - """ - # Convert to position mode if not already - self.to_positions() - - # For integer input, return the structure at that frame - if isinstance(frames, int): - if frames >= len(self): - raise IndexError(f"Frame index {frames} out of range.") - - if self.lattice is None: - charge = 0 - if self.charge is not None: - charge = int(self.charge) - - spin = None - if self.spin_multiplicity is not None: - spin = int(self.spin_multiplicity) - - return Molecule( - self.species, - self.coords[frames], - charge=charge, - spin_multiplicity=spin, - site_properties=self._get_site_props(frames), # type: ignore - ) - - lattice = self.lattice if self.constant_lattice else self.lattice[frames] # type: ignore - - return Structure( - Lattice(lattice), - self.species, - self.coords[frames], - site_properties=self._get_site_props(frames), # type: ignore - to_unit_cell=True, - ) - - # For slice input, return a trajectory - if isinstance(frames, (slice, list, np.ndarray)): - if isinstance(frames, slice): - start, stop, step = frames.indices(len(self)) - selected = list(range(start, stop, step)) - else: - # Get rid of frames that exceed trajectory length - selected = [i for i in frames if i < len(self)] - - if len(selected) < len(frames): - bad_frames = [i for i in frames if i > len(self)] - raise IndexError(f"Frame index {bad_frames} out of range.") - - coords = self.coords[selected] - if self.frame_properties is not None: - frame_properties = [self.frame_properties[i] for i in selected] - else: - frame_properties = None - - if self.lattice is None: - return Trajectory( - species=self.species, - coords=coords, - charge=self.charge, - spin_multiplicity=self.spin_multiplicity, - site_properties=self._get_site_props(selected), - frame_properties=frame_properties, - time_step=self.time_step, - coords_are_displacement=False, - base_positions=self.base_positions, - ) - - lattice = self.lattice if self.constant_lattice else self.lattice[selected] # type: ignore - - return Trajectory( - species=self.species, - coords=coords, - lattice=lattice, - site_properties=self._get_site_props(selected), - frame_properties=frame_properties, - constant_lattice=self.constant_lattice, - time_step=self.time_step, - coords_are_displacement=False, - base_positions=self.base_positions, - ) - - supported = [int, slice, list or np.ndarray] - raise ValueError(f"Expect the type of frames be one of {supported}; {type(frames)}.") - def write_Xdatcar( self, - filename: str | Path = "XDATCAR", + filename: PathLike = "XDATCAR", system: str | None = None, significant_figures: int = 6, ) -> None: @@ -408,7 +404,7 @@ def write_Xdatcar( Xdatcar_from_structs.get_str method and are passed through directly. Args: - filename: Name of file to write. It's prudent to end the filename with + filename: File to write. It's prudent to end the filename with 'XDATCAR', as most visualization and analysis software require this for autodetection. system: Description of system (e.g. 2D MoS2). @@ -435,7 +431,7 @@ def write_Xdatcar( if idx == 0 or not self.constant_lattice: lines.extend([system, "1.0"]) - _lattice = self.lattice if self.constant_lattice else self.lattice[idx] # type: ignore + _lattice = self.lattice if self.constant_lattice else self.lattice[idx] for latt_vec in _lattice: lines.append(f'{" ".join(map(str, latt_vec))}') @@ -448,10 +444,10 @@ def write_Xdatcar( line = f'{" ".join(format_str.format(c) for c in coord)} {specie}' lines.append(line) - xdatcar_string = "\n".join(lines) + "\n" + xdatcar_str = "\n".join(lines) + "\n" with zopen(filename, mode="wt") as file: - file.write(xdatcar_string) + file.write(xdatcar_str) def as_dict(self) -> dict: """Return the trajectory as a MSONable dict.""" @@ -493,15 +489,15 @@ def from_structures(cls, structures: list[Structure], constant_lattice: bool = T else: lattice = np.array([structure.lattice.matrix for structure in structures]) - species = structures[0].species + species: list[Element | Species] = structures[0].species coords = [structure.frac_coords for structure in structures] site_properties = [structure.site_properties for structure in structures] return cls( - species=species, # type: ignore + species=species, # type: ignore[arg-type] coords=coords, lattice=lattice, - site_properties=site_properties, # type: ignore + site_properties=site_properties, constant_lattice=constant_lattice, **kwargs, ) @@ -524,11 +520,11 @@ def from_molecules(cls, molecules: list[Molecule], **kwargs) -> Self: site_properties = [mol.site_properties for mol in molecules] return cls( - species=species, # type: ignore + species=species, # type: ignore[arg-type] coords=coords, charge=int(molecules[0].charge), spin_multiplicity=int(molecules[0].spin_multiplicity), - site_properties=site_properties, # type: ignore + site_properties=site_properties, **kwargs, ) @@ -561,7 +557,7 @@ def from_file(cls, filename: str | Path, constant_lattice: bool = True, **kwargs from ase.io.trajectory import Trajectory as AseTrajectory ase_traj = AseTrajectory(filename) - # periodic boundary conditions should be the same for all frames so just check the first + # Periodic boundary conditions should be the same for all frames so just check the first pbc = ase_traj[0].pbc if any(pbc): structures = [AseAtomsAdaptor.get_structure(atoms) for atoms in ase_traj] @@ -582,7 +578,12 @@ def from_file(cls, filename: str | Path, constant_lattice: bool = True, **kwargs return cls.from_structures(structures, constant_lattice=constant_lattice, **kwargs) @staticmethod - def _combine_lattice(lat1: np.ndarray, lat2: np.ndarray, len1: int, len2: int) -> tuple[np.ndarray, bool]: + def _combine_lattice( + lat1: np.ndarray, + lat2: np.ndarray, + len1: int, + len2: int, + ) -> tuple[np.ndarray, bool]: """Helper function to combine trajectory lattice.""" if lat1.ndim == lat2.ndim == 2: constant_lat = True @@ -599,51 +600,57 @@ def _combine_lattice(lat1: np.ndarray, lat2: np.ndarray, len1: int, len2: int) - @staticmethod def _combine_site_props( - prop1: SitePropsType | None, prop2: SitePropsType | None, len1: int, len2: int + prop1: SitePropsType | None, + prop2: SitePropsType | None, + len1: int, + len2: int, ) -> SitePropsType | None: """Combine site properties. Either one of prop1 or prop2 can be None, dict, or a list of dict. All possibilities of combining them are considered. """ - # special cases - + # Special cases if prop1 is prop2 is None: return None if isinstance(prop1, dict) and prop1 == prop2: return prop1 - # general case - + # General case assert prop1 is None or isinstance(prop1, (list, dict)) assert prop2 is None or isinstance(prop2, (list, dict)) - p1_candidates = { + p1_candidates: dict[str, Any] = { "NoneType": [None] * len1, "dict": [prop1] * len1, "list": prop1, } - p2_candidates = { + p2_candidates: dict[str, Any] = { "NoneType": [None] * len2, "dict": [prop2] * len2, "list": prop2, } - p1_selected: list = p1_candidates[type(prop1).__name__] # type: ignore - p2_selected: list = p2_candidates[type(prop2).__name__] # type: ignore + p1_selected: list = p1_candidates[type(prop1).__name__] + p2_selected: list = p2_candidates[type(prop2).__name__] return p1_selected + p2_selected @staticmethod - def _combine_frame_props(prop1: list[dict] | None, prop2: list[dict] | None, len1: int, len2: int) -> list | None: + def _combine_frame_props( + prop1: list[dict] | None, + prop2: list[dict] | None, + len1: int, + len2: int, + ) -> list | None: """Combine frame properties.""" if prop1 is prop2 is None: return None if prop1 is None: - return [None] * len1 + list(prop2) # type: ignore + return [None] * len1 + list(cast(list[dict], prop2)) if prop2 is None: - return list(prop1) + [None] * len2 # type: ignore - return list(prop1) + list(prop2) # type:ignore + return list(prop1) + [None] * len2 + return list(prop1) + list(prop2) def _check_site_props(self, site_props: SitePropsType | None) -> None: """Check data shape of site properties. @@ -660,10 +667,10 @@ def _check_site_props(self, site_props: SitePropsType | None) -> None: if isinstance(site_props, dict): site_props = [site_props] - else: - assert len(site_props) == len( - self - ), f"Size of the site properties {len(site_props)} does not equal to the number of frames {len(self)}." + elif len(site_props) != len(self): + raise AssertionError( + f"Size of the site properties {len(site_props)} does not equal to the number of frames {len(self)}" + ) n_sites = len(self.coords[0]) for dct in site_props: @@ -678,11 +685,12 @@ def _check_frame_props(self, frame_props: list[dict] | None) -> None: if frame_props is None: return - assert len(frame_props) == len( - self - ), f"Size of the frame properties {len(frame_props)} does not equal to the number of frames {len(self)}." + if len(frame_props) != len(self): + raise AssertionError( + f"Size of the frame properties {len(frame_props)} does not equal to the number of frames {len(self)}" + ) - def _get_site_props(self, frames: int | list[int]) -> SitePropsType | None: + def _get_site_props(self, frames: ValidIndex) -> SitePropsType | None: """Slice site properties.""" if self.site_properties is None: return None @@ -692,6 +700,6 @@ def _get_site_props(self, frames: int | list[int]) -> SitePropsType | None: if isinstance(frames, int): return self.site_properties[frames] if isinstance(frames, list): - return [self.site_properties[i] for i in frames] + return [self.site_properties[idx] for idx in frames] raise ValueError("Unexpected frames type.") raise ValueError("Unexpected site_properties type.") diff --git a/pymatgen/core/units.py b/pymatgen/core/units.py index 11364373f98..9abff8cffe4 100644 --- a/pymatgen/core/units.py +++ b/pymatgen/core/units.py @@ -1,9 +1,11 @@ -"""This module implements a FloatWithUnit, which is a subclass of float. It -also defines supported units for some commonly used units for energy, length, -temperature, time and charge. FloatWithUnit also support conversion to one -another, and additions and subtractions perform automatic conversion if -units are detected. An ArrayWithUnit is also implemented, which is a subclass -of numpy's ndarray with similar unit features. +"""This module defines commonly used units +for energy, length, temperature, time and charge. + +Also defines the following classes: +- FloatWithUnit, a subclass of float, which supports + conversion to another, and additions and subtractions + perform automatic conversion if units are detected. +- ArrayWithUnit, a subclass of numpy's ndarray with similar unit features. """ from __future__ import annotations @@ -13,14 +15,16 @@ from collections import defaultdict from functools import partial from numbers import Number -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import numpy as np import scipy.constants as const if TYPE_CHECKING: + from collections.abc import Iterator from typing import Any + from numpy.typing import NDArray from typing_extensions import Self __author__ = "Shyue Ping Ong, Matteo Giantomassi" @@ -142,32 +146,16 @@ _UNAME2UTYPE = {uname: utype for utype, dct in ALL_UNITS.items() for uname in dct} -def _get_si_unit(unit): - unit_type = _UNAME2UTYPE[unit] - si_unit = filter(lambda k: BASE_UNITS[unit_type][k] == 1, BASE_UNITS[unit_type]) - return next(iter(si_unit)), BASE_UNITS[unit_type][unit] - - class UnitError(BaseException): """Exception class for unit errors.""" -def _check_mappings(u): - for v in DERIVED_UNITS.values(): - for k2, v2 in v.items(): - if all(v2.get(ku, 0) == vu for ku, vu in u.items()) and all( - u.get(kv2, 0) == vv2 for kv2, vv2 in v2.items() - ): - return {k2: 1} - return u - - class Unit(collections.abc.Mapping): - """Represents a unit, e.g. "m" for meters, etc. Supports compound units. - Only integer powers are supported for units. + """Represent a unit, e.g. "m" for meters, etc. Supports compound units. + Only integer powers are supported. """ - def __init__(self, unit_def) -> None: + def __init__(self, unit_def: str | dict[str, int]) -> None: """ Args: unit_def: A definition for the unit. Either a mapping of unit to @@ -176,38 +164,49 @@ def __init__(self, unit_def) -> None: format uses "^" as the power operator and all units must be space-separated. """ + + def check_mappings(u): + for v in DERIVED_UNITS.values(): + for k2, v2 in v.items(): + if all(v2.get(ku, 0) == vu for ku, vu in u.items()) and all( + u.get(kv2, 0) == vv2 for kv2, vv2 in v2.items() + ): + return {k2: 1} + return u + if isinstance(unit_def, str): unit: dict[str, int] = defaultdict(int) for match in re.finditer(r"([A-Za-z]+)\s*\^*\s*([\-0-9]*)", unit_def): - val = match.group(2) - val = 1 if not val else int(val) - key = match.group(1) + val = match[2] + val = int(val) if val else 1 + key = match[1] unit[key] += val else: unit = {k: v for k, v in dict(unit_def).items() if v != 0} - self._unit = _check_mappings(unit) - def __mul__(self, other): - new_units = defaultdict(int) + self._unit = check_mappings(unit) + + def __mul__(self, other: Self) -> Self: + new_units: defaultdict = defaultdict(int) for k, v in self.items(): new_units[k] += v for k, v in other.items(): new_units[k] += v - return Unit(new_units) + return type(self)(new_units) - def __truediv__(self, other): - new_units = defaultdict(int) + def __truediv__(self, other: Self) -> Self: + new_units: defaultdict = defaultdict(int) for k, v in self.items(): new_units[k] += v for k, v in other.items(): new_units[k] -= v - return Unit(new_units) + return type(self)(new_units) - def __pow__(self, i): - return Unit({k: v * i for k, v in self.items()}) + def __pow__(self, i: Self) -> Self: + return type(self)({k: v * i for k, v in self.items()}) - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self._unit) def __getitem__(self, i) -> int: @@ -223,15 +222,21 @@ def __repr__(self) -> str: ) @property - def as_base_units(self): + def as_base_units(self) -> tuple[dict, float]: """Convert all units to base SI units, including derived units. Returns: tuple[dict, float]: (base_units_dict, scaling factor). base_units_dict will not contain any constants, which are gathered in the scaling factor. """ - b = defaultdict(int) - factor = 1 + + def get_si_unit(unit): + unit_type = _UNAME2UTYPE[unit] + si_unit = filter(lambda k: BASE_UNITS[unit_type][k] == 1, BASE_UNITS[unit_type]) + return next(iter(si_unit)), BASE_UNITS[unit_type][unit] + + base_units: defaultdict = defaultdict(int) + factor: float = 1 for k, v in self.items(): derived = False for dct in DERIVED_UNITS.values(): @@ -240,28 +245,32 @@ def as_base_units(self): if isinstance(k2, Number): factor *= k2 ** (v2 * v) else: - b[k2] += v2 * v + base_units[k2] += v2 * v derived = True break if not derived: - si, f = _get_si_unit(k) - b[si] += v + si, f = get_si_unit(k) + base_units[si] += v factor *= f**v - return {k: v for k, v in b.items() if v != 0}, factor + return {k: v for k, v in base_units.items() if v != 0}, factor - def get_conversion_factor(self, new_unit): - """Get a conversion factor between this unit and a new unit. + def get_conversion_factor(self, new_unit: str | Unit) -> float: + """Get the conversion factor between this unit and a new unit. Compound units are supported, but must have the same powers in each unit type. Args: - new_unit: The new unit. + new_unit (str | Unit): The new unit. """ + _new_unit: str = repr(new_unit) if isinstance(new_unit, Unit) else new_unit + old_base, old_factor = self.as_base_units - new_base, new_factor = Unit(new_unit).as_base_units + new_base, new_factor = type(self)(_new_unit).as_base_units + units_new = sorted(new_base.items(), key=lambda d: _UNAME2UTYPE[d[0]]) units_old = sorted(old_base.items(), key=lambda d: _UNAME2UTYPE[d[0]]) - factor = old_factor / new_factor + factor: float = old_factor / new_factor + for old, new in zip(units_old, units_new): if old[1] != new[1]: raise UnitError(f"Units {old} and {new} are not compatible!") @@ -275,40 +284,51 @@ class FloatWithUnit(float): pre-defined unit type subclasses such as Energy, Length, etc. instead of using FloatWithUnit directly. - Supports conversion, addition and subtraction of the same unit type. e.g. + Support conversion, addition and subtraction of the same unit type. e.g. 1 m + 20 cm will be automatically converted to 1.2 m (units follow the leftmost quantity). Note that FloatWithUnit does not override the eq method for float, i.e., units are not checked when testing for equality. The reason is to allow this class to be used transparently wherever floats are expected. - >>> e = Energy(1.1, "Ha") - >>> a = Energy(1.1, "Ha") - >>> b = Energy(3, "eV") - >>> c = a + b - >>> print(c) - 1.2102479761938871 Ha - >>> c.to("eV") - 32.932522246000005 eV + Example usage: + >>> energy_a = Energy(1.1, "Ha") + >>> energy_b = Energy(3, "eV") + >>> energy_c = energy_a + energy_b + >>> print(energy_c) + 1.2102479761938871 Ha + >>> energy_c.to("eV") + 32.932522246000005 eV """ - def __init__(self, val: float | Number, unit: str, unit_type: str | None = None) -> None: + def __init__( + self, + val: float | Number, + unit: str | Unit, + unit_type: str | None = None, + ) -> None: """Initialize a float with unit. Args: val (float): Value - unit (Unit): A unit. e.g. "C". + unit (str | Unit): A unit. e.g. "C". unit_type (str): A type of unit. e.g. "charge" """ if unit_type is not None and str(unit) not in ALL_UNITS[unit_type]: raise UnitError(f"{unit} is not a supported unit for {unit_type}") - self._unit = Unit(unit) + + self._unit = unit if isinstance(unit, Unit) else Unit(unit) self._unit_type = unit_type - def __new__(cls, val, unit, unit_type=None) -> Self: - """Overrides __new__ since we are subclassing a Python primitive.""" + def __new__( + cls, + val, + unit: str | Unit, + unit_type: str | None = None, + ) -> Self: + """Override __new__.""" new = float.__new__(cls, val) - new._unit = Unit(unit) + new._unit = unit if isinstance(unit, Unit) else Unit(unit) new._unit_type = unit_type return new @@ -323,7 +343,11 @@ def __add__(self, other): val = other if other.unit != self._unit: val = other.to(self._unit) - return FloatWithUnit(float(self) + val, unit_type=self._unit_type, unit=self._unit) + return type(self)( + float(self) + val, + unit_type=self._unit_type, + unit=self._unit, + ) def __sub__(self, other): if not hasattr(other, "unit_type"): @@ -333,29 +357,54 @@ def __sub__(self, other): val = other if other.unit != self._unit: val = other.to(self._unit) - return FloatWithUnit(float(self) - val, unit_type=self._unit_type, unit=self._unit) + return type(self)( + float(self) - val, + unit_type=self._unit_type, + unit=self._unit, + ) def __mul__(self, other): - if not isinstance(other, FloatWithUnit): - return FloatWithUnit(float(self) * other, unit_type=self._unit_type, unit=self._unit) - return FloatWithUnit(float(self) * other, unit_type=None, unit=self._unit * other._unit) + cls = type(self) + if not isinstance(other, cls): + return cls( + float(self) * other, + unit_type=self._unit_type, + unit=self._unit, + ) + return cls( + float(self) * other, + unit_type=None, + unit=self._unit * other._unit, + ) def __rmul__(self, other): - if not isinstance(other, FloatWithUnit): - return FloatWithUnit(float(self) * other, unit_type=self._unit_type, unit=self._unit) - return FloatWithUnit(float(self) * other, unit_type=None, unit=self._unit * other._unit) + if not isinstance(other, type(self)): + return type(self)( + float(self) * other, + unit_type=self._unit_type, + unit=self._unit, + ) + return type(self)( + float(self) * other, + unit_type=None, + unit=self._unit * other._unit, + ) def __pow__(self, i): - return FloatWithUnit(float(self) ** i, unit_type=None, unit=self._unit**i) + return type(self)(float(self) ** i, unit_type=None, unit=self._unit**i) def __truediv__(self, other): val = super().__truediv__(other) - if not isinstance(other, FloatWithUnit): - return FloatWithUnit(val, unit_type=self._unit_type, unit=self._unit) - return FloatWithUnit(val, unit_type=None, unit=self._unit / other._unit) + if not isinstance(other, type(self)): + return type(self)(val, unit_type=self._unit_type, unit=self._unit) + return type(self)(val, unit_type=None, unit=self._unit / other._unit) def __neg__(self): - return FloatWithUnit(super().__neg__(), unit_type=self._unit_type, unit=self._unit) + return type(self)( + super().__neg__(), + unit_type=self._unit_type, + unit=self._unit, + ) def __getnewargs__(self): """Used by pickle to recreate object.""" @@ -390,8 +439,10 @@ def unit(self) -> Unit: @classmethod def from_str(cls, string: str) -> Self: - """Parse string to FloatWithUnit. - Example: Memory.from_str("1. Mb"). + """Convert string to FloatWithUnit. + + Example usage: + Memory.from_str("1. Mb"). """ # Extract num and unit string. string = string.strip() @@ -408,90 +459,85 @@ def from_str(cls, string: str) -> Self: return cls(num, unit, unit_type=unit_type) return cls(num, unit, unit_type=None) - def to(self, new_unit): - """Conversion to a new_unit. Right now, only supports 1 to 1 mapping of - units of each type. + def to(self, new_unit: str | Unit) -> Self: + """Convert to a new unit. Right now, only support + 1 to 1 mapping of units of each type. Args: - new_unit: New unit type. + new_unit (str | Unit): New unit type. Returns: - A FloatWithUnit object in the new units. + FloatWithUnit in the new unit. Example usage: - >>> e = Energy(1.1, "eV") - >>> e = Energy(1.1, "Ha") - >>> e.to("eV") + >>> energy = Energy(1.1, "eV") + >>> energy = Energy(1.1, "Ha") + >>> energy.to("eV") 29.932522246 eV """ - return FloatWithUnit( - self * self.unit.get_conversion_factor(new_unit), - unit_type=self._unit_type, - unit=new_unit, - ) + new_value = self * self.unit.get_conversion_factor(new_unit) + return type(self)(new_value, unit_type=self._unit_type, unit=new_unit) @property def as_base_units(self): """This FloatWithUnit in base SI units, including derived units. Returns: - A FloatWithUnit object in base SI units + FloatWithUnit in base SI units """ return self.to(self.unit.as_base_units[0]) @property - def supported_units(self): + def supported_units(self) -> tuple: """Supported units for specific unit type.""" - return tuple(ALL_UNITS[self._unit_type]) + if self.unit_type is None: + raise RuntimeError("Cannot get supported unit for None.") + + return tuple(ALL_UNITS[self.unit_type]) class ArrayWithUnit(np.ndarray): """Subclasses numpy.ndarray to attach a unit type. Typically, you should use the pre-defined unit type subclasses such as EnergyArray, - LengthArray, etc. instead of using ArrayWithFloatWithUnit directly. + LengthArray, etc. instead of using ArrayWithUnit directly. - Supports conversion, addition and subtraction of the same unit type. e.g. + Support conversion, addition and subtraction of the same unit type. e.g. 1 m + 20 cm will be automatically converted to 1.2 m (units follow the leftmost quantity). - >>> a = EnergyArray([1, 2], "Ha") - >>> b = EnergyArray([1, 2], "eV") - >>> c = a + b - >>> print(c) + >>> energy_arr_a = EnergyArray([1, 2], "Ha") + >>> energy_arr_b = EnergyArray([1, 2], "eV") + >>> energy_arr_c = energy_arr_a + energy_arr_b + >>> print(energy_arr_c) [ 1.03674933 2.07349865] Ha - >>> c.to("eV") + >>> energy_arr_c.to("eV") array([ 28.21138386, 56.42276772]) eV """ - def __new__(cls, input_array, unit, unit_type=None) -> Self: + def __new__( + cls, + input_array: NDArray, + unit: str | Unit, + unit_type: str | None = None, + ) -> Self: """Override __new__.""" # Input array is an already formed ndarray instance # We first cast to be our class type obj = np.asarray(input_array).view(cls) - # add the new attributes to the created instance - obj._unit = Unit(unit) + # Add the new attributes to the created instance + obj._unit = unit if isinstance(unit, Unit) else Unit(unit) obj._unit_type = unit_type return obj - def __array_finalize__(self, obj): - """See http://docs.scipy.org/doc/numpy/user/basics.subclassing.html for - comments. + def __array_finalize__(self, obj) -> None: + """See http://docs.scipy.org/doc/numpy/user/basics.subclassing.html + for comments. """ if obj is None: return self._unit = getattr(obj, "_unit", None) self._unit_type = getattr(obj, "_unit_type", None) - @property - def unit_type(self) -> str: - """The type of unit. Energy, Charge, etc.""" - return self._unit_type - - @property - def unit(self) -> str: - """The unit, e.g. "eV".""" - return self._unit - def __reduce__(self): reduce = list(super().__reduce__()) reduce[2] = {"np_state": reduce[2], "_unit": self._unit} @@ -525,14 +571,18 @@ def __sub__(self, other): if other.unit != self.unit: other = other.to(self.unit) - return type(self)(np.array(self) - np.array(other), unit_type=self.unit_type, unit=self.unit) + return type(self)( + np.array(self) - np.array(other), + unit_type=self.unit_type, + unit=self.unit, + ) def __mul__(self, other): - # TODO Here we have the most important difference between FloatWithUnit and - # ArrayWithFloatWithUnit: - # If other does not have units, I return an object with the same units + # TODO Here we have the most important difference between + # FloatWithUnit and ArrayWithUnit: + # If other does not have units, return an object with the same units # as self. - # if other *has* units, I return an object *without* units since + # If other *has* units, return an object *without* units since # taking into account all the possible derived quantities would be # too difficult. # Moreover Energy(1.0) * Time(1.0, "s") returns 1.0 Ha that is a @@ -546,7 +596,10 @@ def __mul__(self, other): ) # Cannot use super since it returns an instance of self.__class__ # while here we want a bare numpy array. - return type(self)(np.array(self).__mul__(np.array(other)), unit=self.unit * other.unit) + return type(self)( + np.array(self).__mul__(np.array(other)), + unit=self.unit * other.unit, + ) def __rmul__(self, other): if not hasattr(other, "unit_type"): @@ -555,29 +608,48 @@ def __rmul__(self, other): unit_type=self._unit_type, unit=self._unit, ) - return type(self)(np.array(self) * np.array(other), unit=self.unit * other.unit) + return type(self)( + np.array(self) * np.array(other), + unit=self.unit * other.unit, + ) def __truediv__(self, other): if not hasattr(other, "unit_type"): - return type(self)(np.array(self) / np.array(other), unit_type=self._unit_type, unit=self._unit) - return type(self)(np.array(self) / np.array(other), unit=self.unit / other.unit) + return type(self)( + np.array(self) / np.array(other), + unit_type=self._unit_type, + unit=self._unit, + ) + return type(self)( + np.array(self) / np.array(other), + unit=self.unit / other.unit, + ) def __neg__(self): return type(self)(-np.array(self), unit_type=self.unit_type, unit=self.unit) - def to(self, new_unit): - """Conversion to a new_unit. + @property + def unit_type(self) -> str | None: + """The type of unit. Energy, Charge, etc.""" + return self._unit_type + + @property + def unit(self) -> Unit: + """The unit, e.g. "eV".""" + return cast(Unit, self._unit) + + def to(self, new_unit: str | Unit) -> Self: + """Convert to a new unit. Args: - new_unit: - New unit type. + new_unit (str | Unit): New unit type. Returns: - A ArrayWithFloatWithUnit object in the new units. + ArrayWithUnit in the new unit. Example usage: - >>> e = EnergyArray([1, 1.1], "Ha") - >>> e.to("eV") + >>> energy = EnergyArray([1, 1.1], "Ha") + >>> energy.to("eV") array([ 27.21138386, 29.93252225]) eV """ return type(self)( @@ -591,18 +663,21 @@ def as_base_units(self): """This ArrayWithUnit in base SI units, including derived units. Returns: - An ArrayWithUnit object in base SI units + ArrayWithUnit in base SI units """ return self.to(self.unit.as_base_units[0]) - # TODO abstract base class property? @property - def supported_units(self): + def supported_units(self) -> dict: + # TODO abstract base class property? """Supported units for specific unit type.""" + if self.unit_type is None: + raise RuntimeError("Cannot get supported unit for None.") + return ALL_UNITS[self.unit_type] - # TODO abstract base class method? - def conversions(self): + def conversions(self) -> str: + # TODO abstract base class method? """Get a string showing the available conversions. Useful tool in interactive mode. """ @@ -695,10 +770,13 @@ def _my_partial(func, *args, **kwargs): """ -def obj_with_unit(obj: Any, unit: str) -> FloatWithUnit | ArrayWithUnit | dict[str, FloatWithUnit | ArrayWithUnit]: +def obj_with_unit( + obj: Any, + unit: str, +) -> FloatWithUnit | ArrayWithUnit | dict[str, FloatWithUnit | ArrayWithUnit]: """Get a FloatWithUnit instance if obj is scalar, a dictionary of objects with units if obj is a dict, else an instance of - ArrayWithFloatWithUnit. + ArrayWithUnit. Args: obj (Any): Object to be given a unit. @@ -716,7 +794,7 @@ def obj_with_unit(obj: Any, unit: str) -> FloatWithUnit | ArrayWithUnit | dict[s def unitized(unit): - """Useful decorator to assign units to the output of a function. You can also + """Decorator to assign units to the output of a function. You can also use it to standardize the output units of a function that already returns a FloatWithUnit or ArrayWithUnit. For sequences, all values in the sequences are assigned the same unit. It works with Python sequences only. The creation diff --git a/tests/core/test_trajectory.py b/tests/core/test_trajectory.py index 03fe63181fe..8c2684ca208 100644 --- a/tests/core/test_trajectory.py +++ b/tests/core/test_trajectory.py @@ -1,8 +1,10 @@ from __future__ import annotations import copy +import re import numpy as np +import pytest from numpy.testing import assert_allclose from pymatgen.core.lattice import Lattice @@ -480,3 +482,11 @@ def test_from_file(self): # Check composition of the first frame of the trajectory assert traj[0].formula == "Li2 Mn2 O4" + + def test_index_error(self): + with pytest.raises(IndexError, match="index=100 out of range, trajectory only has 100 frames"): + self.traj[100] + with pytest.raises( + TypeError, match=re.escape("bad index='test', expected one of [int, slice, list[int], numpy.ndarray]") + ): + self.traj["test"] diff --git a/tests/core/test_units.py b/tests/core/test_units.py index 00cebd8272c..ae785b20f7a 100644 --- a/tests/core/test_units.py +++ b/tests/core/test_units.py @@ -169,9 +169,9 @@ def test_neg(self): assert FloatWithUnit(-5, "MPa") == -x -class TestArrayWithFloatWithUnit(PymatgenTest): +class TestArrayWithUnit(PymatgenTest): def test_energy(self): - """Similar to FloatWithUnitTest.test_energy. + """Similar to TestFloatWithUnit.test_energy. Check whether EnergyArray and FloatWithUnit have same behavior. # TODO