-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
mv figshare.py to remote/figshare.py and extract download_file from d…
…ata.py to new remote/fetch.py
- Loading branch information
Showing
11 changed files
with
174 additions
and
169 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
"""Files download functions.""" | ||
|
||
import builtins | ||
import os | ||
import sys | ||
import traceback | ||
|
||
import requests | ||
|
||
|
||
def download_file(file_path: str, url: str) -> None: | ||
"""Download the file from the given URL to the given file path. | ||
Prints rather than raises if the file cannot be downloaded. | ||
""" | ||
file_dir = os.path.dirname(file_path) | ||
os.makedirs(file_dir, exist_ok=True) | ||
try: | ||
response = requests.get(url, timeout=5) | ||
|
||
response.raise_for_status() | ||
|
||
with open(file_path, "wb") as file: | ||
file.write(response.content) | ||
except requests.exceptions.RequestException: | ||
print(f"Error downloading {url=}\nto {file_path=}.\n{traceback.format_exc()}") | ||
|
||
|
||
def maybe_auto_download_file(url: str, abs_path: str, label: str | None = None) -> None: | ||
"""Download file if not exist and user confirms or auto-download is enabled.""" | ||
if os.path.isfile(abs_path): | ||
return | ||
|
||
# whether to auto-download model prediction files without prompting | ||
auto_download_files = os.getenv("MBD_AUTO_DOWNLOAD_FILES", "true").lower() == "true" | ||
|
||
is_ipython = hasattr(builtins, "__IPYTHON__") | ||
# default to 'y' if auto-download enabled or not in interactive session (TTY | ||
# or iPython) | ||
answer = ( | ||
"y" | ||
if auto_download_files or not (is_ipython or sys.stdin.isatty()) | ||
else input( | ||
f"{abs_path!r} associated with {label=} does not exist. Download it " | ||
"now? This will cache the file for future use. [y/n] " | ||
) | ||
) | ||
if answer.lower().strip() == "y": | ||
print(f"Downloading {label!r} from {url!r} to {abs_path!r}") | ||
download_file(abs_path, url) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import os | ||
from pathlib import Path | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
import requests | ||
|
||
from matbench_discovery.remote.fetch import maybe_auto_download_file | ||
|
||
|
||
def test_download_file(tmp_path: Path, capsys: pytest.CaptureFixture) -> None: | ||
"""Test download_file function.""" | ||
|
||
from matbench_discovery.remote.fetch import download_file | ||
|
||
url = "https://example.com/test.txt" | ||
test_content = b"test content" | ||
dest_path = tmp_path / "test.txt" | ||
|
||
# Mock successful request | ||
mock_response = requests.Response() | ||
mock_response.status_code = 200 | ||
mock_response._content = test_content # noqa: SLF001 | ||
|
||
with patch("requests.get", return_value=mock_response): | ||
download_file(str(dest_path), url) | ||
assert dest_path.read_bytes() == test_content | ||
|
||
# Mock failed request | ||
mock_response = requests.Response() | ||
mock_response.status_code = 404 | ||
mock_response._content = b"Not found" # noqa: SLF001 | ||
|
||
with patch("requests.get", return_value=mock_response): | ||
download_file(str(dest_path), url) # Should print error but not raise | ||
|
||
stdout, stderr = capsys.readouterr() | ||
assert f"Error downloading {url=}" in stdout | ||
assert stderr == "" | ||
|
||
|
||
def test_maybe_auto_download_file( | ||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture | ||
) -> None: | ||
"""Test auto-download behavior of maybe_auto_download_file function.""" | ||
url = "https://example.com/file.txt" | ||
abs_path = f"{tmp_path}/test/file.txt" | ||
os.makedirs(os.path.dirname(abs_path), exist_ok=True) | ||
|
||
# Mock successful request | ||
mock_response = requests.Response() | ||
mock_response.status_code = 200 | ||
mock_response._content = b"test content" # noqa: SLF001 | ||
|
||
# Test 1: Auto-download enabled (default) | ||
monkeypatch.setenv("MBD_AUTO_DOWNLOAD_FILES", "true") | ||
with patch("requests.get", return_value=mock_response): | ||
maybe_auto_download_file(url, abs_path, label="test") | ||
stdout, _ = capsys.readouterr() | ||
assert f"Downloading 'test' from {url!r}" in stdout | ||
assert os.path.isfile(abs_path) | ||
|
||
# Test 2: Auto-download disabled | ||
os.remove(abs_path) | ||
monkeypatch.setenv("MBD_AUTO_DOWNLOAD_FILES", "false") | ||
assert not os.path.isfile(abs_path) | ||
|
||
# Mock user input 'n' to skip download | ||
with ( | ||
patch("requests.get", return_value=mock_response), | ||
patch("builtins.input", return_value="n"), | ||
patch("sys.stdin.isatty", return_value=True), # force interactive mode | ||
): | ||
maybe_auto_download_file(url, abs_path, label="test") | ||
assert not os.path.isfile(abs_path) | ||
|
||
# Test 3: Auto-download disabled but user confirms | ||
with ( | ||
patch("requests.get", return_value=mock_response), | ||
patch("builtins.input", return_value="y"), | ||
patch("sys.stdin.isatty", return_value=True), # force interactive mode | ||
): | ||
maybe_auto_download_file(url, abs_path, label="test") | ||
stdout, _ = capsys.readouterr() | ||
assert f"Downloading 'test' from {url!r}" in stdout | ||
assert os.path.isfile(abs_path) | ||
|
||
# Test 4: File already exists (no download attempt) | ||
with patch("requests.get") as mock_get: | ||
maybe_auto_download_file(url, abs_path, label="test") | ||
mock_get.assert_not_called() | ||
|
||
# Test 5: Non-interactive session (auto-download) | ||
os.remove(abs_path) | ||
with ( | ||
patch("requests.get", return_value=mock_response), | ||
patch("sys.stdin.isatty", return_value=False), | ||
): | ||
maybe_auto_download_file(url, abs_path, label="test") | ||
stdout, _ = capsys.readouterr() | ||
assert f"Downloading 'test' from {url!r}" in stdout | ||
assert os.path.isfile(abs_path) | ||
|
||
# Test 6: IPython session with auto-download disabled | ||
os.remove(abs_path) | ||
with ( | ||
patch("requests.get", return_value=mock_response), | ||
patch("builtins.input", return_value="n"), | ||
patch("sys.stdin.isatty", return_value=True), # force interactive mode | ||
): | ||
maybe_auto_download_file(url, abs_path, label="test") | ||
assert not os.path.isfile(abs_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.