Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,5 @@ ENV/

# Doc folder
_autosummary

uv.lock
73 changes: 43 additions & 30 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TYPE_CHECKING,
ForwardRef,
Generic,
Optional,
TypeVar,
get_args,
)
Expand All @@ -37,7 +38,7 @@
from urllib3.util.retry import Retry

from mp_api.client.core.settings import MAPIClientSettings
from mp_api.client.core.utils import api_sanitize, load_json, validate_ids
from mp_api.client.core.utils import load_json, validate_ids

try:
import boto3
Expand All @@ -54,6 +55,8 @@
if TYPE_CHECKING:
from typing import Any, Callable

from pydantic.fields import FieldInfo

try:
__version__ = version("mp_api")
except PackageNotFoundError: # pragma: no cover
Expand Down Expand Up @@ -147,12 +150,6 @@ def __init__(
else:
self._s3_client = None

self.document_model = (
api_sanitize(self.document_model) # type: ignore
if self.document_model is not None
else None # type: ignore
)

@property
def session(self) -> requests.Session:
if not self._session:
Expand Down Expand Up @@ -1043,10 +1040,8 @@ def _convert_to_model(self, data: list[dict]):
(list[MPDataDoc]): List of MPDataDoc objects

"""
raw_doc_list = [self.document_model.model_validate(d) for d in data] # type: ignore

if len(raw_doc_list) > 0:
data_model, set_fields, _ = self._generate_returned_model(raw_doc_list[0])
if len(data) > 0:
data_model, set_fields, _ = self._generate_returned_model(data[0])

data = [
data_model(
Expand All @@ -1056,44 +1051,56 @@ def _convert_to_model(self, data: list[dict]):
if field in set_fields
}
)
for raw_doc in raw_doc_list
for raw_doc in data
]

return data

def _generate_returned_model(self, doc):
def _generate_returned_model(
self, doc: dict[str, Any]
) -> tuple[BaseModel, list[str], list[str]]:
model_fields = self.document_model.model_fields

set_fields = doc.model_fields_set
set_fields = [k for k in doc if k in model_fields]
unset_fields = [field for field in model_fields if field not in set_fields]

# Update with locals() from external module if needed
other_vars = {}
if any(
isinstance(field_meta.annotation, ForwardRef)
for field_meta in model_fields.values()
) or any(
isinstance(typ, ForwardRef)
for field_meta in model_fields.values()
for typ in get_args(field_meta.annotation)
):
other_vars = vars(import_module(self.document_model.__module__))

include_fields = {
name: (
model_fields[name].annotation,
model_fields[name],
vars(import_module(self.document_model.__module__))

include_fields: dict[str, tuple[type, FieldInfo]] = {}
for name in set_fields:
field_copy = model_fields[name]._copy()
field_copy.default = None
include_fields[name] = (
Optional[model_fields[name].annotation],
field_copy,
)
for name in set_fields
}

data_model = create_model( # type: ignore
"MPDataDoc",
**include_fields,
# TODO fields_not_requested is not the same as unset_fields
# i.e. field could be requested but not available in the raw doc
fields_not_requested=(list[str], unset_fields),
__base__=self.document_model,
__doc__=".".join(
[
getattr(self.document_model, k, "")
for k in ("__module__", "__name__")
]
),
__module__=self.document_model.__module__,
)
if other_vars:
data_model.model_rebuild(_types_namespace=other_vars)
# if other_vars:
# data_model.model_rebuild(_types_namespace=other_vars)

orig_rester_name = self.document_model.__name__

def new_repr(self) -> str:
extra = ",\n".join(
Expand All @@ -1102,7 +1109,7 @@ def new_repr(self) -> str:
if n == "fields_not_requested" or n in set_fields
)

s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501
s = f"\033[4m\033[1m{self.__class__.__name__}<{orig_rester_name}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501
return s

def new_str(self) -> str:
Expand Down Expand Up @@ -1216,8 +1223,14 @@ def get_data_by_id(
stacklevel=2,
)

if self.primary_key in ["material_id", "task_id"]:
validate_ids([document_id])
if self.primary_key in [
"material_id",
"task_id",
"battery_id",
"spectrum_id",
"thermo_id",
]:
document_id = validate_ids([document_id])[0]

if isinstance(fields, str): # pragma: no cover
fields = (fields,) # type: ignore
Expand Down
136 changes: 61 additions & 75 deletions mp_api/client/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,46 @@
from __future__ import annotations

import re
from functools import cache
from typing import Optional, get_args
from typing import TYPE_CHECKING, Literal

import orjson
from emmet.core.utils import get_flat_models_from_model
from monty.json import MontyDecoder, MSONable
from pydantic import BaseModel
from pydantic._internal._utils import lenient_issubclass
from pydantic.fields import FieldInfo
from emmet.core import __version__ as _EMMET_CORE_VER
from monty.json import MontyDecoder
from packaging.version import parse as parse_version

from mp_api.client.core.settings import MAPIClientSettings

if TYPE_CHECKING:
from monty.json import MSONable


def _compare_emmet_ver(
ref_version: str, op: Literal["==", ">", ">=", "<", "<="]
) -> bool:
"""Compare the current emmet-core version to a reference for version guarding.

Example:
_compare_emmet_ver("0.84.0rc0","<") returns
emmet.core.__version__ < "0.84.0rc0"

Parameters
-----------
ref_version : str
A reference version of emmet-core
op : A mathematical operator
"""
op_to_op = {"==": "eq", ">": "gt", ">=": "ge", "<": "lt", "<=": "le"}
return getattr(
parse_version(_EMMET_CORE_VER),
f"__{op_to_op.get(op,op)}__",
)(parse_version(ref_version))


if _compare_emmet_ver("0.85.0", ">="):
from emmet.core.mpid_ext import validate_identifier
else:
validate_identifier = None


def load_json(json_like: str | bytes, deser: bool = False, encoding: str = "utf-8"):
"""Utility to load json in consistent manner."""
Expand All @@ -22,6 +50,26 @@ def load_json(json_like: str | bytes, deser: bool = False, encoding: str = "utf-
return MontyDecoder().process_decoded(data) if deser else data


def _legacy_id_validation(id_list: list[str]) -> list[str]:
"""Legacy utility to validate IDs, pre-AlphaID transition.

This function is temporarily maintained to allow for
backwards compatibility with older versions of emmet, and will
not be preserved.
"""
pattern = "(mp|mvc|mol|mpcule)-.*"
if malformed_ids := {
entry for entry in id_list if re.match(pattern, entry) is None
}:
raise ValueError(
f"{'Entry' if len(malformed_ids) == 1 else 'Entries'}"
f" {', '.join(malformed_ids)}"
f"{'is' if len(malformed_ids) == 1 else 'are'} not formatted correctly!"
)

return id_list


def validate_ids(id_list: list[str]):
"""Function to validate material and task IDs.

Expand All @@ -40,74 +88,12 @@ def validate_ids(id_list: list[str]):
" data for all IDs and filter locally."
)

