Skip to content

Commit

Permalink
Migrate ext.COD from mysql to REST API (#4117)
Browse files Browse the repository at this point in the history
* tweak comments

* reduce default timeout as 10 minutes is unrealistic

* remove mysql test in test

* finish rewrite

* use tighter timeout in test

* capture timeout errors in ci

* make the timeout skip a wrapper

* use a conditional timeout

* better deprecation handle without breaking

* use a formula with only one match
  • Loading branch information
DanielYang59 authored Oct 21, 2024
1 parent a28e1da commit aa9bdbf
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 74 deletions.
126 changes: 69 additions & 57 deletions src/pymatgen/ext/cod.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,97 +27,109 @@

from __future__ import annotations

import re
import subprocess
import warnings
from shutil import which
from typing import TYPE_CHECKING

import requests
from monty.dev import requires

from pymatgen.core.composition import Composition
from pymatgen.core.structure import Structure

if TYPE_CHECKING:
from typing import Literal


class COD:
"""An interface to the Crystallography Open Database."""
"""An interface to the Crystallography Open Database.
url = "www.crystallography.net"
Reference:
https://wiki.crystallography.net/RESTful_API/
"""

def query(self, sql: str) -> str:
"""Perform a query.
def __init__(self, timeout: int = 60):
"""Initialize the COD class.
Args:
sql: SQL string
Returns:
Response from SQL query.
timeout (int): request timeout in seconds.
"""
response = subprocess.check_output(["mysql", "-u", "cod_reader", "-h", self.url, "-e", sql, "cod"])
return response.decode("utf-8")
self.timeout = timeout
self.url = "https://www.crystallography.net"
self.api_url = f"{self.url}/cod/result"

@requires(which("mysql"), "mysql must be installed to use this query.")
def get_cod_ids(self, formula) -> list[int]:
"""Query the COD for all cod ids associated with a formula. Requires
mysql executable to be in the path.
def get_cod_ids(self, formula: str) -> list[int]:
"""Query the COD for all COD IDs associated with a formula.
Args:
formula (str): Formula.
Returns:
List of cod ids.
formula (str): The formula to request
"""
# TODO: Remove dependency on external mysql call. MySQL-python package does not support Py3!

# Standardize formula to the version used by COD
# Use hill_formula format as per COD request
cod_formula = Composition(formula).hill_formula
sql = f'select file from data where formula="- {cod_formula} -"' # noqa: S608
text = self.query(sql).split("\n")
cod_ids = []
for line in text:
if match := re.search(r"(\d+)", line):
cod_ids.append(int(match[1]))
return cod_ids

def get_structure_by_id(self, cod_id: int, timeout: int = 600, **kwargs) -> Structure:
"""Query the COD for a structure by id.
# Set up query parameters
params = {"formula": cod_formula, "format": "json"}

response = requests.get(self.api_url, params=params, timeout=self.timeout)

# Raise an exception if the request fails
response.raise_for_status()

return [int(entry["file"]) for entry in response.json()]

def get_structure_by_id(self, cod_id: int, timeout: int | None = None, **kwargs) -> Structure:
"""Query the COD for a structure by ID.
Args:
cod_id (int): COD id.
timeout (int): Timeout for the request in seconds. Default = 600.
kwargs: All kwargs supported by Structure.from_str.
cod_id (int): COD ID.
timeout (int): DEPRECATED. request timeout in seconds.
kwargs: kwargs passed to Structure.from_str.
Returns:
A Structure.
"""
response = requests.get(f"https://{self.url}/cod/{cod_id}.cif", timeout=timeout)
# TODO: remove timeout arg and use class level timeout after 2025-10-17
if timeout is not None:
warnings.warn("separate timeout arg is deprecated, please use class level timeout", DeprecationWarning)
timeout = timeout or self.timeout

response = requests.get(f"{self.url}/cod/{cod_id}.cif", timeout=timeout)
return Structure.from_str(response.text, fmt="cif", **kwargs)

@requires(which("mysql"), "mysql must be installed to use this query.")
def get_structure_by_formula(self, formula: str, **kwargs) -> list[dict[str, str | int | Structure]]:
"""Query the COD for structures by formula. Requires mysql executable to
be in the path.
def get_structure_by_formula(
self,
formula: str,
**kwargs,
) -> list[dict[Literal["structure", "cod_id", "sg"], str | int | Structure]]:
"""Query the COD for structures by formula.
Args:
formula (str): Chemical formula.
kwargs: All kwargs supported by Structure.from_str.
Returns:
A list of dict of the format [{"structure": Structure, "cod_id": int, "sg": "P n m a"}]
A list of dict of: {"structure": Structure, "cod_id": int, "sg": "P n m a"}
"""
structures: list[dict[str, str | int | Structure]] = []
sql = f'select file, sg from data where formula="- {Composition(formula).hill_formula} -"' # noqa: S608
text = self.query(sql).split("\n")
text.pop(0)
for line in text:
if line.strip():
cod_id, sg = line.split("\t")
response = requests.get(f"https://{self.url}/cod/{cod_id.strip()}.cif", timeout=60)
try:
struct = Structure.from_str(response.text, fmt="cif", **kwargs)
structures.append({"structure": struct, "cod_id": int(cod_id), "sg": sg})
except Exception:
warnings.warn(f"\nStructure.from_str failed while parsing CIF file:\n{response.text}")
raise
# Prepare the query parameters
params = {
"formula": Composition(formula).hill_formula,
"format": "json",
}

response = requests.get(self.api_url, params=params, timeout=self.timeout)
response.raise_for_status()

structures: list[dict[Literal["structure", "cod_id", "sg"], str | int | Structure]] = []

# Parse the JSON response
for entry in response.json():
cod_id = entry["file"]
sg = entry.get("sg")

try:
struct = self.get_structure_by_id(cod_id, **kwargs)
structures.append({"structure": struct, "cod_id": int(cod_id), "sg": sg})

except Exception:
warnings.warn(f"Structure.from_str failed while parsing CIF file for COD ID {cod_id}", stacklevel=2)
raise

return structures
54 changes: 37 additions & 17 deletions tests/ext/test_cod.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,58 @@
from __future__ import annotations

import os
from shutil import which
from unittest import TestCase
from functools import wraps

import pytest
import requests
import urllib3

from pymatgen.ext.cod import COD

if "CI" in os.environ: # test is slow and flaky, skip in CI. see
# https://github.com/materialsproject/pymatgen/pull/3777#issuecomment-2071217785
pytest.skip(allow_module_level=True, reason="Skip COD test in CI")
# Set a tighter timeout in CI
TIMEOUT = 10 if os.getenv("CI") else 60


try:
WEBSITE_DOWN = requests.get("https://www.crystallography.net", timeout=60).status_code != 200
except (requests.exceptions.ConnectionError, urllib3.exceptions.ConnectTimeoutError):
WEBSITE_DOWN = requests.get("https://www.crystallography.net", timeout=TIMEOUT).status_code != 200
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout, requests.exceptions.ReadTimeout):
WEBSITE_DOWN = True

