Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cloud storage in load_dataset via fsspec #5580

Merged
merged 10 commits into from
Mar 11, 2023
38 changes: 36 additions & 2 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import List, Optional, Type, TypeVar, Union
from urllib.parse import urljoin, urlparse

import fsspec
import huggingface_hub
import requests
from huggingface_hub import HfFolder
Expand Down Expand Up @@ -327,6 +328,28 @@ def _request_with_retry(
return response


def fsspec_head(url, timeout=10.0):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout})
if len(paths) > 1:
raise ValueError(f"HEAD can be called with at most one path but was called with {paths}")
return fs.info(paths[0])


def fsspec_get(url, temp_file, timeout=10.0, desc=None):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout})
if len(paths) > 1:
raise ValueError(f"GET can be called with at most one path but was called with {paths}")
callback = fsspec.callbacks.TqdmCallback(
tqdm_kwargs={
"desc": desc or "Downloading",
"disable": not logging.is_progress_bar_enabled(),
}
)
fs.get_file(paths[0], temp_file.name, callback=callback)


def ftp_head(url, timeout=10.0):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
try:
Expand Down Expand Up @@ -400,6 +423,8 @@ def http_head(


def request_etag(url: str, use_auth_token: Optional[Union[str, bool]] = None) -> Optional[str]:
if urlparse(url).scheme not in ("http", "https"):
return None
headers = get_authentication_headers_for_url(url, use_auth_token=use_auth_token)
response = http_head(url, headers=headers, max_retries=3)
response.raise_for_status()
Expand Down Expand Up @@ -453,6 +478,7 @@ def get_from_cache(
cookies = None
etag = None
head_error = None
scheme = None

# Try a first time to file the file on the local file system without eTag (None)
# if we don't ask for 'force_download' then we spare a request
Expand All @@ -469,8 +495,14 @@ def get_from_cache(

# We don't have the file locally or we need an eTag
if not local_files_only:
if url.startswith("ftp://"):
scheme = urlparse(url).scheme
if scheme == "ftp":
connected = ftp_head(url)
elif scheme not in ("http", "https"):
response = fsspec_head(url)
# s3fs uses "ETag", gcsfs uses "etag"
etag = (response.get("ETag", None) or response.get("etag", None)) if use_etag else None
connected = True
try:
response = http_head(
url,
Expand Down Expand Up @@ -569,8 +601,10 @@ def _resumable_file_manager():
logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}")

# GET file object
if url.startswith("ftp://"):
if scheme == "ftp":
ftp_get(url, temp_file)
elif scheme not in ("http", "https"):
fsspec_get(url, temp_file, desc=download_desc)
else:
http_get(
url,
Expand Down
25 changes: 25 additions & 0 deletions tests/fixtures/fsspec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import posixpath
from pathlib import Path
from unittest.mock import patch

import fsspec
import pytest
Expand Down Expand Up @@ -73,10 +74,27 @@ def _strip_protocol(cls, path):
return path


class TmpDirFileSystem(MockFileSystem):
protocol = "tmp"
tmp_dir = None

def __init__(self, *args, **kwargs):
assert self.tmp_dir is not None, "TmpDirFileSystem.tmp_dir is not set"
super().__init__(*args, **kwargs, local_root_dir=self.tmp_dir, auto_mkdir=True)

@classmethod
def _strip_protocol(cls, path):
path = stringify_path(path)
if path.startswith("tmp://"):
path = path[6:]
return path


@pytest.fixture
def mock_fsspec():
original_registry = fsspec.registry.copy()
fsspec.register_implementation("mock", MockFileSystem)
fsspec.register_implementation("tmp", TmpDirFileSystem)
yield
fsspec.registry = original_registry

Expand All @@ -85,3 +103,10 @@ def mock_fsspec():
def mockfs(tmp_path_factory, mock_fsspec):
local_fs_dir = tmp_path_factory.mktemp("mockfs")
return MockFileSystem(local_root_dir=local_fs_dir, auto_mkdir=True)


@pytest.fixture
def tmpfs(tmp_path_factory, mock_fsspec):
tmp_fs_dir = tmp_path_factory.mktemp("tmpfs")
with patch.object(TmpDirFileSystem, "tmp_dir", tmp_fs_dir):
yield TmpDirFileSystem()
39 changes: 37 additions & 2 deletions tests/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,42 @@
import zstandard as zstd

from datasets.download.download_config import DownloadConfig
from datasets.utils.file_utils import OfflineModeIsEnabled, cached_path, ftp_get, ftp_head, http_get, http_head
from datasets.utils.file_utils import (
OfflineModeIsEnabled,
cached_path,
fsspec_get,
fsspec_head,
ftp_get,
ftp_head,
get_from_cache,
http_get,
http_head,
)


FILE_CONTENT = """\
Text data.
Second line of data."""

FILE_PATH = "file"


@pytest.fixture(scope="session")
def zstd_path(tmp_path_factory):
path = tmp_path_factory.mktemp("data") / "file.zstd"
path = tmp_path_factory.mktemp("data") / (FILE_PATH + ".zstd")
data = bytes(FILE_CONTENT, "utf-8")
with zstd.open(path, "wb") as f:
f.write(data)
return path


@pytest.fixture
def tmpfs_file(tmpfs):
with open(os.path.join(tmpfs.local_root_dir, FILE_PATH), "w") as f:
f.write(FILE_CONTENT)
return FILE_PATH


@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])
def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file):
input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path}
Expand Down Expand Up @@ -80,6 +99,13 @@ def test_cached_path_missing_local(tmp_path):
cached_path(missing_file)


def test_get_from_cache_fsspec(tmpfs_file):
output_path = get_from_cache(f"tmp://{tmpfs_file}")
with open(output_path) as f:
output_file_content = f.read()
assert output_file_content == FILE_CONTENT


@patch("datasets.config.HF_DATASETS_OFFLINE", True)
def test_cached_path_offline():
with pytest.raises(OfflineModeIsEnabled):
Expand All @@ -102,3 +128,12 @@ def test_ftp_offline(tmp_path_factory):
ftp_get("ftp://huggingface.co", temp_file=filename)
with pytest.raises(OfflineModeIsEnabled):
ftp_head("ftp://huggingface.co")


@patch("datasets.config.HF_DATASETS_OFFLINE", True)
def test_fsspec_offline(tmp_path_factory):
filename = tmp_path_factory.mktemp("data") / "file.html"
with pytest.raises(OfflineModeIsEnabled):
fsspec_get("s3://huggingface.co", temp_file=filename)
with pytest.raises(OfflineModeIsEnabled):
fsspec_head("s3://huggingface.co")