pattern = "(mp|mvc|mol|mpcule)-.*"

for entry in id_list:
if re.match(pattern, entry) is None:
raise ValueError(f"{entry} is not formatted correctly!")

return id_list


@cache
def api_sanitize(
pydantic_model: BaseModel,
fields_to_leave: list[str] | None = None,
allow_dict_msonable=False,
):
"""Function to clean up pydantic models for the API by:
1.) Making fields optional
2.) Allowing dictionaries in-place of the objects for MSONable quantities.

WARNING: This works in place, so it mutates the model and all sub-models

Args:
pydantic_model (BaseModel): Pydantic model to alter
fields_to_leave (list[str] | None): list of strings for model fields as "model__name__.field".
Defaults to None.
allow_dict_msonable (bool): Whether to allow dictionaries in place of MSONable quantities.
Defaults to False
"""
models = [
model
for model in get_flat_models_from_model(pydantic_model)
if issubclass(model, BaseModel)
] # type: list[BaseModel]

fields_to_leave = fields_to_leave or []
fields_tuples = [f.split(".") for f in fields_to_leave]
assert all(len(f) == 2 for f in fields_tuples)

for model in models:
model_fields_to_leave = {f[1] for f in fields_tuples if model.__name__ == f[0]}
for name, field in model.model_fields.items():
field_type = field.annotation

