diff --git a/.gitignore b/.gitignore index 038a0156..9644749f 100644 --- a/.gitignore +++ b/.gitignore @@ -118,3 +118,5 @@ ENV/ # Doc folder _autosummary + +uv.lock diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 34873b08..38d94bf6 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -22,6 +22,7 @@ TYPE_CHECKING, ForwardRef, Generic, + Optional, TypeVar, get_args, ) @@ -37,7 +38,7 @@ from urllib3.util.retry import Retry from mp_api.client.core.settings import MAPIClientSettings -from mp_api.client.core.utils import api_sanitize, load_json, validate_ids +from mp_api.client.core.utils import load_json, validate_ids try: import boto3 @@ -54,6 +55,8 @@ if TYPE_CHECKING: from typing import Any, Callable + from pydantic.fields import FieldInfo + try: __version__ = version("mp_api") except PackageNotFoundError: # pragma: no cover @@ -147,12 +150,6 @@ def __init__( else: self._s3_client = None - self.document_model = ( - api_sanitize(self.document_model) # type: ignore - if self.document_model is not None - else None # type: ignore - ) - @property def session(self) -> requests.Session: if not self._session: @@ -1043,10 +1040,8 @@ def _convert_to_model(self, data: list[dict]): (list[MPDataDoc]): List of MPDataDoc objects """ - raw_doc_list = [self.document_model.model_validate(d) for d in data] # type: ignore - - if len(raw_doc_list) > 0: - data_model, set_fields, _ = self._generate_returned_model(raw_doc_list[0]) + if len(data) > 0: + data_model, set_fields, _ = self._generate_returned_model(data[0]) data = [ data_model( @@ -1056,33 +1051,37 @@ def _convert_to_model(self, data: list[dict]): if field in set_fields } ) - for raw_doc in raw_doc_list + for raw_doc in data ] return data - def _generate_returned_model(self, doc): + def _generate_returned_model( + self, doc: dict[str, Any] + ) -> tuple[BaseModel, list[str], list[str]]: model_fields = self.document_model.model_fields - - set_fields = doc.model_fields_set + set_fields = [k for k in doc if k in model_fields] unset_fields = [field for field in model_fields if field not in set_fields] # Update with locals() from external module if needed - other_vars = {} if any( + isinstance(field_meta.annotation, ForwardRef) + for field_meta in model_fields.values() + ) or any( isinstance(typ, ForwardRef) for field_meta in model_fields.values() for typ in get_args(field_meta.annotation) ): - other_vars = vars(import_module(self.document_model.__module__)) - - include_fields = { - name: ( - model_fields[name].annotation, - model_fields[name], + vars(import_module(self.document_model.__module__)) + + include_fields: dict[str, tuple[type, FieldInfo]] = {} + for name in set_fields: + field_copy = model_fields[name]._copy() + field_copy.default = None + include_fields[name] = ( + Optional[model_fields[name].annotation], + field_copy, ) - for name in set_fields - } data_model = create_model( # type: ignore "MPDataDoc", @@ -1090,10 +1089,18 @@ def _generate_returned_model(self, doc): # TODO fields_not_requested is not the same as unset_fields # i.e. field could be requested but not available in the raw doc fields_not_requested=(list[str], unset_fields), - __base__=self.document_model, + __doc__=".".join( + [ + getattr(self.document_model, k, "") + for k in ("__module__", "__name__") + ] + ), + __module__=self.document_model.__module__, ) - if other_vars: - data_model.model_rebuild(_types_namespace=other_vars) + # if other_vars: + # data_model.model_rebuild(_types_namespace=other_vars) + + orig_rester_name = self.document_model.__name__ def new_repr(self) -> str: extra = ",\n".join( @@ -1102,7 +1109,7 @@ def new_repr(self) -> str: if n == "fields_not_requested" or n in set_fields ) - s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501 + s = f"\033[4m\033[1m{self.__class__.__name__}<{orig_rester_name}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501 return s def new_str(self) -> str: @@ -1216,8 +1223,14 @@ def get_data_by_id( stacklevel=2, ) - if self.primary_key in ["material_id", "task_id"]: - validate_ids([document_id]) + if self.primary_key in [ + "material_id", + "task_id", + "battery_id", + "spectrum_id", + "thermo_id", + ]: + document_id = validate_ids([document_id])[0] if isinstance(fields, str): # pragma: no cover fields = (fields,) # type: ignore diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index 67b6f5a8..c2d03fec 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -1,18 +1,46 @@ from __future__ import annotations import re -from functools import cache -from typing import Optional, get_args +from typing import TYPE_CHECKING, Literal import orjson -from emmet.core.utils import get_flat_models_from_model -from monty.json import MontyDecoder, MSONable -from pydantic import BaseModel -from pydantic._internal._utils import lenient_issubclass -from pydantic.fields import FieldInfo +from emmet.core import __version__ as _EMMET_CORE_VER +from monty.json import MontyDecoder +from packaging.version import parse as parse_version from mp_api.client.core.settings import MAPIClientSettings +if TYPE_CHECKING: + from monty.json import MSONable + + +def _compare_emmet_ver( + ref_version: str, op: Literal["==", ">", ">=", "<", "<="] +) -> bool: + """Compare the current emmet-core version to a reference for version guarding. + + Example: + _compare_emmet_ver("0.84.0rc0","<") returns + emmet.core.__version__ < "0.84.0rc0" + + Parameters + ----------- + ref_version : str + A reference version of emmet-core + op : A mathematical operator + """ + op_to_op = {"==": "eq", ">": "gt", ">=": "ge", "<": "lt", "<=": "le"} + return getattr( + parse_version(_EMMET_CORE_VER), + f"__{op_to_op.get(op,op)}__", + )(parse_version(ref_version)) + + +if _compare_emmet_ver("0.85.0", ">="): + from emmet.core.mpid_ext import validate_identifier +else: + validate_identifier = None + def load_json(json_like: str | bytes, deser: bool = False, encoding: str = "utf-8"): """Utility to load json in consistent manner.""" @@ -22,6 +50,26 @@ def load_json(json_like: str | bytes, deser: bool = False, encoding: str = "utf- return MontyDecoder().process_decoded(data) if deser else data +def _legacy_id_validation(id_list: list[str]) -> list[str]: + """Legacy utility to validate IDs, pre-AlphaID transition. + + This function is temporarily maintained to allow for + backwards compatibility with older versions of emmet, and will + not be preserved. + """ + pattern = "(mp|mvc|mol|mpcule)-.*" + if malformed_ids := { + entry for entry in id_list if re.match(pattern, entry) is None + }: + raise ValueError( + f"{'Entry' if len(malformed_ids) == 1 else 'Entries'}" + f" {', '.join(malformed_ids)}" + f"{'is' if len(malformed_ids) == 1 else 'are'} not formatted correctly!" + ) + + return id_list + + def validate_ids(id_list: list[str]): """Function to validate material and task IDs. @@ -40,74 +88,12 @@ def validate_ids(id_list: list[str]): " data for all IDs and filter locally." ) - pattern = "(mp|mvc|mol|mpcule)-.*" - - for entry in id_list: - if re.match(pattern, entry) is None: - raise ValueError(f"{entry} is not formatted correctly!") - - return id_list - - -@cache -def api_sanitize( - pydantic_model: BaseModel, - fields_to_leave: list[str] | None = None, - allow_dict_msonable=False, -): - """Function to clean up pydantic models for the API by: - 1.) Making fields optional - 2.) Allowing dictionaries in-place of the objects for MSONable quantities. - - WARNING: This works in place, so it mutates the model and all sub-models - - Args: - pydantic_model (BaseModel): Pydantic model to alter - fields_to_leave (list[str] | None): list of strings for model fields as "model__name__.field". - Defaults to None. - allow_dict_msonable (bool): Whether to allow dictionaries in place of MSONable quantities. - Defaults to False - """ - models = [ - model - for model in get_flat_models_from_model(pydantic_model) - if issubclass(model, BaseModel) - ] # type: list[BaseModel] - - fields_to_leave = fields_to_leave or [] - fields_tuples = [f.split(".") for f in fields_to_leave] - assert all(len(f) == 2 for f in fields_tuples) - - for model in models: - model_fields_to_leave = {f[1] for f in fields_tuples if model.__name__ == f[0]} - for name, field in model.model_fields.items(): - field_type = field.annotation - - if field_type is not None and allow_dict_msonable: - if lenient_issubclass(field_type, MSONable): - field_type = allow_msonable_dict(field_type) - else: - for sub_type in get_args(field_type): - if lenient_issubclass(sub_type, MSONable): - allow_msonable_dict(sub_type) - - if name not in model_fields_to_leave: - new_field = FieldInfo.from_annotated_attribute( - Optional[field_type], None - ) - - for attr in ( - "json_schema_extra", - "exclude", - ): - if (val := getattr(field, attr)) is not None: - setattr(new_field, attr, val) - - model.model_fields[name] = new_field - - model.model_rebuild(force=True) - - return pydantic_model + # TODO: after the transition to AlphaID in the document models, + # The following line should be changed to + # return [validate_identifier(idx,serialize=True) for idx in id_list] + if validate_identifier: + return [str(validate_identifier(idx)) for idx in id_list] + return _legacy_id_validation(id_list) def allow_msonable_dict(monty_cls: type[MSONable]): diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index f14d2d0b..50fd5326 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -10,7 +10,6 @@ from emmet.core.mpid import MPID, AlphaID from emmet.core.settings import EmmetSettings from emmet.core.tasks import TaskDoc -from emmet.core.thermo import ThermoType from emmet.core.vasp.calc_types import CalcType from packaging import version from pymatgen.analysis.phase_diagram import PhaseDiagram @@ -24,7 +23,7 @@ from mp_api.client.core import BaseRester, MPRestError from mp_api.client.core.settings import MAPIClientSettings -from mp_api.client.core.utils import load_json, validate_ids +from mp_api.client.core.utils import _compare_emmet_ver, load_json, validate_ids from mp_api.client.routes import GeneralStoreRester, MessagesRester, UserSettingsRester from mp_api.client.routes.materials import ( AbsorptionRester, @@ -59,6 +58,11 @@ from mp_api.client.routes.materials.materials import MaterialsRester from mp_api.client.routes.molecules import MoleculeRester +if _compare_emmet_ver("0.85.0", ">="): + from emmet.core.types.enums import ThermoType +else: + from emmet.core.thermo import ThermoType + if TYPE_CHECKING: from typing import Any, Literal @@ -1339,7 +1343,7 @@ def get_charge_density_from_task_id( """ kwargs = dict( bucket="materialsproject-parsed", - key=f"chgcars/{str(task_id)}.json.gz", + key=f"chgcars/{validate_ids([task_id])[0]}.json.gz", decoder=lambda x: load_json(x, deser=self.monty_decode), ) chgcar = self.materials.tasks._query_open_data(**kwargs)[0] diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index 2fd487d0..c43a88af 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -234,7 +234,7 @@ def get_bandstructure_from_task_id(self, task_id: str): """ result = self._query_open_data( bucket="materialsproject-parsed", - key=f"bandstructures/{task_id}.json.gz", + key=f"bandstructures/{validate_ids([task_id])[0]}.json.gz", )[0] if result: @@ -428,7 +428,7 @@ def get_dos_from_task_id(self, task_id: str): """ result = self._query_open_data( bucket="materialsproject-parsed", - key=f"dos/{task_id}.json.gz", + key=f"dos/{validate_ids([task_id])[0]}.json.gz", )[0] if result: diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index ae19a490..509745ef 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -3,12 +3,17 @@ from collections import defaultdict import numpy as np -from emmet.core.thermo import ThermoDoc, ThermoType +from emmet.core.thermo import ThermoDoc from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.core import Element from mp_api.client.core import BaseRester -from mp_api.client.core.utils import load_json, validate_ids +from mp_api.client.core.utils import _compare_emmet_ver, load_json, validate_ids + +if _compare_emmet_ver("0.85.0", ">="): + from emmet.core.types.enums import ThermoType +else: + from emmet.core.thermo import ThermoType class ThermoRester(BaseRester[ThermoDoc]): diff --git a/mp_api/client/routes/materials/xas.py b/mp_api/client/routes/materials/xas.py index b4ee1d82..1b2e28a6 100644 --- a/mp_api/client/routes/materials/xas.py +++ b/mp_api/client/routes/materials/xas.py @@ -1,11 +1,26 @@ from __future__ import annotations -from emmet.core.xas import Edge, Type, XASDoc +from typing import TYPE_CHECKING + +from emmet.core.xas import XASDoc from pymatgen.core.periodic_table import Element from mp_api.client.core import BaseRester from mp_api.client.core.utils import validate_ids +if TYPE_CHECKING: + from mp_api.client.core.utils import _compare_emmet_ver + + if _compare_emmet_ver("0.85.0", ">="): + from emmet.core.types.enums import XasEdge, XasType + else: + from emmet.core.xas import ( + Edge as XasEdge, + ) + from emmet.core.xas import ( + Type as XasType, + ) + class XASRester(BaseRester[XASDoc]): suffix = "materials/xas" @@ -14,13 +29,13 @@ class XASRester(BaseRester[XASDoc]): def search( self, - edge: Edge | None = None, + edge: XasEdge | None = None, absorbing_element: Element | None = None, formula: str | None = None, chemsys: str | list[str] | None = None, elements: list[str] | None = None, material_ids: list[str] | None = None, - spectrum_type: Type | None = None, + spectrum_type: XasType | None = None, spectrum_ids: str | list[str] | None = None, num_chunks: int | None = None, chunk_size: int = 1000, @@ -30,7 +45,7 @@ def search( """Query core XAS docs using a variety of search criteria. Arguments: - edge (Edge): The absorption edge (e.g. K, L2, L3, L2,3). + edge (XasEdge): The absorption edge (e.g. K, L2, L3, L2,3). absorbing_element (Element): The absorbing element. formula (str): A formula including anonymized formula or wild cards (e.g., Fe2O3, ABO3, Si*). @@ -39,7 +54,7 @@ def search( elements (List[str]): A list of elements. material_ids (str, List[str]): A single Material ID string or list of strings (e.g., mp-149, [mp-149, mp-13]). - spectrum_type (Type): Spectrum type (e.g. EXAFS, XAFS, or XANES). + spectrum_type (XasType): Spectrum type (e.g. EXAFS, XAFS, or XANES). spectrum_ids (str, List[str]): A single Spectrum ID string or list of strings (e.g., mp-149-XANES-Li-K, [mp-149-XANES-Li-K, mp-13-XANES-Li-K]). num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. diff --git a/pyproject.toml b/pyproject.toml index 5944f191..d79963cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,14 +25,14 @@ dependencies = [ "typing-extensions>=3.7.4.1", "requests>=2.23.0", "monty>=2024.12.10", - "emmet-core>=0.84.6rc0,<0.85", + "emmet-core>=0.84.6rc0", "smart_open", "boto3", ] dynamic = ["version"] [project.optional-dependencies] -all = ["emmet-core[all]>=0.84.6rc0,<0.85", "custodian", "mpcontribs-client"] +all = ["emmet-core[all]>=0.84.6rc0", "custodian", "mpcontribs-client",] test = [ "pre-commit", "pytest", diff --git a/requirements/requirements-ubuntu-latest_py3.11.txt b/requirements/requirements-ubuntu-latest_py3.11.txt index 40946ae8..04ab5d88 100644 --- a/requirements/requirements-ubuntu-latest_py3.11.txt +++ b/requirements/requirements-ubuntu-latest_py3.11.txt @@ -141,6 +141,7 @@ tqdm==4.67.1 # via pymatgen typing-extensions==4.15.0 # via + # blake3 # emmet-core # mp-api (pyproject.toml) # pydantic diff --git a/requirements/requirements-ubuntu-latest_py3.11_extras.txt b/requirements/requirements-ubuntu-latest_py3.11_extras.txt index 7e348bae..eb59dd2a 100644 --- a/requirements/requirements-ubuntu-latest_py3.11_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.11_extras.txt @@ -22,6 +22,8 @@ babel==2.17.0 # via sphinx bibtexparser==1.4.3 # via pymatgen +blake3==1.0.7 + # via emmet-core boltons==25.0.0 # via mpcontribs-client boto3==1.40.43 @@ -60,16 +62,14 @@ dnspython==2.8.0 # pymongo docutils==0.21.2 # via sphinx -emmet-core[all]==0.84.10 +emmet-core[all]==0.85.0 # via mp-api (pyproject.toml) execnet==2.1.1 # via pytest-xdist executing==2.2.1 # via stack-data filelock==3.19.1 - # via - # mdanalysis - # virtualenv + # via virtualenv filetype==1.2.0 # via mpcontribs-client flake8==7.3.0 @@ -84,8 +84,6 @@ fonttools==4.60.1 # via matplotlib fqdn==1.5.1 # via jsonschema -griddataformats==1.0.2 - # via mdanalysis identify==2.6.14 # via pre-commit idna==3.10 @@ -118,7 +116,6 @@ jmespath==1.0.1 # botocore joblib==1.5.2 # via - # mdanalysis # pymatgen # pymatgen-analysis-diffusion # scikit-learn @@ -149,10 +146,8 @@ matminer==0.9.3 matplotlib==3.10.6 # via # ase - # mdanalysis # pymatgen # seaborn - # solvation-analysis matplotlib-inline==0.1.7 # via ipython mccabe==0.7.0 @@ -184,8 +179,6 @@ mpcontribs-client==5.10.4 # via mp-api (pyproject.toml) mpmath==1.3.0 # via sympy -mrcfile==1.5.4 - # via griddataformats msgpack==1.1.1 # via # bravado @@ -211,21 +204,16 @@ numpy==1.26.4 # via # ase # contourpy - # griddataformats # imageio # matminer # matplotlib - # mdanalysis # monty # mpcontribs-client - # mrcfile # pandas - # patsy # pymatgen # pymatgen-analysis-defects # pymatgen-analysis-diffusion # pymatgen-io-validation - # rdkit # robocrys # scikit-image # scikit-learn @@ -233,10 +221,7 @@ numpy==1.26.4 # seaborn # seekpath # shapely - # solvation-analysis # spglib - # statsmodels - # tidynamics # tifffile orjson==3.11.3 # via pymatgen @@ -249,7 +234,6 @@ packaging==25.0 # pytest # scikit-image # sphinx - # statsmodels palettable==3.3.3 # via pymatgen pandas==2.3.3 @@ -264,15 +248,12 @@ parso==0.8.5 # via jedi pathspec==0.12.1 # via mypy -patsy==1.0.1 - # via statsmodels pexpect==4.9.0 # via ipython pillow==11.3.0 # via # imageio # matplotlib - # rdkit # scikit-image pint==0.25 # via mpcontribs-client @@ -284,7 +265,6 @@ plotly==6.3.0 # via # mpcontribs-client # pymatgen - # solvation-analysis pluggy==1.6.0 # via # pytest @@ -367,7 +347,6 @@ pytest==8.4.2 # pytest-cov # pytest-mock # pytest-xdist - # solvation-analysis pytest-asyncio==1.2.0 # via mp-api (pyproject.toml) pytest-cov==7.0.0 @@ -446,14 +425,10 @@ scikit-learn==1.7.2 scipy==1.16.2 # via # ase - # griddataformats - # mdanalysis # pymatgen # robocrys # scikit-image # scikit-learn - # solvation-analysis - # statsmodels seaborn==0.13.2 # via pymatgen-analysis-diffusion seekpath==2.1.0 @@ -477,8 +452,6 @@ smart-open==7.3.1 # via mp-api (pyproject.toml) snowballstemmer==3.0.1 # via sphinx -solvation-analysis==0.4.1 - # via emmet-core spglib==2.6.0 # via # pymatgen @@ -500,8 +473,6 @@ sphinxcontrib-serializinghtml==2.0.0 # via sphinx stack-data==0.6.3 # via ipython -statsmodels==0.14.5 - # via solvation-analysis swagger-spec-validator==3.0.4 # via # bravado-core @@ -523,15 +494,12 @@ tifffile==2025.9.30 tqdm==4.67.1 # via # matminer - # mdanalysis # mpcontribs-client # pymatgen traitlets==5.14.3 # via # ipython # matplotlib-inline -transport-analysis==0.1.2 - # via emmet-core typeguard==4.4.4 # via inflect types-python-dateutil==2.9.0.20250822 @@ -542,6 +510,7 @@ types-setuptools==80.9.0.20250822 # via mp-api (pyproject.toml) typing-extensions==4.15.0 # via + # blake3 # bravado # emmet-core # flexcache diff --git a/requirements/requirements-ubuntu-latest_py3.12_extras.txt b/requirements/requirements-ubuntu-latest_py3.12_extras.txt index 805c128e..2fefcc2f 100644 --- a/requirements/requirements-ubuntu-latest_py3.12_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.12_extras.txt @@ -22,6 +22,8 @@ babel==2.17.0 # via sphinx bibtexparser==1.4.3 # via pymatgen +blake3==1.0.7 + # via emmet-core boltons==25.0.0 # via mpcontribs-client boto3==1.40.43 @@ -60,16 +62,14 @@ dnspython==2.8.0 # pymongo docutils==0.21.2 # via sphinx -emmet-core[all]==0.84.10 +emmet-core[all]==0.85.0 # via mp-api (pyproject.toml) execnet==2.1.1 # via pytest-xdist executing==2.2.1 # via stack-data filelock==3.19.1 - # via - # mdanalysis - # virtualenv + # via virtualenv filetype==1.2.0 # via mpcontribs-client flake8==7.3.0 @@ -84,8 +84,6 @@ fonttools==4.60.1 # via matplotlib fqdn==1.5.1 # via jsonschema -griddataformats==1.0.2 - # via mdanalysis identify==2.6.14 # via pre-commit idna==3.10 @@ -118,7 +116,6 @@ jmespath==1.0.1 # botocore joblib==1.5.2 # via - # mdanalysis # pymatgen # pymatgen-analysis-diffusion # scikit-learn @@ -149,10 +146,8 @@ matminer==0.9.3 matplotlib==3.10.6 # via # ase - # mdanalysis # pymatgen # seaborn - # solvation-analysis matplotlib-inline==0.1.7 # via ipython mccabe==0.7.0 @@ -184,8 +179,6 @@ mpcontribs-client==5.10.4 # via mp-api (pyproject.toml) mpmath==1.3.0 # via sympy -mrcfile==1.5.4 - # via griddataformats msgpack==1.1.1 # via # bravado @@ -211,21 +204,16 @@ numpy==1.26.4 # via # ase # contourpy - # griddataformats # imageio # matminer # matplotlib - # mdanalysis # monty # mpcontribs-client - # mrcfile # pandas - # patsy # pymatgen # pymatgen-analysis-defects # pymatgen-analysis-diffusion # pymatgen-io-validation - # rdkit # robocrys # scikit-image # scikit-learn @@ -233,10 +221,7 @@ numpy==1.26.4 # seaborn # seekpath # shapely - # solvation-analysis # spglib - # statsmodels - # tidynamics # tifffile orjson==3.11.3 # via pymatgen @@ -249,7 +234,6 @@ packaging==25.0 # pytest # scikit-image # sphinx - # statsmodels palettable==3.3.3 # via pymatgen pandas==2.3.3 @@ -264,15 +248,12 @@ parso==0.8.5 # via jedi pathspec==0.12.1 # via mypy -patsy==1.0.1 - # via statsmodels pexpect==4.9.0 # via ipython pillow==11.3.0 # via # imageio # matplotlib - # rdkit # scikit-image pint==0.25 # via mpcontribs-client @@ -284,7 +265,6 @@ plotly==6.3.0 # via # mpcontribs-client # pymatgen - # solvation-analysis pluggy==1.6.0 # via # pytest @@ -367,7 +347,6 @@ pytest==8.4.2 # pytest-cov # pytest-mock # pytest-xdist - # solvation-analysis pytest-asyncio==1.2.0 # via mp-api (pyproject.toml) pytest-cov==7.0.0 @@ -446,14 +425,10 @@ scikit-learn==1.7.2 scipy==1.16.2 # via # ase - # griddataformats - # mdanalysis # pymatgen # robocrys # scikit-image # scikit-learn - # solvation-analysis - # statsmodels seaborn==0.13.2 # via pymatgen-analysis-diffusion seekpath==2.1.0 @@ -477,8 +452,6 @@ smart-open==7.3.1 # via mp-api (pyproject.toml) snowballstemmer==3.0.1 # via sphinx -solvation-analysis==0.4.1 - # via emmet-core spglib==2.6.0 # via # pymatgen @@ -500,8 +473,6 @@ sphinxcontrib-serializinghtml==2.0.0 # via sphinx stack-data==0.6.3 # via ipython -statsmodels==0.14.5 - # via solvation-analysis swagger-spec-validator==3.0.4 # via # bravado-core @@ -523,15 +494,12 @@ tifffile==2025.9.30 tqdm==4.67.1 # via # matminer - # mdanalysis # mpcontribs-client # pymatgen traitlets==5.14.3 # via # ipython # matplotlib-inline -transport-analysis==0.1.2 - # via emmet-core typeguard==4.4.4 # via inflect types-python-dateutil==2.9.0.20250822 diff --git a/tests/materials/test_thermo.py b/tests/materials/test_thermo.py index acb58d8b..1a4a0b38 100644 --- a/tests/materials/test_thermo.py +++ b/tests/materials/test_thermo.py @@ -2,10 +2,15 @@ from core_function import client_search_testing import pytest -from emmet.core.thermo import ThermoType from pymatgen.analysis.phase_diagram import PhaseDiagram from mp_api.client.routes.materials.thermo import ThermoRester +from mp_api.client.core.utils import _compare_emmet_ver + +if _compare_emmet_ver("0.85.0", ">="): + from emmet.core.types.enums import ThermoType +else: + from emmet.core.thermo import ThermoType @pytest.fixture diff --git a/tests/materials/test_xas.py b/tests/materials/test_xas.py index d03af7ab..d461f0a0 100644 --- a/tests/materials/test_xas.py +++ b/tests/materials/test_xas.py @@ -2,10 +2,18 @@ from core_function import client_search_testing import pytest -from emmet.core.xas import Edge, Type from pymatgen.core.periodic_table import Element from mp_api.client.routes.materials.xas import XASRester +from mp_api.client.core.utils import _compare_emmet_ver + +if _compare_emmet_ver("0.85.0", ">="): + from emmet.core.types.enums import XasEdge, XasType +else: + from emmet.core.xas import ( + Type as XasType, + Edge as XasEdge, + ) @pytest.fixture @@ -33,8 +41,8 @@ def rester(): } # type: dict custom_field_tests = { - "edge": Edge.L2_3, - "spectrum_type": Type.EXAFS, + "edge": XasEdge.L2_3, + "spectrum_type": XasType.EXAFS, "absorbing_element": Element("Ce"), "required_elements": [Element("Ce")], "formula": "Ce(WO4)2", diff --git a/tests/test_client.py b/tests/test_client.py index 49a77504..7a5a0ff2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -67,6 +67,10 @@ def test_generic_get_methods(rester): use_document_model=True, ) + docs_check = lambda _docs: all( + rester.document_model.__module__ == _doc.__module__ for _doc in _docs + ) + if name not in ignore_generic: key = rester.primary_key if name not in key_only_resters: @@ -74,12 +78,22 @@ def test_generic_get_methods(rester): key = rester.available_fields[0] doc = rester._query_resource_data({"_limit": 1}, fields=[key])[0] - assert isinstance(doc, rester.document_model) + assert docs_check([doc]) if name not in search_only_resters: - doc = rester.get_data_by_id(doc.model_dump()[key], fields=[key]) - assert isinstance(doc, rester.document_model) + docs = rester.search( + **{key + "s": [doc.model_dump()[key]]}, fields=[key] + ) + assert docs_check(docs) elif name not in special_resters: - doc = rester.get_data_by_id(key_only_resters[name], fields=[key]) - assert isinstance(doc, rester.document_model) + search_method = "search" + if name == "materials_robocrys": + search_method += "_docs" + docs = getattr(rester, search_method)( + **{key + "s": [key_only_resters[name]]}, fields=[key] + ) + with pytest.warns(DeprecationWarning, match="get_data_by_id is deprecated"): + _ = rester.get_data_by_id(key_only_resters[name], fields=[key]) + + assert docs_check(docs) diff --git a/tests/test_mprester.py b/tests/test_mprester.py index 8eb55ef0..0cc9d271 100644 --- a/tests/test_mprester.py +++ b/tests/test_mprester.py @@ -6,7 +6,6 @@ import numpy as np import pytest from emmet.core.tasks import TaskDoc -from emmet.core.thermo import ThermoType from emmet.core.vasp.calc_types import CalcType from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.analysis.pourbaix_diagram import IonEntry, PourbaixDiagram, PourbaixEntry @@ -32,6 +31,12 @@ from mp_api.client import MPRester from mp_api.client.core.client import MPRestError from mp_api.client.core.settings import MAPIClientSettings +from mp_api.client.core.utils import _compare_emmet_ver + +if _compare_emmet_ver("0.85.0", ">="): + from emmet.core.types.enums import ThermoType +else: + from emmet.core.thermo import ThermoType @pytest.fixture() @@ -303,7 +308,7 @@ def test_get_charge_density_from_material_id(self, mpr): "mp-149", inc_task_doc=True ) assert isinstance(chgcar, Chgcar) - assert isinstance(task_doc, TaskDoc) + assert isinstance(TaskDoc.model_validate(task_doc.model_dump()), TaskDoc) def test_get_charge_density_from_task_id(self, mpr): chgcar = mpr.get_charge_density_from_task_id("mp-2246557") @@ -313,7 +318,7 @@ def test_get_charge_density_from_task_id(self, mpr): "mp-2246557", inc_task_doc=True ) assert isinstance(chgcar, Chgcar) - assert isinstance(task_doc, TaskDoc) + assert isinstance(TaskDoc.model_validate(task_doc.model_dump()), TaskDoc) def test_get_wulff_shape(self, mpr): ws = mpr.get_wulff_shape("mp-126")