Skip to content

Commit

Permalink
mv figshare.py to remote/figshare.py and extract download_file from d…
Browse files Browse the repository at this point in the history
…ata.py to new remote/fetch.py
  • Loading branch information
janosh committed Feb 8, 2025
1 parent c832147 commit 2f7ea53
Show file tree
Hide file tree
Showing 11 changed files with 174 additions and 169 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.4
rev: v0.9.5
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -57,7 +57,7 @@ repos:
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.19.0
rev: v9.20.0
hooks:
- id: eslint
types: [file]
Expand Down
44 changes: 0 additions & 44 deletions matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
~/.cache/matbench-discovery.
"""

import builtins
import io
import os
import sys
import traceback
import zipfile
from collections import defaultdict
from collections.abc import Callable, Sequence
Expand All @@ -25,7 +23,6 @@

import ase.io
import pandas as pd
import requests
import yaml
from ase import Atoms
from pymatviz.enums import Key
Expand Down Expand Up @@ -197,47 +194,6 @@ def ase_atoms_to_zip(
zip_file.writestr(f"{mat_id}.extxyz", buffer.getvalue())


def download_file(file_path: str, url: str) -> None:
"""Download the file from the given URL to the given file path.
Prints rather than raises if the file cannot be downloaded.
"""
file_dir = os.path.dirname(file_path)
os.makedirs(file_dir, exist_ok=True)
try:
response = requests.get(url, timeout=5)

response.raise_for_status()

with open(file_path, "wb") as file:
file.write(response.content)
except requests.exceptions.RequestException:
print(f"Error downloading {url=}\nto {file_path=}.\n{traceback.format_exc()}")


def maybe_auto_download_file(url: str, abs_path: str, label: str | None = None) -> None:
"""Download file if not exist and user confirms or auto-download is enabled."""
if os.path.isfile(abs_path):
return

# whether to auto-download model prediction files without prompting
auto_download_files = os.getenv("MBD_AUTO_DOWNLOAD_FILES", "true").lower() == "true"

is_ipython = hasattr(builtins, "__IPYTHON__")
# default to 'y' if auto-download enabled or not in interactive session (TTY
# or iPython)
answer = (
"y"
if auto_download_files or not (is_ipython or sys.stdin.isatty())
else input(
f"{abs_path!r} associated with {label=} does not exist. Download it "
"now? This will cache the file for future use. [y/n] "
)
)
if answer.lower().strip() == "y":
print(f"Downloading {label!r} from {url!r} to {abs_path!r}")
download_file(abs_path, url)


df_wbm = pd.read_csv(DataFiles.wbm_summary.path)
# str() around Key.mat_id added for https://github.com/janosh/matbench-discovery/issues/81
df_wbm.index = df_wbm[str(Key.mat_id)]
Expand Down
9 changes: 1 addition & 8 deletions matbench_discovery/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import yaml

from matbench_discovery import DEFAULT_CACHE_DIR, PKG_DIR, ROOT
from matbench_discovery.remote.fetch import download_file, maybe_auto_download_file

eV_per_atom = pmv.enums.eV_per_atom # noqa: N816
T = TypeVar("T", bound="Files")
Expand Down Expand Up @@ -358,8 +359,6 @@ def yaml_path(self) -> str:
@property
def discovery_path(self) -> str:
"""Prediction file path associated with the model."""
from matbench_discovery.data import maybe_auto_download_file

rel_path = self.metrics.get("discovery", {}).get("pred_file")
file_url = self.metrics.get("discovery", {}).get("pred_file_url")
if not rel_path:
Expand All @@ -375,8 +374,6 @@ def geo_opt_path(self) -> str | None:
"""File path associated with the file URL if it exists, otherwise
download the file first, then return the path.
"""
from matbench_discovery.data import maybe_auto_download_file

geo_opt_metrics = self.metrics.get("geo_opt", {})
if geo_opt_metrics in ("not available", "not applicable"):
return None
Expand All @@ -395,8 +392,6 @@ def kappa_103_path(self) -> str | None:
"""File path associated with the file URL if it exists, otherwise
download the file first, then return the path.
"""
from matbench_discovery.data import maybe_auto_download_file

phonons_metrics = self.metrics.get("phonons", {})
if phonons_metrics in ("not available", "not applicable"):
return None
Expand Down Expand Up @@ -495,8 +490,6 @@ def path(self) -> str:
"""File path associated with the file URL if it exists, otherwise
download the file first, then return the path.
"""
from matbench_discovery.data import download_file

key, rel_path = self.name, self.rel_path

if rel_path not in self.yaml[key]["path"]:
Expand Down
49 changes: 49 additions & 0 deletions matbench_discovery/remote/fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Files download functions."""