if field_type is not None and allow_dict_msonable:
if lenient_issubclass(field_type, MSONable):
field_type = allow_msonable_dict(field_type)
else:
for sub_type in get_args(field_type):
if lenient_issubclass(sub_type, MSONable):
allow_msonable_dict(sub_type)

if name not in model_fields_to_leave:
new_field = FieldInfo.from_annotated_attribute(
Optional[field_type], None
)

for attr in (
"json_schema_extra",
"exclude",
):
if (val := getattr(field, attr)) is not None:
setattr(new_field, attr, val)

model.model_fields[name] = new_field

model.model_rebuild(force=True)

return pydantic_model
# TODO: after the transition to AlphaID in the document models,
# The following line should be changed to
# return [validate_identifier(idx,serialize=True) for idx in id_list]
if validate_identifier:
return [str(validate_identifier(idx)) for idx in id_list]
return _legacy_id_validation(id_list)


def allow_msonable_dict(monty_cls: type[MSONable]):
Expand Down
10 changes: 7 additions & 3 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from emmet.core.mpid import MPID, AlphaID
from emmet.core.settings import EmmetSettings
from emmet.core.tasks import TaskDoc
from emmet.core.thermo import ThermoType
from emmet.core.vasp.calc_types import CalcType
from packaging import version
from pymatgen.analysis.phase_diagram import PhaseDiagram
Expand All @@ -24,7 +23,7 @@

from mp_api.client.core import BaseRester, MPRestError
from mp_api.client.core.settings import MAPIClientSettings
from mp_api.client.core.utils import load_json, validate_ids
from mp_api.client.core.utils import _compare_emmet_ver, load_json, validate_ids
from mp_api.client.routes import GeneralStoreRester, MessagesRester, UserSettingsRester
from mp_api.client.routes.materials import (
AbsorptionRester,
Expand Down Expand Up @@ -59,6 +58,11 @@
from mp_api.client.routes.materials.materials import MaterialsRester
from mp_api.client.routes.molecules import MoleculeRester

if _compare_emmet_ver("0.85.0", ">="):
from emmet.core.types.enums import ThermoType
else:
from emmet.core.thermo import ThermoType

if TYPE_CHECKING:
from typing import Any, Literal

Expand Down Expand Up @@ -1339,7 +1343,7 @@ def get_charge_density_from_task_id(
"""
kwargs = dict(
bucket="materialsproject-parsed",
key=f"chgcars/{str(task_id)}.json.gz",
key=f"chgcars/{validate_ids([task_id])[0]}.json.gz",
decoder=lambda x: load_json(x, deser=self.monty_decode),
)
chgcar = self.materials.tasks._query_open_data(**kwargs)[0]
Expand Down
4 changes: 2 additions & 2 deletions mp_api/client/routes/materials/electronic_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def get_bandstructure_from_task_id(self, task_id: str):
"""
result = self._query_open_data(
bucket="materialsproject-parsed",
key=f"bandstructures/{task_id}.json.gz",
key=f"bandstructures/{validate_ids([task_id])[0]}.json.gz",
)[0]

if result:
Expand Down Expand Up @@ -428,7 +428,7 @@ def get_dos_from_task_id(self, task_id: str):
"""
result = self._query_open_data(
bucket="materialsproject-parsed",
key=f"dos/{task_id}.json.gz",
key=f"dos/{validate_ids([task_id])[0]}.json.gz",
)[0]

if result:
Expand Down
9 changes: 7 additions & 2 deletions mp_api/client/routes/materials/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
from collections import defaultdict

import numpy as np
from emmet.core.thermo import ThermoDoc, ThermoType
from emmet.core.thermo import ThermoDoc
from pymatgen.analysis.phase_diagram import PhaseDiagram
from pymatgen.core import Element

from mp_api.client.core import BaseRester
from mp_api.client.core.utils import load_json, validate_ids
from mp_api.client.core.utils import _compare_emmet_ver, load_json, validate_ids

if _compare_emmet_ver("0.85.0", ">="):
from emmet.core.types.enums import ThermoType
else:
from emmet.core.thermo import ThermoType


class ThermoRester(BaseRester[ThermoDoc]):
Expand Down
Loading
Loading