Skip to content

Commit 3fc51af

Browse files
Async/data persistence (#2829)
Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
1 parent f1adbe8 commit 3fc51af

File tree

19 files changed

+266
-98
lines changed

19 files changed

+266
-98
lines changed

flytekit/core/data_persistence.py

Lines changed: 76 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
1919
"""
2020

21+
import asyncio
2122
import io
2223
import os
2324
import pathlib
@@ -29,6 +30,7 @@
2930

3031
import fsspec
3132
from decorator import decorator
33+
from fsspec.asyn import AsyncFileSystem
3234
from fsspec.utils import get_protocol
3335
from typing_extensions import Unpack
3436

@@ -40,6 +42,7 @@
4042
from flytekit.exceptions.user import FlyteAssertion, FlyteDataNotFoundException
4143
from flytekit.interfaces.random import random
4244
from flytekit.loggers import logger
45+
from flytekit.utils.asyn import loop_manager
4346

4447
# Refer to https://github.com/fsspec/s3fs/blob/50bafe4d8766c3b2a4e1fc09669cf02fb2d71454/s3fs/core.py#L198
4548
# for key and secret
@@ -208,8 +211,17 @@ def get_filesystem(
208211
storage_options = get_fsspec_storage_options(
209212
protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs
210213
)
214+
kwargs.update(storage_options)
211215

212-
return fsspec.filesystem(protocol, **storage_options)
216+
return fsspec.filesystem(protocol, **kwargs)
217+
218+
async def get_async_filesystem_for_path(
219+
self, path: str = "", anonymous: bool = False, **kwargs
220+
) -> Union[AsyncFileSystem, fsspec.AbstractFileSystem]:
221+
protocol = get_protocol(path)
222+
loop = asyncio.get_running_loop()
223+
224+
return self.get_filesystem(protocol, anonymous=anonymous, path=path, asynchronous=True, loop=loop, **kwargs)
213225

214226
def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem:
215227
protocol = get_protocol(path)
@@ -282,8 +294,8 @@ def exists(self, path: str) -> bool:
282294
raise oe
283295

284296
@retry_request
285-
def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
286-
file_system = self.get_filesystem_for_path(from_path)
297+
async def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
298+
file_system = await self.get_async_filesystem_for_path(from_path)
287299
if recursive:
288300
from_path, to_path = self.recursive_paths(from_path, to_path)
289301
try:
@@ -294,23 +306,33 @@ def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
294306
self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True
295307
)
296308
logger.info(f"Getting {from_path} to {to_path}")
297-
dst = file_system.get(from_path, to_path, recursive=recursive, **kwargs)
309+
if isinstance(file_system, AsyncFileSystem):
310+
dst = await file_system._get(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212
311+
else:
312+
dst = file_system.get(from_path, to_path, recursive=recursive, **kwargs)
298313
if isinstance(dst, (str, pathlib.Path)):
299314
return dst
300315
return to_path
301316
except OSError as oe:
302317
logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}")
303318
if not file_system.exists(from_path):
304319
raise FlyteDataNotFoundException(from_path)
305-
file_system = self.get_filesystem(get_protocol(from_path), anonymous=True)
320+
file_system = self.get_filesystem(get_protocol(from_path), anonymous=True, asynchronous=True)
306321
if file_system is not None:
307322
logger.debug(f"Attempting anonymous get with {file_system}")
308-
return file_system.get(from_path, to_path, recursive=recursive, **kwargs)
323+
if isinstance(file_system, AsyncFileSystem):
324+
return await file_system._get(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212
325+
else:
326+
return file_system.get(from_path, to_path, recursive=recursive, **kwargs)
309327
raise oe
310328

311329
@retry_request
312-
def put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
313-
file_system = self.get_filesystem_for_path(to_path)
330+
async def _put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
331+
"""
332+
More of an internal function to be called by put_data and put_raw_data
333+
This does not need a separate sync function.
334+
"""
335+
file_system = await self.get_async_filesystem_for_path(to_path)
314336
from_path = self.strip_file_header(from_path)
315337
if recursive:
316338
# Only check this for the local filesystem
@@ -327,13 +349,16 @@ def put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
327349
if "metadata" not in kwargs:
328350
kwargs["metadata"] = {}
329351
kwargs["metadata"].update(self._execution_metadata)
330-
dst = file_system.put(from_path, to_path, recursive=recursive, **kwargs)
352+
if isinstance(file_system, AsyncFileSystem):
353+
dst = await file_system._put(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212
354+
else:
355+
dst = file_system.put(from_path, to_path, recursive=recursive, **kwargs)
331356
if isinstance(dst, (str, pathlib.Path)):
332357
return dst
333358
else:
334359
return to_path
335360

336-
def put_raw_data(
361+
async def async_put_raw_data(
337362
self,
338363
lpath: Uploadable,
339364
upload_prefix: Optional[str] = None,
@@ -364,7 +389,7 @@ def put_raw_data(
364389
:param read_chunk_size_bytes: If lpath is a buffer, this is the chunk size to read from it
365390
:param encoding: If lpath is a io.StringIO, this is the encoding to use to encode it to binary.
366391
:param skip_raw_data_prefix: If True, the raw data prefix will not be prepended to the upload_prefix
367-
:param kwargs: Additional kwargs are passed into the the fsspec put() call or the open() call
392+
:param kwargs: Additional kwargs are passed into the fsspec put() call or the open() call
368393
:return: Returns the final path data was written to.
369394
"""
370395
# First figure out what the destination path should be, then call put.
@@ -388,42 +413,60 @@ def put_raw_data(
388413
raise FlyteAssertion(f"File {from_path} is a symlink, can't upload")
389414
if p.is_dir():
390415
logger.debug(f"Detected directory {from_path}, using recursive put")
391-
r = self.put(from_path, to_path, recursive=True, **kwargs)
416+
r = await self._put(from_path, to_path, recursive=True, **kwargs)
392417
else:
393418
logger.debug(f"Detected file {from_path}, call put non-recursive")
394-
r = self.put(from_path, to_path, **kwargs)
419+
r = await self._put(from_path, to_path, **kwargs)
395420
return r or to_path
396421

397422
# raw bytes
398423
if isinstance(lpath, bytes):
399-
fs = self.get_filesystem_for_path(to_path)
400-
with fs.open(to_path, "wb", **kwargs) as s:
401-
s.write(lpath)
424+
fs = await self.get_async_filesystem_for_path(to_path)
425+
if isinstance(fs, AsyncFileSystem):
426+
async with fs.open_async(to_path, "wb", **kwargs) as s:
427+
s.write(lpath)
428+
else:
429+
with fs.open(to_path, "wb", **kwargs) as s:
430+
s.write(lpath)
431+
402432
return to_path
403433

404434
# If lpath is a buffered reader of some kind
405435
if isinstance(lpath, io.BufferedReader) or isinstance(lpath, io.BytesIO):
406436
if not lpath.readable():
407437
raise FlyteAssertion("Buffered reader must be readable")
408-
fs = self.get_filesystem_for_path(to_path)
438+
fs = await self.get_async_filesystem_for_path(to_path)
409439
lpath.seek(0)
410-
with fs.open(to_path, "wb", **kwargs) as s:
411-
while data := lpath.read(read_chunk_size_bytes):
412-
s.write(data)
440+
if isinstance(fs, AsyncFileSystem):
441+
async with fs.open_async(to_path, "wb", **kwargs) as s:
442+
while data := lpath.read(read_chunk_size_bytes):
443+
s.write(data)
444+
else:
445+
with fs.open(to_path, "wb", **kwargs) as s:
446+
while data := lpath.read(read_chunk_size_bytes):
447+
s.write(data)
413448
return to_path
414449

415450
if isinstance(lpath, io.StringIO):
416451
if not lpath.readable():
417452
raise FlyteAssertion("Buffered reader must be readable")
418-
fs = self.get_filesystem_for_path(to_path)
453+
fs = await self.get_async_filesystem_for_path(to_path)
419454
lpath.seek(0)
420-
with fs.open(to_path, "wb", **kwargs) as s:
421-
while data_str := lpath.read(read_chunk_size_bytes):
422-
s.write(data_str.encode(encoding))
455+
if isinstance(fs, AsyncFileSystem):
456+
async with fs.open_async(to_path, "wb", **kwargs) as s:
457+
while data_str := lpath.read(read_chunk_size_bytes):
458+
s.write(data_str.encode(encoding))
459+
else:
460+
with fs.open(to_path, "wb", **kwargs) as s:
461+
while data_str := lpath.read(read_chunk_size_bytes):
462+
s.write(data_str.encode(encoding))
423463
return to_path
424464

425465
raise FlyteAssertion(f"Unsupported lpath type {type(lpath)}")
426466

467+
# Public synchronous version
468+
put_raw_data = loop_manager.synced(async_put_raw_data)
469+
427470
@staticmethod
428471
def get_random_string() -> str:
429472
return UUID(int=random.getrandbits(128)).hex
@@ -549,7 +592,7 @@ def upload_directory(self, local_path: str, remote_path: str, **kwargs):
549592
"""
550593
return self.put_data(local_path, remote_path, is_multipart=True, **kwargs)
551594

552-
def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False, **kwargs):
595+
async def async_get_data(self, remote_path: str, local_path: str, is_multipart: bool = False, **kwargs):
553596
"""
554597
:param remote_path:
555598
:param local_path:
@@ -558,7 +601,7 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False
558601
try:
559602
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)
560603
with timeit(f"Download data to local from {remote_path}"):
561-
self.get(remote_path, to_path=local_path, recursive=is_multipart, **kwargs)
604+
await self.get(remote_path, to_path=local_path, recursive=is_multipart, **kwargs)
562605
except FlyteDataNotFoundException:
563606
raise
564607
except Exception as ex:
@@ -567,7 +610,9 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False
567610
f"Original exception: {str(ex)}"
568611
)
569612