import builtins
import os
import sys
import traceback

import requests


def download_file(file_path: str, url: str) -> None:
"""Download the file from the given URL to the given file path.
Prints rather than raises if the file cannot be downloaded.
"""
file_dir = os.path.dirname(file_path)
os.makedirs(file_dir, exist_ok=True)
try:
response = requests.get(url, timeout=5)

response.raise_for_status()

with open(file_path, "wb") as file:
file.write(response.content)
except requests.exceptions.RequestException:
print(f"Error downloading {url=}\nto {file_path=}.\n{traceback.format_exc()}")


def maybe_auto_download_file(url: str, abs_path: str, label: str | None = None) -> None:
"""Download file if not exist and user confirms or auto-download is enabled."""
if os.path.isfile(abs_path):
return

# whether to auto-download model prediction files without prompting
auto_download_files = os.getenv("MBD_AUTO_DOWNLOAD_FILES", "true").lower() == "true"

is_ipython = hasattr(builtins, "__IPYTHON__")
# default to 'y' if auto-download enabled or not in interactive session (TTY
# or iPython)
answer = (
"y"
if auto_download_files or not (is_ipython or sys.stdin.isatty())
else input(
f"{abs_path!r} associated with {label=} does not exist. Download it "
"now? This will cache the file for future use. [y/n] "
)
)
if answer.lower().strip() == "y":
print(f"Downloading {label!r} from {url!r} to {abs_path!r}")
download_file(abs_path, url)
File renamed without changes.
2 changes: 1 addition & 1 deletion scripts/upload_data_files_to_figshare.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from tqdm import tqdm

import matbench_discovery.figshare as figshare
import matbench_discovery.remote.figshare as figshare
from matbench_discovery import DATA_DIR, PKG_DIR, ROOT
from matbench_discovery.data import round_trip_yaml
from matbench_discovery.enums import DataFiles
Expand Down
2 changes: 1 addition & 1 deletion scripts/upload_model_preds_to_figshare.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import yaml
from tqdm import tqdm

import matbench_discovery.figshare as figshare
import matbench_discovery.remote.figshare as figshare
from matbench_discovery import PKG_DIR, ROOT
from matbench_discovery.data import round_trip_yaml
from matbench_discovery.enums import Model
Expand Down
112 changes: 112 additions & 0 deletions tests/remote/test_fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os
from pathlib import Path
from unittest.mock import patch

import pytest
import requests

from matbench_discovery.remote.fetch import maybe_auto_download_file


def test_download_file(tmp_path: Path, capsys: pytest.CaptureFixture) -> None:
"""Test download_file function."""

from matbench_discovery.remote.fetch import download_file

url = "https://example.com/test.txt"
test_content = b"test content"
dest_path = tmp_path / "test.txt"

# Mock successful request
mock_response = requests.Response()
mock_response.status_code = 200
mock_response._content = test_content # noqa: SLF001

with patch("requests.get", return_value=mock_response):
download_file(str(dest_path), url)
assert dest_path.read_bytes() == test_content

# Mock failed request
mock_response = requests.Response()
mock_response.status_code = 404
mock_response._content = b"Not found" # noqa: SLF001

with patch("requests.get", return_value=mock_response):
download_file(str(dest_path), url) # Should print error but not raise

stdout, stderr = capsys.readouterr()
assert f"Error downloading {url=}" in stdout
assert stderr == ""


def test_maybe_auto_download_file(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture
) -> None:
"""Test auto-download behavior of maybe_auto_download_file function."""
url = "https://example.com/file.txt"
abs_path = f"{tmp_path}/test/file.txt"
os.makedirs(os.path.dirname(abs_path), exist_ok=True)

# Mock successful request
mock_response = requests.Response()
mock_response.status_code = 200
mock_response._content = b"test content" # noqa: SLF001

