Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
laughingman7743 committed Dec 29, 2024
1 parent b58ec77 commit dd13a29
Showing 1 changed file with 139 additions and 92 deletions.
231 changes: 139 additions & 92 deletions pyathena/filesystem/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ def _get_client_compatible_with_s3fs(self, **kwargs) -> BaseClient:
if anon:
config_kwargs.update({"signature_version": UNSIGNED})
else:
creds = {
"aws_access_key_id": kwargs.pop("key", kwargs.pop("username", None)),
"aws_secret_access_key": kwargs.pop("secret", kwargs.pop("password", None)),
"aws_session_token": kwargs.pop("token", None),
}
creds = dict(
aws_access_key_id=kwargs.pop("key", kwargs.pop("username", None)),
aws_secret_access_key=kwargs.pop("secret", kwargs.pop("password", None)),
aws_session_token=kwargs.pop("token", None),
)
kwargs.update(**creds)
client_kwargs.update(**creds)

Expand All @@ -148,7 +148,8 @@ def parse_path(path: str) -> Tuple[str, Optional[str], Optional[str]]:
match = S3FileSystem.PATTERN_PATH.search(path)
if match:
return match.group("bucket"), match.group("key"), match.group("version_id")
raise ValueError(f"Invalid S3 path format {path}.")
else:
raise ValueError(f"Invalid S3 path format {path}.")