570-
def put_data(
613+
get_data = loop_manager.synced(async_get_data)
614+
615+
async def async_put_data(
571616
self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart: bool = False, **kwargs
572617
) -> str:
573618
"""
@@ -581,7 +626,7 @@ def put_data(
581626
try:
582627
local_path = str(local_path)
583628
with timeit(f"Upload data to {remote_path}"):
584-
put_result = self.put(cast(str, local_path), remote_path, recursive=is_multipart, **kwargs)
629+
put_result = await self._put(cast(str, local_path), remote_path, recursive=is_multipart, **kwargs)
585630
# This is an unfortunate workaround to ensure that we return the correct path for the remote location
586631
# Callers of this put_data function in flytekit have been changed to assign the remote path to the
587632
# output
@@ -595,6 +640,9 @@ def put_data(
595640
f"Original exception: {str(ex)}"
596641
) from ex
597642

643+
# Public synchronous version
644+
put_data = loop_manager.synced(async_put_data)
645+
598646

599647
flyte_tmp_dir = tempfile.mkdtemp(prefix="flyte-")
600648
default_local_file_access_provider = FileAccessProvider(

flytekit/core/type_engine.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1908,7 +1908,9 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple:
19081908
return None, None
19091909

19101910
@staticmethod
1911-
def dict_to_binary_literal(ctx: FlyteContext, v: dict, python_type: Type[dict], allow_pickle: bool) -> Literal:
1911+
async def dict_to_binary_literal(
1912+
ctx: FlyteContext, v: dict, python_type: Type[dict], allow_pickle: bool
1913+
) -> Literal:
19121914
"""
19131915
Converts a Python dictionary to a Flyte-specific ``Literal`` using MessagePack encoding.
19141916
Falls back to Pickle if encoding fails and `allow_pickle` is True.
@@ -1922,7 +1924,7 @@ def dict_to_binary_literal(ctx: FlyteContext, v: dict, python_type: Type[dict],
19221924
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))
19231925
except TypeError as e:
19241926
if allow_pickle:
1925-
remote_path = FlytePickle.to_pickle(ctx, v)
1927+
remote_path = await FlytePickle.to_pickle(ctx, v)
19261928
return Literal(
19271929
scalar=Scalar(
19281930
generic=_json_format.Parse(json.dumps({"pickle_file": remote_path}), _struct.Struct())
@@ -1980,7 +1982,7 @@ async def async_to_literal(
19801982
allow_pickle, base_type = DictTransformer.is_pickle(python_type)
19811983

19821984
if expected and expected.simple and expected.simple == SimpleType.STRUCT:
1983-
return self.dict_to_binary_literal(ctx, python_val, python_type, allow_pickle)
1985+
return await self.dict_to_binary_literal(ctx, python_val, python_type, allow_pickle)
19841986

19851987
lit_map = {}
19861988
for k, v in python_val.items():
@@ -2036,7 +2038,7 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p
20362038
from flytekit.types.pickle import FlytePickle
20372039

20382040
uri = json.loads(_json_format.MessageToJson(lv.scalar.generic)).get("pickle_file")
2039-
return FlytePickle.from_pickle(uri)
2041+
return await FlytePickle.from_pickle(uri)
20402042

20412043
try:
20422044
return json.loads(_json_format.MessageToJson(lv.scalar.generic))

flytekit/extend/backend/base_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ async def _create(
368368
literal_map = await TypeEngine._dict_to_literal_map(ctx, inputs or {}, self.get_input_types())
369369
path = ctx.file_access.get_random_local_path()
370370
utils.write_proto_to_file(literal_map.to_flyte_idl(), path)
371-
ctx.file_access.put_data(path, f"{output_prefix}/inputs.pb")
371+
await ctx.file_access.async_put_data(path, f"{output_prefix}/inputs.pb")
372372
task_template = render_task_template(task_template, output_prefix)
373373
else:
374374
literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types())

flytekit/extras/tensorflow/model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
import tensorflow as tf
55

66
from flytekit.core.context_manager import FlyteContext
7-
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
7+
from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError
88
from flytekit.models.core import types as _core_types
99
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
1010
from flytekit.models.types import LiteralType
1111

1212

13-
class TensorFlowModelTransformer(TypeTransformer[tf.keras.Model]):
13+
class TensorFlowModelTransformer(AsyncTypeTransformer[tf.keras.Model]):
1414
TENSORFLOW_FORMAT = "TensorFlowModel"
1515

1616
def __init__(self):
@@ -24,7 +24,7 @@ def get_literal_type(self, t: Type[tf.keras.Model]) -> LiteralType:
2424
)
2525
)
2626

27-
def to_literal(
27+
async def async_to_literal(
2828
self,
2929
ctx: FlyteContext,
3030
python_val: tf.keras.Model,
@@ -44,10 +44,10 @@ def to_literal(
4444
# save model in SavedModel format
4545
tf.keras.models.save_model(python_val, local_path)
4646

47-
remote_path = ctx.file_access.put_raw_data(local_path)
47+
remote_path = await ctx.file_access.async_put_raw_data(local_path)
4848
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))
4949

50-
def to_python_value(
50+
async def async_to_python_value(
5151
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[tf.keras.Model]
5252
) -> tf.keras.Model:
5353
try:
@@ -56,7 +56,7 @@ def to_python_value(
5656
TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")
5757

5858
local_path = ctx.file_access.get_random_local_path()
59-
ctx.file_access.get_data(uri, local_path, is_multipart=True)
59+
await ctx.file_access.async_get_data(uri, local_path, is_multipart=True)
6060

6161
# load model
6262
return tf.keras.models.load_model(local_path)

flytekit/types/directory/types.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from flytekit import BlobType
2020
from flytekit.core.constants import MESSAGEPACK
2121
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
22-
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_batch_size
22+
from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError, get_batch_size
2323
from flytekit.exceptions.user import FlyteAssertion
2424
from flytekit.models import types as _type_models
2525
from flytekit.models.core import types as _core_types
@@ -407,7 +407,7 @@ def __str__(self):
407407
return str(self.path)
408408

409409

410-
class FlyteDirToMultipartBlobTransformer(TypeTransformer[FlyteDirectory]):
410+
class FlyteDirToMultipartBlobTransformer(AsyncTypeTransformer[FlyteDirectory]):
411411
"""
412412
This transformer handles conversion between the Python native FlyteDirectory class defined above, and the Flyte
413413
IDL literal/type of Multipart Blob. Please see the FlyteDirectory comments for additional information.
@@ -444,7 +444,7 @@ def assert_type(self, t: typing.Type[FlyteDirectory], v: typing.Union[FlyteDirec
444444
def get_literal_type(self, t: typing.Type[FlyteDirectory]) -> LiteralType:
445445
return _type_models.LiteralType(blob=self._blob_type(format=FlyteDirToMultipartBlobTransformer.get_format(t)))
446446

447-
def to_literal(
447+
async def async_to_literal(
448448
self,
449449
ctx: FlyteContext,
450450
python_val: FlyteDirectory,
@@ -499,7 +499,9 @@ def to_literal(
499499
remote_directory = ctx.file_access.get_random_remote_directory()
500500
if not pathlib.Path(source_path).is_dir():
501501
raise FlyteAssertion("Expected a directory. {} is not a directory".format(source_path))
502-
ctx.file_access.put_data(source_path, remote_directory, is_multipart=True, batch_size=batch_size)
502+
await ctx.file_access.async_put_data(
503+
source_path, remote_directory, is_multipart=True, batch_size=batch_size
504+
)
503505
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_directory)))
504506

505507
# If not uploading, then we can only take the original source path as the uri.
@@ -535,7 +537,7 @@ def from_binary_idl(
535537
else:
536538
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")
537539

538-
def to_python_value(
540+
async def async_to_python_value(
539541
self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[FlyteDirectory]
540542
) -> FlyteDirectory:
541543
if lv.scalar.binary:

0 commit comments

Comments
 (0)