From aa9bdbf645ab1b96e3c6e56dc190bff0774fa8df Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel) YANG" Date: Tue, 22 Oct 2024 04:54:02 +0800 Subject: [PATCH] Migrate `ext.COD` from `mysql` to REST API (#4117) * tweak comments * reduce default timeout as 10 minutes is unrealistic * remove mysql test in test * finish rewrite * use tighter timeout in test * capture timeout errors in ci * make the timeout skip a wrapper * use a conditional timeout * better deprecation handle without breaking * use a formula with only one match --- src/pymatgen/ext/cod.py | 126 ++++++++++++++++++++++------------------ tests/ext/test_cod.py | 54 +++++++++++------ 2 files changed, 106 insertions(+), 74 deletions(-) diff --git a/src/pymatgen/ext/cod.py b/src/pymatgen/ext/cod.py index dd037f0860d..e23bf1e3028 100644 --- a/src/pymatgen/ext/cod.py +++ b/src/pymatgen/ext/cod.py @@ -27,97 +27,109 @@ from __future__ import annotations -import re -import subprocess import warnings -from shutil import which +from typing import TYPE_CHECKING import requests -from monty.dev import requires from pymatgen.core.composition import Composition from pymatgen.core.structure import Structure +if TYPE_CHECKING: + from typing import Literal + class COD: - """An interface to the Crystallography Open Database.""" + """An interface to the Crystallography Open Database. - url = "www.crystallography.net" + Reference: + https://wiki.crystallography.net/RESTful_API/ + """ - def query(self, sql: str) -> str: - """Perform a query. + def __init__(self, timeout: int = 60): + """Initialize the COD class. Args: - sql: SQL string - - Returns: - Response from SQL query. + timeout (int): request timeout in seconds. """ - response = subprocess.check_output(["mysql", "-u", "cod_reader", "-h", self.url, "-e", sql, "cod"]) - return response.decode("utf-8") + self.timeout = timeout + self.url = "https://www.crystallography.net" + self.api_url = f"{self.url}/cod/result" - @requires(which("mysql"), "mysql must be installed to use this query.") - def get_cod_ids(self, formula) -> list[int]: - """Query the COD for all cod ids associated with a formula. Requires - mysql executable to be in the path. + def get_cod_ids(self, formula: str) -> list[int]: + """Query the COD for all COD IDs associated with a formula. Args: - formula (str): Formula. - - Returns: - List of cod ids. + formula (str): The formula to request """ - # TODO: Remove dependency on external mysql call. MySQL-python package does not support Py3! - - # Standardize formula to the version used by COD + # Use hill_formula format as per COD request cod_formula = Composition(formula).hill_formula - sql = f'select file from data where formula="- {cod_formula} -"' # noqa: S608 - text = self.query(sql).split("\n") - cod_ids = [] - for line in text: - if match := re.search(r"(\d+)", line): - cod_ids.append(int(match[1])) - return cod_ids - def get_structure_by_id(self, cod_id: int, timeout: int = 600, **kwargs) -> Structure: - """Query the COD for a structure by id. + # Set up query parameters + params = {"formula": cod_formula, "format": "json"} + + response = requests.get(self.api_url, params=params, timeout=self.timeout) + + # Raise an exception if the request fails + response.raise_for_status() + + return [int(entry["file"]) for entry in response.json()] + + def get_structure_by_id(self, cod_id: int, timeout: int | None = None, **kwargs) -> Structure: + """Query the COD for a structure by ID. Args: - cod_id (int): COD id. - timeout (int): Timeout for the request in seconds. Default = 600. - kwargs: All kwargs supported by Structure.from_str. + cod_id (int): COD ID. + timeout (int): DEPRECATED. request timeout in seconds. + kwargs: kwargs passed to Structure.from_str. Returns: A Structure. """ - response = requests.get(f"https://{self.url}/cod/{cod_id}.cif", timeout=timeout) + # TODO: remove timeout arg and use class level timeout after 2025-10-17 + if timeout is not None: + warnings.warn("separate timeout arg is deprecated, please use class level timeout", DeprecationWarning) + timeout = timeout or self.timeout + + response = requests.get(f"{self.url}/cod/{cod_id}.cif", timeout=timeout) return Structure.from_str(response.text, fmt="cif", **kwargs) - @requires(which("mysql"), "mysql must be installed to use this query.") - def get_structure_by_formula(self, formula: str, **kwargs) -> list[dict[str, str | int | Structure]]: - """Query the COD for structures by formula. Requires mysql executable to - be in the path. + def get_structure_by_formula( + self, + formula: str, + **kwargs, + ) -> list[dict[Literal["structure", "cod_id", "sg"], str | int | Structure]]: + """Query the COD for structures by formula. Args: formula (str): Chemical formula. kwargs: All kwargs supported by Structure.from_str. Returns: - A list of dict of the format [{"structure": Structure, "cod_id": int, "sg": "P n m a"}] + A list of dict of: {"structure": Structure, "cod_id": int, "sg": "P n m a"} """ - structures: list[dict[str, str | int | Structure]] = [] - sql = f'select file, sg from data where formula="- {Composition(formula).hill_formula} -"' # noqa: S608 - text = self.query(sql).split("\n") - text.pop(0) - for line in text: - if line.strip(): - cod_id, sg = line.split("\t") - response = requests.get(f"https://{self.url}/cod/{cod_id.strip()}.cif", timeout=60) - try: - struct = Structure.from_str(response.text, fmt="cif", **kwargs) - structures.append({"structure": struct, "cod_id": int(cod_id), "sg": sg}) - except Exception: - warnings.warn(f"\nStructure.from_str failed while parsing CIF file:\n{response.text}") - raise + # Prepare the query parameters + params = { + "formula": Composition(formula).hill_formula, + "format": "json", + } + + response = requests.get(self.api_url, params=params, timeout=self.timeout) + response.raise_for_status() + + structures: list[dict[Literal["structure", "cod_id", "sg"], str | int | Structure]] = [] + + # Parse the JSON response + for entry in response.json(): + cod_id = entry["file"] + sg = entry.get("sg") + + try: + struct = self.get_structure_by_id(cod_id, **kwargs) + structures.append({"structure": struct, "cod_id": int(cod_id), "sg": sg}) + + except Exception: + warnings.warn(f"Structure.from_str failed while parsing CIF file for COD ID {cod_id}", stacklevel=2) + raise return structures diff --git a/tests/ext/test_cod.py b/tests/ext/test_cod.py index 8c9c5d67220..d1eb62fb925 100644 --- a/tests/ext/test_cod.py +++ b/tests/ext/test_cod.py @@ -1,38 +1,58 @@ from __future__ import annotations import os -from shutil import which -from unittest import TestCase +from functools import wraps import pytest import requests -import urllib3 from pymatgen.ext.cod import COD -if "CI" in os.environ: # test is slow and flaky, skip in CI. see - # https://github.com/materialsproject/pymatgen/pull/3777#issuecomment-2071217785 - pytest.skip(allow_module_level=True, reason="Skip COD test in CI") +# Set a tighter timeout in CI +TIMEOUT = 10 if os.getenv("CI") else 60 + try: - WEBSITE_DOWN = requests.get("https://www.crystallography.net", timeout=60).status_code != 200 -except (requests.exceptions.ConnectionError, urllib3.exceptions.ConnectTimeoutError): + WEBSITE_DOWN = requests.get("https://www.crystallography.net", timeout=TIMEOUT).status_code != 200 +except (requests.exceptions.ConnectionError, requests.exceptions.Timeout, requests.exceptions.ReadTimeout): WEBSITE_DOWN = True +if WEBSITE_DOWN: + pytest.skip(reason="www.crystallography.net is down", allow_module_level=True) + + +def skip_on_timeout(func): + """Skip test in CI when time out.""" + + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout, requests.exceptions.ReadTimeout): + if os.getenv("CI"): + pytest.skip("Request timeout in CI environment") + else: + raise + + return wrapper + -@pytest.mark.skipif(WEBSITE_DOWN, reason="www.crystallography.net is down") -class TestCOD(TestCase): - @pytest.mark.skipif(not which("mysql"), reason="No mysql") +class TestCOD: + @skip_on_timeout def test_get_cod_ids(self): - ids = COD().get_cod_ids("Li2O") + ids = COD(timeout=TIMEOUT).get_cod_ids("Li2O") assert len(ids) > 15 + assert set(ids).issuperset({1010064, 1011372}) - @pytest.mark.skipif(not which("mysql"), reason="No mysql") + @skip_on_timeout def test_get_structure_by_formula(self): - data = COD().get_structure_by_formula("Li2O") - assert len(data) > 15 - assert data[0]["structure"].reduced_formula == "Li2O" + # This formula has only one match (as of 2024-10-17) therefore + # the runtime is shorter (~ 2s for each match) + data = COD(timeout=TIMEOUT).get_structure_by_formula("C3 H18 F6 Fe N9") + assert len(data) >= 1 + assert data[0]["structure"].reduced_formula == "FeH18C3(N3F2)3" + @skip_on_timeout def test_get_structure_by_id(self): - struct = COD().get_structure_by_id(2_002_926) + struct = COD(timeout=TIMEOUT).get_structure_by_id(2_002_926) assert struct.formula == "Be8 H64 N16 F32"