def _head_bucket(self, bucket, refresh: bool = False) -> Optional[S3Object]:
if bucket not in self.dircache or refresh:
Expand All @@ -173,6 +174,7 @@ def _head_bucket(self, bucket, refresh: bool = False) -> Optional[S3Object]:
bucket=bucket,
key=None,
version_id=None,
delimiter=None,
)
self.dircache[bucket] = file
else:
Expand Down Expand Up @@ -206,6 +208,7 @@ def _head_object(
bucket=bucket,
key=key,
version_id=version_id,
delimiter=None,
)
self.dircache[path] = file
else:
Expand All @@ -230,6 +233,7 @@ def _ls_buckets(self, refresh: bool = False) -> List[S3Object]:
bucket=b["Name"],
key=None,
version_id=None,
delimiter=None,
)
for b in response["Buckets"]
]
Expand All @@ -250,55 +254,63 @@ def _ls_dirs(
bucket, key, version_id = self.parse_path(path)
if key:
prefix = f"{key}/{prefix if prefix else ''}"
if path not in self.dircache or refresh:
files: List[S3Object] = []
while True:
request: Dict[Any, Any] = {
"Bucket": bucket,
"Prefix": prefix,
"Delimiter": delimiter,
}
if next_token:
request.update({"ContinuationToken": next_token})
if max_keys:
request.update({"MaxKeys": max_keys})
response = self._call(
self._client.list_objects_v2,
**request,
)
files.extend(
S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=c["Prefix"][:-1].rstrip("/"),
version_id=version_id,
)
for c in response.get("CommonPrefixes", [])

if path in self.dircache and not refresh:
cache = self.dircache[path]
if not isinstance(cache, list):
caches = [cache]
else:
caches = cache
if all([f.delimiter == delimiter for f in caches]):
return caches

files: List[S3Object] = []
while True:
request: Dict[Any, Any] = {
"Bucket": bucket,
"Prefix": prefix,
"Delimiter": delimiter,
}
if next_token:
request.update({"ContinuationToken": next_token})
if max_keys:
request.update({"MaxKeys": max_keys})
response = self._call(
self._client.list_objects_v2,
**request,
)
files.extend(
S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=c["Prefix"][:-1].rstrip("/"),
version_id=version_id,
delimiter=delimiter,
)
files.extend(
S3Object(
init=c,
type=S3ObjectType.S3_OBJECT_TYPE_FILE,
bucket=bucket,
key=c["Key"],
)
for c in response.get("Contents", [])
for c in response.get("CommonPrefixes", [])
)
files.extend(
S3Object(
init=c,
type=S3ObjectType.S3_OBJECT_TYPE_FILE,
bucket=bucket,
key=c["Key"],
delimiter=delimiter,
)
next_token = response.get("NextContinuationToken")
if not next_token:
break
if files:
self.dircache[path] = files
else:
cache = self.dircache[path]
files = cache if isinstance(cache, list) else [cache]
for c in response.get("Contents", [])
)
next_token = response.get("NextContinuationToken")
if not next_token:
break
if files:
self.dircache[path] = files
return files

def ls(
Expand All @@ -313,7 +325,7 @@ def ls(
file = self._head_object(path, refresh=refresh)
if file:
files = [file]
return list(files) if detail else [f.name for f in files]
return [f for f in files] if detail else [f.name for f in files]

def info(self, path: str, **kwargs) -> S3Object:
refresh = kwargs.pop("refresh", False)
Expand All @@ -333,6 +345,7 @@ def info(self, path: str, **kwargs) -> S3Object:
bucket=bucket,
key=None,
version_id=None,
delimiter=None,
)
if not refresh:
caches: Union[List[S3Object], S3Object] = self._ls_from_cache(path)
Expand All @@ -346,19 +359,21 @@ def info(self, path: str, **kwargs) -> S3Object:

if cache:
return cache
return S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=key.rstrip("/") if key else None,
version_id=version_id,
)
else:
return S3Object(
init={
"ContentLength": 0,
"ContentType": None,
"StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY,
"ETag": None,
"LastModified": None,
},
type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY,
bucket=bucket,
key=key.rstrip("/") if key else None,
version_id=version_id,
delimiter=None,
)
if key:
object_info = self._head_object(path, refresh=refresh, version_id=version_id)
if object_info:
Expand All @@ -367,7 +382,8 @@ def info(self, path: str, **kwargs) -> S3Object:
bucket_info = self._head_bucket(path, refresh=refresh)
if bucket_info:
return bucket_info
raise FileNotFoundError(path)
else:
raise FileNotFoundError(path)

response = self._call(
self._client.list_objects_v2,
Expand All @@ -393,33 +409,54 @@ def info(self, path: str, **kwargs) -> S3Object:
bucket=bucket,
key=key.rstrip("/") if key else None,
version_id=version_id,
delimiter=None,
)
raise FileNotFoundError(path)
else:
raise FileNotFoundError(path)

def find(
def _find(
self,
path: str,
maxdepth: Optional[int] = None,
withdirs: Optional[bool] = None,
detail: bool = False,
**kwargs,
) -> Union[Dict[str, S3Object], List[str]]:
# TODO: Support maxdepth and withdirs
) -> List[S3Object]:
path = self._strip_protocol(path)
if path in ["", "/"]:
raise ValueError("Cannot traverse all files in S3.")
bucket, key, _ = self.parse_path(path)
prefix = kwargs.pop("prefix", "")
if maxdepth:
return super().find(
path=path,
maxdepth=maxdepth,
withdirs=withdirs,
detail=True,
**kwargs
).values()

files = self._ls_dirs(path, prefix=prefix, delimiter="")
if not files and key:
try:
files = [self.info(path)]
except FileNotFoundError:
files = []
return files

def find(
self,
path: str,
maxdepth: Optional[int] = None,
withdirs: Optional[bool] = None,
detail: bool = False,
**kwargs,
) -> Union[Dict[str, S3Object], List[str]]:
# TODO: Support withdirs
files = self._find(path=path, maxdepth=maxdepth, withdirs=withdirs, **kwargs)
if detail:
return {f.name: f for f in files}
return [f.name for f in files]
else:
return [f.name for f in files]

def exists(self, path: str, **kwargs) -> bool:
path = self._strip_protocol(path)
Expand All @@ -432,7 +469,10 @@ def exists(self, path: str, **kwargs) -> bool:
if self._ls_from_cache(path):
return True
info = self.info(path)
return bool(info)
if info:
return True
else:
return False
except FileNotFoundError:
return False
elif self.dircache.get(bucket, False):
Expand All @@ -444,7 +484,10 @@ def exists(self, path: str, **kwargs) -> bool:
except FileNotFoundError:
pass
file = self._head_bucket(bucket)
return bool(file)
if file:
return True
else:
return False

def rm_file(self, path: str, **kwargs) -> None:
bucket, key, version_id = self.parse_path(path)
Expand Down Expand Up @@ -711,32 +754,32 @@ def put_file(self, lpath: str, rpath: str, callback=_DEFAULT_CALLBACK, **kwargs)
if content_type is not None:
kwargs["ContentType"] = content_type

with (
self.open(rpath, "wb", s3_additional_kwargs=kwargs) as remote,
open(lpath, "rb") as local,
):
while data := local.read(remote.blocksize):
remote.write(data)
callback.relative_update(len(data))
with self.open(rpath, "wb", s3_additional_kwargs=kwargs) as remote:
with open(lpath, "rb") as local:
while data := local.read(remote.blocksize):
remote.write(data)
callback.relative_update(len(data))

self.invalidate_cache(rpath)

def get_file(self, rpath: str, lpath: str, callback=_DEFAULT_CALLBACK, outfile=None, **kwargs):
if os.path.isdir(lpath):
return

with open(lpath, "wb") as local, self.open(rpath, "rb", **kwargs) as remote:
callback.set_size(remote.size)
while data := remote.read(remote.blocksize):
local.write(data)
callback.relative_update(len(data))
with open(lpath, "wb") as local:
with self.open(rpath, "rb", **kwargs) as remote:
callback.set_size(remote.size)
while data := remote.read(remote.blocksize):
local.write(data)
callback.relative_update(len(data))

def checksum(self, path: str, **kwargs):
refresh = kwargs.pop("refresh", False)
info = self.info(path, refresh=refresh)
if info.get("type") != S3ObjectType.S3_OBJECT_TYPE_DIRECTORY:
return int(info.get("etag").strip('"').split("-")[0], 16)
return int(tokenize(info), 16)
else:
return int(tokenize(info), 16)

def sign(self, path: str, expiration: int = 3600, **kwargs):
bucket, key, version_id = self.parse_path(path)
Expand Down Expand Up @@ -933,7 +976,10 @@ def _complete_multipart_upload(
return S3CompleteMultipartUpload(response)

def _call(self, method: Union[str, Callable[..., Any]], **kwargs) -> Dict[str, Any]:
func = getattr(self._client, method) if isinstance(method, str) else method
if isinstance(method, str):
func = getattr(self._client, method)
else:
func = method
response = retry_api_call(
func, config=self._retry_config, logger=_logger, **kwargs, **self.request_kwargs
)
Expand Down Expand Up @@ -1218,8 +1264,9 @@ def _get_ranges(
if range_end > end:
ranges.append((range_start, end))
break
ranges.append((range_start, range_end))
range_start += worker_block_size
else:
ranges.append((range_start, range_end))
range_start += worker_block_size
else:
ranges.append((start, end))
return ranges
Expand Down

0 comments on commit dd13a29

Please sign in to comment.