# Test 1: Auto-download enabled (default)
monkeypatch.setenv("MBD_AUTO_DOWNLOAD_FILES", "true")
with patch("requests.get", return_value=mock_response):
maybe_auto_download_file(url, abs_path, label="test")
stdout, _ = capsys.readouterr()
assert f"Downloading 'test' from {url!r}" in stdout
assert os.path.isfile(abs_path)

# Test 2: Auto-download disabled
os.remove(abs_path)
monkeypatch.setenv("MBD_AUTO_DOWNLOAD_FILES", "false")
assert not os.path.isfile(abs_path)

# Mock user input 'n' to skip download
with (
patch("requests.get", return_value=mock_response),
patch("builtins.input", return_value="n"),
patch("sys.stdin.isatty", return_value=True), # force interactive mode
):
maybe_auto_download_file(url, abs_path, label="test")
assert not os.path.isfile(abs_path)

# Test 3: Auto-download disabled but user confirms
with (
patch("requests.get", return_value=mock_response),
patch("builtins.input", return_value="y"),
patch("sys.stdin.isatty", return_value=True), # force interactive mode
):
maybe_auto_download_file(url, abs_path, label="test")
stdout, _ = capsys.readouterr()
assert f"Downloading 'test' from {url!r}" in stdout
assert os.path.isfile(abs_path)

# Test 4: File already exists (no download attempt)
with patch("requests.get") as mock_get:
maybe_auto_download_file(url, abs_path, label="test")
mock_get.assert_not_called()

# Test 5: Non-interactive session (auto-download)
os.remove(abs_path)
with (
patch("requests.get", return_value=mock_response),
patch("sys.stdin.isatty", return_value=False),
):
maybe_auto_download_file(url, abs_path, label="test")
stdout, _ = capsys.readouterr()
assert f"Downloading 'test' from {url!r}" in stdout
assert os.path.isfile(abs_path)

# Test 6: IPython session with auto-download disabled
os.remove(abs_path)
with (
patch("requests.get", return_value=mock_response),
patch("builtins.input", return_value="n"),
patch("sys.stdin.isatty", return_value=True), # force interactive mode
):
maybe_auto_download_file(url, abs_path, label="test")
assert not os.path.isfile(abs_path)
13 changes: 7 additions & 6 deletions tests/test_figshare.py → tests/remote/test_figshare.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import requests

import matbench_discovery.figshare as figshare
import matbench_discovery.remote.figshare as figshare


@pytest.mark.parametrize(
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_create_article_variants(
) -> None:
"""Test article creation with different metadata combinations."""
with patch(
"matbench_discovery.figshare.make_request",
"matbench_discovery.remote.figshare.make_request",
side_effect=[{"location": "loc"}, {"id": article_id}],
):
assert figshare.create_article(metadata, verbose=True) == article_id
Expand Down Expand Up @@ -194,9 +194,10 @@ def mock_make_request(method: str, url: str, **kwargs: Any) -> Any:
return mock_responses[method]

with (
patch("matbench_discovery.figshare.ROOT", str(tmp_path)),
patch("matbench_discovery.remote.figshare.ROOT", str(tmp_path)),
patch(
"matbench_discovery.figshare.make_request", side_effect=mock_make_request
"matbench_discovery.remote.figshare.make_request",
side_effect=mock_make_request,
),
):
assert figshare.upload_file(12345, str(test_file), file_name=file_name) == 67890
Expand All @@ -211,7 +212,7 @@ def mock_make_request(method: str, url: str, **kwargs: Any) -> Any:
@pytest.mark.parametrize("files", [[], DUMMY_FILES]) # Empty and non-empty
def test_list_article_files(files: list[dict[str, Any]]) -> None:
"""Test list_article_files with various file configurations."""
with patch("matbench_discovery.figshare.make_request", return_value=files):
with patch("matbench_discovery.remote.figshare.make_request", return_value=files):
assert figshare.list_article_files(12345) == files


Expand Down Expand Up @@ -273,7 +274,7 @@ def test_get_existing_files(
files: list[dict[str, Any]], expected: dict[str, dict[str, Any]]
) -> None:
"""Test get_existing_files with various file configurations."""
with patch("matbench_discovery.figshare.make_request", return_value=files):
with patch("matbench_discovery.remote.figshare.make_request", return_value=files):
assert figshare.get_existing_files(12345) == expected


Expand Down
Loading

0 comments on commit 2f7ea53

Please sign in to comment.