Skip to content

Commit

Permalink
Improve map, optimize and StreamingDataset (#18912)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
  • Loading branch information
3 people authored Nov 3, 2023
1 parent 809e952 commit f5f4d0a
Show file tree
Hide file tree
Showing 26 changed files with 402 additions and 346 deletions.
2 changes: 1 addition & 1 deletion docs/source-app/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def find_source():
linkcheck_anchors = False

# A timeout value, in seconds, for the linkcheck builder.
linkcheck_timeout = 10
linkcheck_timeout = 60

# ignore all links in any CHANGELOG file
linkcheck_exclude_documents = [r"^(.*\/)*CHANGELOG.*$"]
2 changes: 1 addition & 1 deletion docs/source-fabric/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def find_source():
linkcheck_anchors = False

# A timeout value, in seconds, for the linkcheck builder.
linkcheck_timeout = 10
linkcheck_timeout = 60

# ignore all links in any CHANGELOG file
linkcheck_exclude_documents = [r"^(.*\/)*CHANGELOG.*$"]
2 changes: 1 addition & 1 deletion docs/source-pytorch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def package_list_from_file(file):
linkcheck_anchors = False

# A timeout value, in seconds, for the linkcheck builder.
linkcheck_timeout = 10
linkcheck_timeout = 60

# ignore all links in any CHANGELOG file
linkcheck_exclude_documents = [r"^(.*\/)*CHANGELOG.*$"]
Expand Down
Empty file removed index_1.txt
Empty file.
4 changes: 2 additions & 2 deletions requirements/app/app.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
lightning-cloud ==0.5.46 # Must be pinned to ensure compatibility
lightning-cloud == 0.5.48 # Must be pinned to ensure compatibility
packaging
typing-extensions >=4.0.0, <4.8.0
typing-extensions >=4.4.0, <4.8.0
deepdiff >=5.7.0, <6.6.0
starsessions >=1.2.1, <2.0 # strict
fsspec >=2022.5.0, <2023.10.0
Expand Down
1 change: 1 addition & 0 deletions requirements/app/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ pympler
psutil <5.10.0
setuptools <68.3.0
requests-mock ==1.11.0
pandas
2 changes: 0 additions & 2 deletions requirements/app/ui.txt
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
streamlit >=1.13.0, <1.27.0
panel >=1.0.0, <1.3.0
61 changes: 30 additions & 31 deletions src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,38 @@

import logging
import os
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

from lightning.data.datasets.env import _DistributedEnv
from lightning.data.streaming.constants import (
_INDEX_FILENAME,
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46,
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48,
_TORCH_GREATER_EQUAL_2_1_0,
)
from lightning.data.streaming.item_loader import BaseItemLoader
from lightning.data.streaming.reader import BinaryReader
from lightning.data.streaming.sampler import ChunkedIndex
from lightning.data.streaming.writer import BinaryWriter

if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46:
from lightning_cloud.resolver import _find_remote_dir, _try_create_cache_dir

logger = logging.Logger(__name__)

if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48:
from lightning_cloud.resolver import _resolve_dir


@dataclass
class Dir:
"""Holds a directory path and possibly its associated remote URL."""

path: str
url: Optional[str] = None


class Cache:
def __init__(
self,
cache_dir: Optional[str] = None,
remote_dir: Optional[str] = None,
name: Optional[str] = None,
version: Optional[Union[int, Literal["latest"]]] = "latest",
input_dir: Optional[Union[str, Dir]],
compression: Optional[str] = None,
chunk_size: Optional[int] = None,
chunk_bytes: Optional[Union[int, str]] = None,
Expand All @@ -48,9 +54,7 @@ def __init__(
together in order to accelerate fetching.
Arguments:
cache_dir: The path to where the chunks will be stored.
remote_dir: The path to a remote folder where the data are located.
The scheme needs to be added to the path.
input_dir: The path to where the chunks will be or are stored.
name: The name of dataset in the cloud.
version: The version of the dataset in the cloud to use. By default, we will use the latest.
compression: The name of the algorithm to reduce the size of the chunks.
Expand All @@ -63,25 +67,20 @@ def __init__(
if not _TORCH_GREATER_EQUAL_2_1_0:
raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.")

self._cache_dir = cache_dir = str(cache_dir) if cache_dir else _try_create_cache_dir(name)
if not remote_dir:
remote_dir, has_index_file = _find_remote_dir(name, version)

# When the index exists, we don't care about the chunk_size anymore.
if has_index_file and (chunk_size is None and chunk_bytes is None):
chunk_size = 2

# Add the version to the cache_dir to avoid collisions.
if remote_dir and os.path.basename(remote_dir).startswith("version_"):
cache_dir = os.path.join(cache_dir, os.path.basename(remote_dir))

if cache_dir:
os.makedirs(cache_dir, exist_ok=True)

self._cache_dir = cache_dir

self._writer = BinaryWriter(cache_dir, chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression)
self._reader = BinaryReader(cache_dir, remote_dir=remote_dir, compression=compression, item_loader=item_loader)
if not _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48:
raise ModuleNotFoundError("Lightning Cloud 0.5.48 or higher is required to use the cache.")

input_dir = _resolve_dir(input_dir)
self._cache_dir = input_dir.path
self._writer = BinaryWriter(
self._cache_dir, chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression
)
self._reader = BinaryReader(
self._cache_dir,
remote_input_dir=input_dir.url,
compression=compression,
item_loader=item_loader,
)
self._is_done = False
self._distributed_env = _DistributedEnv.detect()

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# This is required for full pytree serialization / deserialization support
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46 = RequirementCache("lightning-cloud>=0.5.46")
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48 = RequirementCache("lightning-cloud>=0.5.48")
_BOTO3_AVAILABLE = RequirementCache("boto3")

# DON'T CHANGE ORDER
Expand Down
Loading

0 comments on commit f5f4d0a

Please sign in to comment.