Skip to content

Commit

Permalink
Simplify datachain.lib.listing by reusing Cilent.scandir()
Browse files Browse the repository at this point in the history
Fixes #373
  • Loading branch information
rlamy committed Aug 30, 2024
1 parent ed71afa commit 315816e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 93 deletions.
99 changes: 6 additions & 93 deletions src/datachain/lib/listing.py
Original file line number Diff line number Diff line change
@@ -1,103 +1,14 @@
import asyncio
from collections.abc import AsyncIterator, Iterator, Sequence
from typing import Callable, Optional
from collections.abc import Iterator
from typing import Callable

from botocore.exceptions import ClientError
from fsspec.asyn import get_loop

from datachain.asyn import iter_over_async
from datachain.client import Client
from datachain.error import ClientError as DataChainClientError
from datachain.lib.file import File

ResultQueue = asyncio.Queue[Optional[Sequence[File]]]

DELIMITER = "/" # Path delimiter
FETCH_WORKERS = 100


async def _fetch_dir(client, prefix, result_queue) -> set[str]:
path = f"{client.name}/{prefix}"
infos = await client.ls_dir(path)
files = []
subdirs = set()
for info in infos:
full_path = info["name"]
subprefix = client.rel_path(full_path)
if prefix.strip(DELIMITER) == subprefix.strip(DELIMITER):
continue
if info["type"] == "directory":
subdirs.add(subprefix)
else:
files.append(client.info_to_file(info, subprefix))
if files:
await result_queue.put(files)
return subdirs


async def _fetch(
client, start_prefix: str, result_queue: ResultQueue, fetch_workers
) -> None:
loop = get_loop()

queue: asyncio.Queue[str] = asyncio.Queue()
queue.put_nowait(start_prefix)

async def process(queue) -> None:
while True:
prefix = await queue.get()
try:
subdirs = await _fetch_dir(client, prefix, result_queue)
for subdir in subdirs:
queue.put_nowait(subdir)
except Exception:
while not queue.empty():
queue.get_nowait()
queue.task_done()
raise

finally:
queue.task_done()

try:
workers: list[asyncio.Task] = [
loop.create_task(process(queue)) for _ in range(fetch_workers)
]

# Wait for all fetch tasks to complete
await queue.join()
# Stop the workers
excs = []
for worker in workers:
if worker.done() and (exc := worker.exception()):
excs.append(exc)
else:
worker.cancel()
if excs:
raise excs[0]
except ClientError as exc:
raise DataChainClientError(
exc.response.get("Error", {}).get("Message") or exc,
exc.response.get("Error", {}).get("Code"),
) from exc
finally:
# This ensures the progress bar is closed before any exceptions are raised
result_queue.put_nowait(None)


async def _scandir(client, prefix, fetch_workers) -> AsyncIterator:
"""Recursively goes through dir tree and yields files"""
result_queue: ResultQueue = asyncio.Queue()
loop = get_loop()
main_task = loop.create_task(_fetch(client, prefix, result_queue, fetch_workers))
while (files := await result_queue.get()) is not None:
for f in files:
yield f

await main_task


def list_bucket(uri: str, client_config=None, fetch_workers=FETCH_WORKERS) -> Callable:
def list_bucket(uri: str, client_config=None) -> Callable:
"""
Function that returns another generator function that yields File objects
from bucket where each File represents one bucket entry.
Expand All @@ -106,6 +17,8 @@ def list_bucket(uri: str, client_config=None, fetch_workers=FETCH_WORKERS) -> Ca
def list_func() -> Iterator[File]:
config = client_config or {}
client, path = Client.parse_url(uri, None, **config) # type: ignore[arg-type]
yield from iter_over_async(_scandir(client, path, fetch_workers), get_loop())
for entries in iter_over_async(client.scandir(path), get_loop()):
for entry in entries:
yield entry.to_file(client.uri)

return list_func
13 changes: 13 additions & 0 deletions src/datachain/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import attrs

from datachain.cache import UniqueId
from datachain.lib.file import File
from datachain.storage import StorageURI
from datachain.utils import TIME_ZERO, time_to_str

Expand Down Expand Up @@ -189,6 +190,18 @@ def parent(self):
return ""
return split[0]

def to_file(self, source: str) -> File:
return File(
source=source,
path=self.path,
size=self.size,
version=self.version,
etag=self.etag,
is_latest=self.is_latest,
last_modified=self.last_modified,
location=self.location,
)


def get_path(parent: str, name: str):
return f"{parent}/{name}" if parent else name
Expand Down

0 comments on commit 315816e

Please sign in to comment.