if WEBSITE_DOWN:
pytest.skip(reason="www.crystallography.net is down", allow_module_level=True)


def skip_on_timeout(func):
"""Skip test in CI when time out."""

@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout, requests.exceptions.ReadTimeout):
if os.getenv("CI"):
pytest.skip("Request timeout in CI environment")
else:
raise

return wrapper


@pytest.mark.skipif(WEBSITE_DOWN, reason="www.crystallography.net is down")
class TestCOD(TestCase):
@pytest.mark.skipif(not which("mysql"), reason="No mysql")
class TestCOD:
@skip_on_timeout
def test_get_cod_ids(self):
ids = COD().get_cod_ids("Li2O")
ids = COD(timeout=TIMEOUT).get_cod_ids("Li2O")
assert len(ids) > 15
assert set(ids).issuperset({1010064, 1011372})

@pytest.mark.skipif(not which("mysql"), reason="No mysql")
@skip_on_timeout
def test_get_structure_by_formula(self):
data = COD().get_structure_by_formula("Li2O")
assert len(data) > 15
assert data[0]["structure"].reduced_formula == "Li2O"
# This formula has only one match (as of 2024-10-17) therefore
# the runtime is shorter (~ 2s for each match)
data = COD(timeout=TIMEOUT).get_structure_by_formula("C3 H18 F6 Fe N9")
assert len(data) >= 1
assert data[0]["structure"].reduced_formula == "FeH18C3(N3F2)3"

@skip_on_timeout
def test_get_structure_by_id(self):
struct = COD().get_structure_by_id(2_002_926)
struct = COD(timeout=TIMEOUT).get_structure_by_id(2_002_926)
assert struct.formula == "Be8 H64 N16 F32"

0 comments on commit aa9bdbf

Please sign in to comment.