From dd13a2977b763598f17b94852873da97698225b4 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sun, 12 May 2024 17:33:23 +0900 Subject: [PATCH] WIP --- pyathena/filesystem/s3.py | 231 +++++++++++++++++++++++--------------- 1 file changed, 139 insertions(+), 92 deletions(-) diff --git a/pyathena/filesystem/s3.py b/pyathena/filesystem/s3.py index 8a062570..acbd1b70 100644 --- a/pyathena/filesystem/s3.py +++ b/pyathena/filesystem/s3.py @@ -125,11 +125,11 @@ def _get_client_compatible_with_s3fs(self, **kwargs) -> BaseClient: if anon: config_kwargs.update({"signature_version": UNSIGNED}) else: - creds = { - "aws_access_key_id": kwargs.pop("key", kwargs.pop("username", None)), - "aws_secret_access_key": kwargs.pop("secret", kwargs.pop("password", None)), - "aws_session_token": kwargs.pop("token", None), - } + creds = dict( + aws_access_key_id=kwargs.pop("key", kwargs.pop("username", None)), + aws_secret_access_key=kwargs.pop("secret", kwargs.pop("password", None)), + aws_session_token=kwargs.pop("token", None), + ) kwargs.update(**creds) client_kwargs.update(**creds) @@ -148,7 +148,8 @@ def parse_path(path: str) -> Tuple[str, Optional[str], Optional[str]]: match = S3FileSystem.PATTERN_PATH.search(path) if match: return match.group("bucket"), match.group("key"), match.group("version_id") - raise ValueError(f"Invalid S3 path format {path}.") + else: + raise ValueError(f"Invalid S3 path format {path}.") def _head_bucket(self, bucket, refresh: bool = False) -> Optional[S3Object]: if bucket not in self.dircache or refresh: @@ -173,6 +174,7 @@ def _head_bucket(self, bucket, refresh: bool = False) -> Optional[S3Object]: bucket=bucket, key=None, version_id=None, + delimiter=None, ) self.dircache[bucket] = file else: @@ -206,6 +208,7 @@ def _head_object( bucket=bucket, key=key, version_id=version_id, + delimiter=None, ) self.dircache[path] = file else: @@ -230,6 +233,7 @@ def _ls_buckets(self, refresh: bool = False) -> List[S3Object]: bucket=b["Name"], key=None, version_id=None, + delimiter=None, ) for b in response["Buckets"] ] @@ -250,55 +254,63 @@ def _ls_dirs( bucket, key, version_id = self.parse_path(path) if key: prefix = f"{key}/{prefix if prefix else ''}" - if path not in self.dircache or refresh: - files: List[S3Object] = [] - while True: - request: Dict[Any, Any] = { - "Bucket": bucket, - "Prefix": prefix, - "Delimiter": delimiter, - } - if next_token: - request.update({"ContinuationToken": next_token}) - if max_keys: - request.update({"MaxKeys": max_keys}) - response = self._call( - self._client.list_objects_v2, - **request, - ) - files.extend( - S3Object( - init={ - "ContentLength": 0, - "ContentType": None, - "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, - "ETag": None, - "LastModified": None, - }, - type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, - bucket=bucket, - key=c["Prefix"][:-1].rstrip("/"), - version_id=version_id, - ) - for c in response.get("CommonPrefixes", []) + + if path in self.dircache and not refresh: + cache = self.dircache[path] + if not isinstance(cache, list): + caches = [cache] + else: + caches = cache + if all([f.delimiter == delimiter for f in caches]): + return caches + + files: List[S3Object] = [] + while True: + request: Dict[Any, Any] = { + "Bucket": bucket, + "Prefix": prefix, + "Delimiter": delimiter, + } + if next_token: + request.update({"ContinuationToken": next_token}) + if max_keys: + request.update({"MaxKeys": max_keys}) + response = self._call( + self._client.list_objects_v2, + **request, + ) + files.extend( + S3Object( + init={ + "ContentLength": 0, + "ContentType": None, + "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, + "ETag": None, + "LastModified": None, + }, + type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, + bucket=bucket, + key=c["Prefix"][:-1].rstrip("/"), + version_id=version_id, + delimiter=delimiter, ) - files.extend( - S3Object( - init=c, - type=S3ObjectType.S3_OBJECT_TYPE_FILE, - bucket=bucket, - key=c["Key"], - ) - for c in response.get("Contents", []) + for c in response.get("CommonPrefixes", []) + ) + files.extend( + S3Object( + init=c, + type=S3ObjectType.S3_OBJECT_TYPE_FILE, + bucket=bucket, + key=c["Key"], + delimiter=delimiter, ) - next_token = response.get("NextContinuationToken") - if not next_token: - break - if files: - self.dircache[path] = files - else: - cache = self.dircache[path] - files = cache if isinstance(cache, list) else [cache] + for c in response.get("Contents", []) + ) + next_token = response.get("NextContinuationToken") + if not next_token: + break + if files: + self.dircache[path] = files return files def ls( @@ -313,7 +325,7 @@ def ls( file = self._head_object(path, refresh=refresh) if file: files = [file] - return list(files) if detail else [f.name for f in files] + return [f for f in files] if detail else [f.name for f in files] def info(self, path: str, **kwargs) -> S3Object: refresh = kwargs.pop("refresh", False) @@ -333,6 +345,7 @@ def info(self, path: str, **kwargs) -> S3Object: bucket=bucket, key=None, version_id=None, + delimiter=None, ) if not refresh: caches: Union[List[S3Object], S3Object] = self._ls_from_cache(path) @@ -346,19 +359,21 @@ def info(self, path: str, **kwargs) -> S3Object: if cache: return cache - return S3Object( - init={ - "ContentLength": 0, - "ContentType": None, - "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, - "ETag": None, - "LastModified": None, - }, - type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, - bucket=bucket, - key=key.rstrip("/") if key else None, - version_id=version_id, - ) + else: + return S3Object( + init={ + "ContentLength": 0, + "ContentType": None, + "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, + "ETag": None, + "LastModified": None, + }, + type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, + bucket=bucket, + key=key.rstrip("/") if key else None, + version_id=version_id, + delimiter=None, + ) if key: object_info = self._head_object(path, refresh=refresh, version_id=version_id) if object_info: @@ -367,7 +382,8 @@ def info(self, path: str, **kwargs) -> S3Object: bucket_info = self._head_bucket(path, refresh=refresh) if bucket_info: return bucket_info - raise FileNotFoundError(path) + else: + raise FileNotFoundError(path) response = self._call( self._client.list_objects_v2, @@ -393,23 +409,31 @@ def info(self, path: str, **kwargs) -> S3Object: bucket=bucket, key=key.rstrip("/") if key else None, version_id=version_id, + delimiter=None, ) - raise FileNotFoundError(path) + else: + raise FileNotFoundError(path) - def find( + def _find( self, path: str, maxdepth: Optional[int] = None, withdirs: Optional[bool] = None, - detail: bool = False, **kwargs, - ) -> Union[Dict[str, S3Object], List[str]]: - # TODO: Support maxdepth and withdirs + ) -> List[S3Object]: path = self._strip_protocol(path) if path in ["", "/"]: raise ValueError("Cannot traverse all files in S3.") bucket, key, _ = self.parse_path(path) prefix = kwargs.pop("prefix", "") + if maxdepth: + return super().find( + path=path, + maxdepth=maxdepth, + withdirs=withdirs, + detail=True, + **kwargs + ).values() files = self._ls_dirs(path, prefix=prefix, delimiter="") if not files and key: @@ -417,9 +441,22 @@ def find( files = [self.info(path)] except FileNotFoundError: files = [] + return files + + def find( + self, + path: str, + maxdepth: Optional[int] = None, + withdirs: Optional[bool] = None, + detail: bool = False, + **kwargs, + ) -> Union[Dict[str, S3Object], List[str]]: + # TODO: Support withdirs + files = self._find(path=path, maxdepth=maxdepth, withdirs=withdirs, **kwargs) if detail: return {f.name: f for f in files} - return [f.name for f in files] + else: + return [f.name for f in files] def exists(self, path: str, **kwargs) -> bool: path = self._strip_protocol(path) @@ -432,7 +469,10 @@ def exists(self, path: str, **kwargs) -> bool: if self._ls_from_cache(path): return True info = self.info(path) - return bool(info) + if info: + return True + else: + return False except FileNotFoundError: return False elif self.dircache.get(bucket, False): @@ -444,7 +484,10 @@ def exists(self, path: str, **kwargs) -> bool: except FileNotFoundError: pass file = self._head_bucket(bucket) - return bool(file) + if file: + return True + else: + return False def rm_file(self, path: str, **kwargs) -> None: bucket, key, version_id = self.parse_path(path) @@ -711,13 +754,11 @@ def put_file(self, lpath: str, rpath: str, callback=_DEFAULT_CALLBACK, **kwargs) if content_type is not None: kwargs["ContentType"] = content_type - with ( - self.open(rpath, "wb", s3_additional_kwargs=kwargs) as remote, - open(lpath, "rb") as local, - ): - while data := local.read(remote.blocksize): - remote.write(data) - callback.relative_update(len(data)) + with self.open(rpath, "wb", s3_additional_kwargs=kwargs) as remote: + with open(lpath, "rb") as local: + while data := local.read(remote.blocksize): + remote.write(data) + callback.relative_update(len(data)) self.invalidate_cache(rpath) @@ -725,18 +766,20 @@ def get_file(self, rpath: str, lpath: str, callback=_DEFAULT_CALLBACK, outfile=N if os.path.isdir(lpath): return - with open(lpath, "wb") as local, self.open(rpath, "rb", **kwargs) as remote: - callback.set_size(remote.size) - while data := remote.read(remote.blocksize): - local.write(data) - callback.relative_update(len(data)) + with open(lpath, "wb") as local: + with self.open(rpath, "rb", **kwargs) as remote: + callback.set_size(remote.size) + while data := remote.read(remote.blocksize): + local.write(data) + callback.relative_update(len(data)) def checksum(self, path: str, **kwargs): refresh = kwargs.pop("refresh", False) info = self.info(path, refresh=refresh) if info.get("type") != S3ObjectType.S3_OBJECT_TYPE_DIRECTORY: return int(info.get("etag").strip('"').split("-")[0], 16) - return int(tokenize(info), 16) + else: + return int(tokenize(info), 16) def sign(self, path: str, expiration: int = 3600, **kwargs): bucket, key, version_id = self.parse_path(path) @@ -933,7 +976,10 @@ def _complete_multipart_upload( return S3CompleteMultipartUpload(response) def _call(self, method: Union[str, Callable[..., Any]], **kwargs) -> Dict[str, Any]: - func = getattr(self._client, method) if isinstance(method, str) else method + if isinstance(method, str): + func = getattr(self._client, method) + else: + func = method response = retry_api_call( func, config=self._retry_config, logger=_logger, **kwargs, **self.request_kwargs ) @@ -1218,8 +1264,9 @@ def _get_ranges( if range_end > end: ranges.append((range_start, end)) break - ranges.append((range_start, range_end)) - range_start += worker_block_size + else: + ranges.append((range_start, range_end)) + range_start += worker_block_size else: ranges.append((start, end)) return ranges