Skip to content

Commit aae3f49

Browse files
WIP
1 parent cf67b76 commit aae3f49

File tree

1 file changed

+82
-53
lines changed

1 file changed

+82
-53
lines changed

pyathena/filesystem/s3.py

Lines changed: 82 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def _head_bucket(self, bucket, refresh: bool = False) -> Optional[S3Object]:
174174
bucket=bucket,
175175
key=None,
176176
version_id=None,
177+
delimiter=None,
177178
)
178179
self.dircache[bucket] = file
179180
else:
@@ -207,6 +208,7 @@ def _head_object(
207208
bucket=bucket,
208209
key=key,
209210
version_id=version_id,
211+
delimiter=None,
210212
)
211213
self.dircache[path] = file
212214
else:
@@ -231,6 +233,7 @@ def _ls_buckets(self, refresh: bool = False) -> List[S3Object]:
231233
bucket=b["Name"],
232234
key=None,
233235
version_id=None,
236+
delimiter=None,
234237
)
235238
for b in response["Buckets"]
236239
]
@@ -251,58 +254,63 @@ def _ls_dirs(
251254
bucket, key, version_id = self.parse_path(path)
252255
if key:
253256
prefix = f"{key}/{prefix if prefix else ''}"
254-
if path not in self.dircache or refresh:
255-
files: List[S3Object] = []
256-
while True:
257-
request: Dict[Any, Any] = {
258-
"Bucket": bucket,
259-
"Prefix": prefix,
260-
"Delimiter": delimiter,
261-
}
262-
if next_token:
263-
request.update({"ContinuationToken": next_token})
264-
if max_keys:
265-
request.update({"MaxKeys": max_keys})
266-
response = self._call(
267-
self._client.list_objects_v2,
268-
**request,
269-
)
270-
files.extend(
271-
S3Object(
272-
init={
273-
"ContentLength": 0,
274-
"ContentType": None,
275-
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
276-
"ETag": None,
277-
"LastModified": None,
278-
},
279-
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
280-
bucket=bucket,
281-
key=c["Prefix"][:-1].rstrip("/"),
282-
version_id=version_id,
283-
)
284-
for c in response.get("CommonPrefixes", [])
285-
)
286-
files.extend(
287-
S3Object(
288-
init=c,
289-
type=S3ObjectType.S3_OBJECT_TYPE_FILE,
290-
bucket=bucket,
291-
key=c["Key"],
292-
)
293-
for c in response.get("Contents", [])
294-
)
295-
next_token = response.get("NextContinuationToken")
296-
if not next_token:
297-
break
298-
if files:
299-
self.dircache[path] = files
300-
else:
257+
258+
if path in self.dircache and not refresh:
301259
cache = self.dircache[path]
302260
if not isinstance(cache, list):
303-
files = [cache]
261+
caches = [cache]
304262
else:
305-
files = cache
263+
caches = cache
264+
if all([f.delimiter == delimiter for f in caches]):
265+
return caches
266+
267+
files: List[S3Object] = []
268+
while True:
269+
request: Dict[Any, Any] = {
270+
"Bucket": bucket,
271+
"Prefix": prefix,
272+
"Delimiter": delimiter,
273+
}
274+
if next_token:
275+
request.update({"ContinuationToken": next_token})
276+
if max_keys:
277+
request.update({"MaxKeys": max_keys})
278+
response = self._call(
279+
self._client.list_objects_v2,
280+
**request,
281+
)
282+
files.extend(
283+
S3Object(
284+
init={
285+
"ContentLength": 0,
286+
"ContentType": None,
287+
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
288+
"ETag": None,
289+
"LastModified": None,
290+
},
291+
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
292+
bucket=bucket,
293+
key=c["Prefix"][:-1].rstrip("/"),
294+
version_id=version_id,
295+
delimiter=delimiter,
296+
)
297+
for c in response.get("CommonPrefixes", [])
298+
)
299+
files.extend(
300+
S3Object(
301+
init=c,
302+
type=S3ObjectType.S3_OBJECT_TYPE_FILE,
303+
bucket=bucket,
304+
key=c["Key"],
305+
delimiter=delimiter,
306+
)
307+
for c in response.get("Contents", [])
308+
)
309+
next_token = response.get("NextContinuationToken")
310+
if not next_token:
311+
break
312+
if files:
313+
self.dircache[path] = files
306314
return files
307315

308316
def ls(
@@ -337,6 +345,7 @@ def info(self, path: str, **kwargs) -> S3Object:
337345
bucket=bucket,
338346
key=None,
339347
version_id=None,
348+
delimiter=None,
340349
)
341350
if not refresh:
342351
caches: Union[List[S3Object], S3Object] = self._ls_from_cache(path)
@@ -363,6 +372,7 @@ def info(self, path: str, **kwargs) -> S3Object:
363372
bucket=bucket,
364373
key=key.rstrip("/") if key else None,
365374
version_id=version_id,
375+
delimiter=None,
366376
)
367377
if key:
368378
object_info = self._head_object(path, refresh=refresh, version_id=version_id)
@@ -399,31 +409,50 @@ def info(self, path: str, **kwargs) -> S3Object:
399409
bucket=bucket,
400410
key=key.rstrip("/") if key else None,
401411
version_id=version_id,
412+
delimiter=None,
402413
)
403414
else:
404415
raise FileNotFoundError(path)
405416

406-
def find(
417+
def _find(
407418
self,
408419
path: str,
409420
maxdepth: Optional[int] = None,
410421
withdirs: Optional[bool] = None,
411-
detail: bool = False,
412422
**kwargs,
413-
) -> Union[Dict[str, S3Object], List[str]]:
414-
# TODO: Support maxdepth and withdirs
423+
) -> List[S3Object]:
415424
path = self._strip_protocol(path)
416425
if path in ["", "/"]:
417426
raise ValueError("Cannot traverse all files in S3.")
418427
bucket, key, _ = self.parse_path(path)
419428
prefix = kwargs.pop("prefix", "")
429+
if maxdepth:
430+
return super().find(
431+
path=path,
432+
maxdepth=maxdepth,
433+
withdirs=withdirs,
434+
detail=True,
435+
**kwargs
436+
).values()
420437

421438
files = self._ls_dirs(path, prefix=prefix, delimiter="")
422439
if not files and key:
423440
try:
424441
files = [self.info(path)]
425442
except FileNotFoundError:
426443
files = []
444+
return files
445+
446+
def find(
447+
self,
448+
path: str,
449+
maxdepth: Optional[int] = None,
450+
withdirs: Optional[bool] = None,
451+
detail: bool = False,
452+
**kwargs,
453+
) -> Union[Dict[str, S3Object], List[str]]:
454+
# TODO: Support withdirs
455+
files = self._find(path=path, maxdepth=maxdepth, withdirs=withdirs, **kwargs)
427456
if detail:
428457
return {f.name: f for f in files}
429458
else:

0 commit comments

Comments
 (0)