From 118642d64dde24fa3338da709d9f28659629ce7b Mon Sep 17 00:00:00 2001 From: Dean Wyatte Date: Sun, 26 Feb 2023 16:13:52 -0700 Subject: [PATCH 01/10] support cloud storage in load_dataset via fsspec --- setup.py | 1 + src/datasets/utils/file_utils.py | 30 ++++++++++++++++++++++++++++-- tests/test_file_utils.py | 20 +++++++++++++++++++- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 4c358d279cf..6f4f1643313 100644 --- a/setup.py +++ b/setup.py @@ -230,6 +230,7 @@ "tensorflow_gpu": ["tensorflow-gpu>=2.2.0,!=2.6.0,!=2.6.1"], "torch": ["torch"], "jax": ["jax>=0.2.8,!=0.3.2,<=0.3.25", "jaxlib>=0.1.65,<=0.3.25"], + "gcsfs": ["gcsfs"], "s3": ["s3fs"], "streaming": [], # for backward compatibility "dev": TESTS_REQUIRE + QUALITY_REQUIRE + DOCS_REQUIRE, diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 48242e8a4a1..26de041fd8c 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -22,6 +22,7 @@ from typing import List, Optional, Type, TypeVar, Union from urllib.parse import urljoin, urlparse +import fsspec import huggingface_hub import requests from huggingface_hub import HfFolder @@ -327,6 +328,23 @@ def _request_with_retry( return response +def fsspec_head(url, timeout=10.0): + _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") + try: + fsspec.filesystem(urlparse(url).scheme).info(url, timeout=timeout) + except Exception: + return False + return True + + +def fsspec_get(url, temp_file, timeout=10.0): + _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") + try: + fsspec.filesystem(urlparse(url).scheme).get(url, temp_file, timeout=timeout) + except fsspec.FSTimeoutError as e: + raise ConnectionError(e) from None + + def ftp_head(url, timeout=10.0): _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") try: @@ -400,6 +418,8 @@ def http_head( def request_etag(url: str, use_auth_token: Optional[Union[str, bool]] = None) -> Optional[str]: + if urlparse(url).scheme not in ("http", "https"): + return None headers = get_authentication_headers_for_url(url, use_auth_token=use_auth_token) response = http_head(url, headers=headers, max_retries=3) response.raise_for_status() @@ -453,6 +473,7 @@ def get_from_cache( cookies = None etag = None head_error = None + scheme = None # Try a first time to file the file on the local file system without eTag (None) # if we don't ask for 'force_download' then we spare a request @@ -469,8 +490,11 @@ def get_from_cache( # We don't have the file locally or we need an eTag if not local_files_only: - if url.startswith("ftp://"): + scheme = urlparse(url).scheme + if scheme == "ftp": connected = ftp_head(url) + elif scheme in ("s3", "gs"): + connected = fsspec_head(url) try: response = http_head( url, @@ -569,8 +593,10 @@ def _resumable_file_manager(): logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}") # GET file object - if url.startswith("ftp://"): + if scheme == "ftp": ftp_get(url, temp_file) + elif scheme in ("gs", "s3"): + fsspec_get(url, temp_file) else: http_get( url, diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index 04b37e5eb40..09f3eeb4f7d 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -6,7 +6,16 @@ import zstandard as zstd from datasets.download.download_config import DownloadConfig -from datasets.utils.file_utils import OfflineModeIsEnabled, cached_path, ftp_get, ftp_head, http_get, http_head +from datasets.utils.file_utils import ( + OfflineModeIsEnabled, + cached_path, + fsspec_get, + fsspec_head, + ftp_get, + ftp_head, + http_get, + http_head, +) FILE_CONTENT = """\ @@ -102,3 +111,12 @@ def test_ftp_offline(tmp_path_factory): ftp_get("ftp://huggingface.co", temp_file=filename) with pytest.raises(OfflineModeIsEnabled): ftp_head("ftp://huggingface.co") + + +@patch("datasets.config.HF_DATASETS_OFFLINE", True) +def test_fsspec_offline(tmp_path_factory): + filename = tmp_path_factory.mktemp("data") / "file.html" + with pytest.raises(OfflineModeIsEnabled): + fsspec_get("s3://huggingface.co", temp_file=filename) + with pytest.raises(OfflineModeIsEnabled): + fsspec_head("s3://huggingface.co") From 2869af6efcc10aa8d081764d79f5b45eff968e9d Mon Sep 17 00:00:00 2001 From: Dean Wyatte Date: Tue, 28 Feb 2023 08:48:21 -0700 Subject: [PATCH 02/10] fsspec get uses tqdm, tries to handle additional protocols, and computes pseudo etag from head response --- src/datasets/utils/file_utils.py | 36 +++++++++++++++++++------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 26de041fd8c..406dbef3131 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -330,19 +330,24 @@ def _request_with_retry( def fsspec_head(url, timeout=10.0): _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") - try: - fsspec.filesystem(urlparse(url).scheme).info(url, timeout=timeout) - except Exception: - return False - return True + fs, _, paths = fsspec.get_fs_token_paths(url) + if len(paths) > 1: + raise ValueError("HEAD can be called with at most one path but was called with {paths}") + return fs.info(paths[0], timeout=timeout) -def fsspec_get(url, temp_file, timeout=10.0): +def fsspec_get(url, temp_file, timeout=10.0, desc=None): _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") - try: - fsspec.filesystem(urlparse(url).scheme).get(url, temp_file, timeout=timeout) - except fsspec.FSTimeoutError as e: - raise ConnectionError(e) from None + fs, _, paths = fsspec.get_fs_token_paths(url) + if len(paths) > 1: + raise ValueError("GET can be called with at most one path but was called with {paths}") + callback = fsspec.callbacks.TqdmCallback( + tqdm_kwargs={ + "desc": desc or "Downloading", + "disable": logging.is_progress_bar_enabled(), + } + ) + fs.get(paths[0], temp_file, timeout=timeout, callback=callback) def ftp_head(url, timeout=10.0): @@ -493,8 +498,11 @@ def get_from_cache( scheme = urlparse(url).scheme if scheme == "ftp": connected = ftp_head(url) - elif scheme in ("s3", "gs"): - connected = fsspec_head(url) + elif scheme not in ("http", "https"): + response = fsspec_head(url) + # use the hash of the response as a pseudo ETag to detect changes + etag = json.dumps(response, sort_keys=True) if use_etag else None + connected = True try: response = http_head( url, @@ -595,8 +603,8 @@ def _resumable_file_manager(): # GET file object if scheme == "ftp": ftp_get(url, temp_file) - elif scheme in ("gs", "s3"): - fsspec_get(url, temp_file) + elif scheme not in ("http", "https"): + fsspec_get(url, temp_file, desc=download_desc) else: http_get( url, From a9e058e12bf908741a4e1e2cfb94fd4387251f6d Mon Sep 17 00:00:00 2001 From: Dean Wyatte Date: Tue, 28 Feb 2023 08:50:52 -0700 Subject: [PATCH 03/10] Update setup.py --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6f4f1643313..7bb04e7c811 100644 --- a/setup.py +++ b/setup.py @@ -230,7 +230,8 @@ "tensorflow_gpu": ["tensorflow-gpu>=2.2.0,!=2.6.0,!=2.6.1"], "torch": ["torch"], "jax": ["jax>=0.2.8,!=0.3.2,<=0.3.25", "jaxlib>=0.1.65,<=0.3.25"], - "gcsfs": ["gcsfs"], + "gcs": ["gcsfs"], + "gs": ["gcsfs"], "s3": ["s3fs"], "streaming": [], # for backward compatibility "dev": TESTS_REQUIRE + QUALITY_REQUIRE + DOCS_REQUIRE, From c5d74b1f98d27dfff8c496444b5803d4a2dbbadf Mon Sep 17 00:00:00 2001 From: Dean Wyatte Date: Wed, 1 Mar 2023 20:44:08 -0700 Subject: [PATCH 04/10] add test --- src/datasets/utils/file_utils.py | 12 ++++++------ tests/fixtures/fsspec.py | 4 ++++ tests/test_file_utils.py | 21 ++++++++++++++++++++- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 406dbef3131..d78df5d1a8b 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -330,24 +330,24 @@ def _request_with_retry( def fsspec_head(url, timeout=10.0): _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") - fs, _, paths = fsspec.get_fs_token_paths(url) + fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout}) if len(paths) > 1: - raise ValueError("HEAD can be called with at most one path but was called with {paths}") - return fs.info(paths[0], timeout=timeout) + raise ValueError(f"HEAD can be called with at most one path but was called with {paths}") + return fs.info(paths[0]) def fsspec_get(url, temp_file, timeout=10.0, desc=None): _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") - fs, _, paths = fsspec.get_fs_token_paths(url) + fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout}) if len(paths) > 1: - raise ValueError("GET can be called with at most one path but was called with {paths}") + raise ValueError(f"GET can be called with at most one path but was called with {paths}") callback = fsspec.callbacks.TqdmCallback( tqdm_kwargs={ "desc": desc or "Downloading", "disable": logging.is_progress_bar_enabled(), } ) - fs.get(paths[0], temp_file, timeout=timeout, callback=callback) + fs.get_file(paths[0], temp_file.name, callback=callback) def ftp_head(url, timeout=10.0): diff --git a/tests/fixtures/fsspec.py b/tests/fixtures/fsspec.py index be49dd0bdeb..e7b653c7f5e 100644 --- a/tests/fixtures/fsspec.py +++ b/tests/fixtures/fsspec.py @@ -40,6 +40,10 @@ def info(self, path, *args, **kwargs): out["name"] = out["name"][len(self.local_root_dir) :] return out + def get_file(self, rpath, lpath, *args, **kwargs): + rpath = posixpath.join(self.local_root_dir, self._strip_protocol(rpath)) + return self._fs.get_file(rpath, lpath, *args, **kwargs) + def cp_file(self, path1, path2, *args, **kwargs): path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1)) path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2)) diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index 09f3eeb4f7d..be0992460b8 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -13,6 +13,7 @@ fsspec_head, ftp_get, ftp_head, + get_from_cache, http_get, http_head, ) @@ -22,16 +23,25 @@ Text data. Second line of data.""" +FILE_PATH = "file" + @pytest.fixture(scope="session") def zstd_path(tmp_path_factory): - path = tmp_path_factory.mktemp("data") / "file.zstd" + path = tmp_path_factory.mktemp("data") / FILE_PATH data = bytes(FILE_CONTENT, "utf-8") with zstd.open(path, "wb") as f: f.write(data) return path +@pytest.fixture +def mockfs_file(mockfs): + with open(os.path.join(mockfs.local_root_dir, FILE_PATH), "w") as f: + f.write(FILE_CONTENT) + return mockfs + + @pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"]) def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file): input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path} @@ -89,6 +99,15 @@ def test_cached_path_missing_local(tmp_path): cached_path(missing_file) +def test_get_from_cache_fsspec(mockfs_file): + with patch("datasets.utils.file_utils.fsspec.get_fs_token_paths") as mock_get_fs_token_paths: + mock_get_fs_token_paths.return_value = (mockfs_file, "", [FILE_PATH]) + output_path = get_from_cache("mock://huggingface.co") + with open(output_path) as f: + output_file_content = f.read() + assert output_file_content == FILE_CONTENT + + @patch("datasets.config.HF_DATASETS_OFFLINE", True) def test_cached_path_offline(): with pytest.raises(OfflineModeIsEnabled): From 9eae6d43f6c168db4f8880ea8ba41413834fd26a Mon Sep 17 00:00:00 2001 From: Dean Wyatte <2512762+dwyatte@users.noreply.github.com> Date: Wed, 1 Mar 2023 21:11:28 -0700 Subject: [PATCH 05/10] Update setup.py Co-authored-by: Alvaro Bartolome --- setup.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/setup.py b/setup.py index 7bb04e7c811..4c358d279cf 100644 --- a/setup.py +++ b/setup.py @@ -230,8 +230,6 @@ "tensorflow_gpu": ["tensorflow-gpu>=2.2.0,!=2.6.0,!=2.6.1"], "torch": ["torch"], "jax": ["jax>=0.2.8,!=0.3.2,<=0.3.25", "jaxlib>=0.1.65,<=0.3.25"], - "gcs": ["gcsfs"], - "gs": ["gcsfs"], "s3": ["s3fs"], "streaming": [], # for backward compatibility "dev": TESTS_REQUIRE + QUALITY_REQUIRE + DOCS_REQUIRE, From c37215a4371eb9d70175ae37e7154e08932da1fc Mon Sep 17 00:00:00 2001 From: Dean Wyatte <2512762+dwyatte@users.noreply.github.com> Date: Thu, 2 Mar 2023 07:28:47 -0700 Subject: [PATCH 06/10] Update tests/test_file_utils.py Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- tests/test_file_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index be0992460b8..64b0583e13e 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -28,7 +28,7 @@ @pytest.fixture(scope="session") def zstd_path(tmp_path_factory): - path = tmp_path_factory.mktemp("data") / FILE_PATH + path = tmp_path_factory.mktemp("data") / (FILE_PATH + ".zstd") data = bytes(FILE_CONTENT, "utf-8") with zstd.open(path, "wb") as f: f.write(data) From 341bda52058a4adefbc26f5785e53194495ecd20 Mon Sep 17 00:00:00 2001 From: Dean Wyatte Date: Sun, 5 Mar 2023 10:13:12 -0700 Subject: [PATCH 07/10] add tmpfs and use to test fsspec in get_from_cache --- tests/fixtures/fsspec.py | 29 +++++++++++++++++++++++++---- tests/test_file_utils.py | 18 ++++++++---------- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/tests/fixtures/fsspec.py b/tests/fixtures/fsspec.py index e7b653c7f5e..8aaa181a77e 100644 --- a/tests/fixtures/fsspec.py +++ b/tests/fixtures/fsspec.py @@ -1,5 +1,6 @@ import posixpath from pathlib import Path +from unittest.mock import patch import fsspec import pytest @@ -40,10 +41,6 @@ def info(self, path, *args, **kwargs): out["name"] = out["name"][len(self.local_root_dir) :] return out - def get_file(self, rpath, lpath, *args, **kwargs): - rpath = posixpath.join(self.local_root_dir, self._strip_protocol(rpath)) - return self._fs.get_file(rpath, lpath, *args, **kwargs) - def cp_file(self, path1, path2, *args, **kwargs): path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1)) path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2)) @@ -77,10 +74,27 @@ def _strip_protocol(cls, path): return path +class TmpDirFileSystem(MockFileSystem): + protocol = "tmp" + tmp_dir = None + + def __init__(self, *args, **kwargs): + assert self.tmp_dir is not None, "TmpDirFileSystem.tmp_dir is not set" + super().__init__(*args, **kwargs, local_root_dir=self.tmp_dir, auto_mkdir=True) + + @classmethod + def _strip_protocol(cls, path): + path = stringify_path(path) + if path.startswith("tmp://"): + path = path[6:] + return path + + @pytest.fixture def mock_fsspec(): original_registry = fsspec.registry.copy() fsspec.register_implementation("mock", MockFileSystem) + fsspec.register_implementation("tmp", TmpDirFileSystem) yield fsspec.registry = original_registry @@ -89,3 +103,10 @@ def mock_fsspec(): def mockfs(tmp_path_factory, mock_fsspec): local_fs_dir = tmp_path_factory.mktemp("mockfs") return MockFileSystem(local_root_dir=local_fs_dir, auto_mkdir=True) + + +@pytest.fixture +def tmpfs(tmp_path_factory, mock_fsspec): + tmp_fs_dir = tmp_path_factory.mktemp("tmpfs") + with patch.object(TmpDirFileSystem, "tmp_dir", tmp_fs_dir): + yield TmpDirFileSystem() diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index 64b0583e13e..a6175c3dd17 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -36,10 +36,10 @@ def zstd_path(tmp_path_factory): @pytest.fixture -def mockfs_file(mockfs): - with open(os.path.join(mockfs.local_root_dir, FILE_PATH), "w") as f: +def tmpfs_file(tmpfs): + with open(os.path.join(tmpfs.local_root_dir, FILE_PATH), "w") as f: f.write(FILE_CONTENT) - return mockfs + return FILE_PATH @pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"]) @@ -99,13 +99,11 @@ def test_cached_path_missing_local(tmp_path): cached_path(missing_file) -def test_get_from_cache_fsspec(mockfs_file): - with patch("datasets.utils.file_utils.fsspec.get_fs_token_paths") as mock_get_fs_token_paths: - mock_get_fs_token_paths.return_value = (mockfs_file, "", [FILE_PATH]) - output_path = get_from_cache("mock://huggingface.co") - with open(output_path) as f: - output_file_content = f.read() - assert output_file_content == FILE_CONTENT +def test_get_from_cache_fsspec(tmpfs_file): + output_path = get_from_cache(f"tmp://{tmpfs_file}") + with open(output_path) as f: + output_file_content = f.read() + assert output_file_content == FILE_CONTENT @patch("datasets.config.HF_DATASETS_OFFLINE", True) From 74c0c45668808cd13dc072e9676991b2ae643a14 Mon Sep 17 00:00:00 2001 From: Dean Wyatte <2512762+dwyatte@users.noreply.github.com> Date: Fri, 10 Mar 2023 14:17:50 -0700 Subject: [PATCH 08/10] Update src/datasets/utils/file_utils.py Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- src/datasets/utils/file_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index d78df5d1a8b..4b9ed454868 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -344,7 +344,7 @@ def fsspec_get(url, temp_file, timeout=10.0, desc=None): callback = fsspec.callbacks.TqdmCallback( tqdm_kwargs={ "desc": desc or "Downloading", - "disable": logging.is_progress_bar_enabled(), + "disable": not logging.is_progress_bar_enabled(), } ) fs.get_file(paths[0], temp_file.name, callback=callback) From b2c958aa845d1f25b948ebf4f3e4391bf5d75cd0 Mon Sep 17 00:00:00 2001 From: Dean Wyatte <2512762+dwyatte@users.noreply.github.com> Date: Fri, 10 Mar 2023 14:18:01 -0700 Subject: [PATCH 09/10] Update src/datasets/utils/file_utils.py Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- src/datasets/utils/file_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 4b9ed454868..00d6c00b218 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -501,7 +501,8 @@ def get_from_cache( elif scheme not in ("http", "https"): response = fsspec_head(url) # use the hash of the response as a pseudo ETag to detect changes - etag = json.dumps(response, sort_keys=True) if use_etag else None + # s3fs uses "ETag", gcsfs uses "etag" + etag = (response.get("ETag", None) or response.get("etag", None)) if use_etag else None connected = True try: response = http_head( From b9b1075c3f73528b79e81856b89a5659ee119448 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Sat, 11 Mar 2023 01:54:06 +0100 Subject: [PATCH 10/10] remove comment --- src/datasets/utils/file_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 00d6c00b218..c2f19b5ce7a 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -500,7 +500,6 @@ def get_from_cache( connected = ftp_head(url) elif scheme not in ("http", "https"): response = fsspec_head(url) - # use the hash of the response as a pseudo ETag to detect changes # s3fs uses "ETag", gcsfs uses "etag" etag = (response.get("ETag", None) or response.get("etag", None)) if use_etag else None connected = True