18
18
19
19
"""
20
20
21
+ import asyncio
21
22
import io
22
23
import os
23
24
import pathlib
29
30
30
31
import fsspec
31
32
from decorator import decorator
33
+ from fsspec .asyn import AsyncFileSystem
32
34
from fsspec .utils import get_protocol
33
35
from typing_extensions import Unpack
34
36
40
42
from flytekit .exceptions .user import FlyteAssertion , FlyteDataNotFoundException
41
43
from flytekit .interfaces .random import random
42
44
from flytekit .loggers import logger
45
+ from flytekit .utils .asyn import loop_manager
43
46
44
47
# Refer to https://github.com/fsspec/s3fs/blob/50bafe4d8766c3b2a4e1fc09669cf02fb2d71454/s3fs/core.py#L198
45
48
# for key and secret
@@ -208,8 +211,17 @@ def get_filesystem(
208
211
storage_options = get_fsspec_storage_options (
209
212
protocol = protocol , anonymous = anonymous , data_config = self ._data_config , ** kwargs
210
213
)
214
+ kwargs .update (storage_options )
211
215
212
- return fsspec .filesystem (protocol , ** storage_options )
216
+ return fsspec .filesystem (protocol , ** kwargs )
217
+
218
+ async def get_async_filesystem_for_path (
219
+ self , path : str = "" , anonymous : bool = False , ** kwargs
220
+ ) -> Union [AsyncFileSystem , fsspec .AbstractFileSystem ]:
221
+ protocol = get_protocol (path )
222
+ loop = asyncio .get_running_loop ()
223
+
224
+ return self .get_filesystem (protocol , anonymous = anonymous , path = path , asynchronous = True , loop = loop , ** kwargs )
213
225
214
226
def get_filesystem_for_path (self , path : str = "" , anonymous : bool = False , ** kwargs ) -> fsspec .AbstractFileSystem :
215
227
protocol = get_protocol (path )
@@ -282,8 +294,8 @@ def exists(self, path: str) -> bool:
282
294
raise oe
283
295
284
296
@retry_request
285
- def get (self , from_path : str , to_path : str , recursive : bool = False , ** kwargs ):
286
- file_system = self .get_filesystem_for_path (from_path )
297
+ async def get (self , from_path : str , to_path : str , recursive : bool = False , ** kwargs ):
298
+ file_system = await self .get_async_filesystem_for_path (from_path )
287
299
if recursive :
288
300
from_path , to_path = self .recursive_paths (from_path , to_path )
289
301
try :
@@ -294,23 +306,33 @@ def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
294
306
self .strip_file_header (from_path ), self .strip_file_header (to_path ), dirs_exist_ok = True
295
307
)
296
308
logger .info (f"Getting { from_path } to { to_path } " )
297
- dst = file_system .get (from_path , to_path , recursive = recursive , ** kwargs )
309
+ if isinstance (file_system , AsyncFileSystem ):
310
+ dst = await file_system ._get (from_path , to_path , recursive = recursive , ** kwargs ) # pylint: disable=W0212
311
+ else :
312
+ dst = file_system .get (from_path , to_path , recursive = recursive , ** kwargs )
298
313
if isinstance (dst , (str , pathlib .Path )):
299
314
return dst
300
315
return to_path
301
316
except OSError as oe :
302
317
logger .debug (f"Error in getting { from_path } to { to_path } rec { recursive } { oe } " )
303
318
if not file_system .exists (from_path ):
304
319
raise FlyteDataNotFoundException (from_path )
305
- file_system = self .get_filesystem (get_protocol (from_path ), anonymous = True )
320
+ file_system = self .get_filesystem (get_protocol (from_path ), anonymous = True , asynchronous = True )
306
321
if file_system is not None :
307
322
logger .debug (f"Attempting anonymous get with { file_system } " )
308
- return file_system .get (from_path , to_path , recursive = recursive , ** kwargs )
323
+ if isinstance (file_system , AsyncFileSystem ):
324
+ return await file_system ._get (from_path , to_path , recursive = recursive , ** kwargs ) # pylint: disable=W0212
325
+ else :
326
+ return file_system .get (from_path , to_path , recursive = recursive , ** kwargs )
309
327
raise oe
310
328
311
329
@retry_request
312
- def put (self , from_path : str , to_path : str , recursive : bool = False , ** kwargs ):
313
- file_system = self .get_filesystem_for_path (to_path )
330
+ async def _put (self , from_path : str , to_path : str , recursive : bool = False , ** kwargs ):
331
+ """
332
+ More of an internal function to be called by put_data and put_raw_data
333
+ This does not need a separate sync function.
334
+ """
335
+ file_system = await self .get_async_filesystem_for_path (to_path )
314
336
from_path = self .strip_file_header (from_path )
315
337
if recursive :
316
338
# Only check this for the local filesystem
@@ -327,13 +349,16 @@ def put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
327
349
if "metadata" not in kwargs :
328
350
kwargs ["metadata" ] = {}
329
351
kwargs ["metadata" ].update (self ._execution_metadata )
330
- dst = file_system .put (from_path , to_path , recursive = recursive , ** kwargs )
352
+ if isinstance (file_system , AsyncFileSystem ):
353
+ dst = await file_system ._put (from_path , to_path , recursive = recursive , ** kwargs ) # pylint: disable=W0212
354
+ else :
355
+ dst = file_system .put (from_path , to_path , recursive = recursive , ** kwargs )
331
356
if isinstance (dst , (str , pathlib .Path )):
332
357
return dst
333
358
else :
334
359
return to_path
335
360
336
- def put_raw_data (
361
+ async def async_put_raw_data (
337
362
self ,
338
363
lpath : Uploadable ,
339
364
upload_prefix : Optional [str ] = None ,
@@ -364,7 +389,7 @@ def put_raw_data(
364
389
:param read_chunk_size_bytes: If lpath is a buffer, this is the chunk size to read from it
365
390
:param encoding: If lpath is a io.StringIO, this is the encoding to use to encode it to binary.
366
391
:param skip_raw_data_prefix: If True, the raw data prefix will not be prepended to the upload_prefix
367
- :param kwargs: Additional kwargs are passed into the the fsspec put() call or the open() call
392
+ :param kwargs: Additional kwargs are passed into the fsspec put() call or the open() call
368
393
:return: Returns the final path data was written to.
369
394
"""
370
395
# First figure out what the destination path should be, then call put.
@@ -388,42 +413,60 @@ def put_raw_data(
388
413
raise FlyteAssertion (f"File { from_path } is a symlink, can't upload" )
389
414
if p .is_dir ():
390
415
logger .debug (f"Detected directory { from_path } , using recursive put" )
391
- r = self .put (from_path , to_path , recursive = True , ** kwargs )
416
+ r = await self ._put (from_path , to_path , recursive = True , ** kwargs )
392
417
else :
393
418
logger .debug (f"Detected file { from_path } , call put non-recursive" )
394
- r = self .put (from_path , to_path , ** kwargs )
419
+ r = await self ._put (from_path , to_path , ** kwargs )
395
420
return r or to_path
396
421
397
422
# raw bytes
398
423
if isinstance (lpath , bytes ):
399
- fs = self .get_filesystem_for_path (to_path )
400
- with fs .open (to_path , "wb" , ** kwargs ) as s :
401
- s .write (lpath )
424
+ fs = await self .get_async_filesystem_for_path (to_path )
425
+ if isinstance (fs , AsyncFileSystem ):
426
+ async with fs .open_async (to_path , "wb" , ** kwargs ) as s :
427
+ s .write (lpath )
428
+ else :
429
+ with fs .open (to_path , "wb" , ** kwargs ) as s :
430
+ s .write (lpath )
431
+
402
432
return to_path
403
433
404
434
# If lpath is a buffered reader of some kind
405
435
if isinstance (lpath , io .BufferedReader ) or isinstance (lpath , io .BytesIO ):
406
436
if not lpath .readable ():
407
437
raise FlyteAssertion ("Buffered reader must be readable" )
408
- fs = self .get_filesystem_for_path (to_path )
438
+ fs = await self .get_async_filesystem_for_path (to_path )
409
439
lpath .seek (0 )
410
- with fs .open (to_path , "wb" , ** kwargs ) as s :
411
- while data := lpath .read (read_chunk_size_bytes ):
412
- s .write (data )
440
+ if isinstance (fs , AsyncFileSystem ):
441
+ async with fs .open_async (to_path , "wb" , ** kwargs ) as s :
442
+ while data := lpath .read (read_chunk_size_bytes ):
443
+ s .write (data )
444
+ else :
445
+ with fs .open (to_path , "wb" , ** kwargs ) as s :
446
+ while data := lpath .read (read_chunk_size_bytes ):
447
+ s .write (data )
413
448
return to_path
414
449
415
450
if isinstance (lpath , io .StringIO ):
416
451
if not lpath .readable ():
417
452
raise FlyteAssertion ("Buffered reader must be readable" )
418
- fs = self .get_filesystem_for_path (to_path )
453
+ fs = await self .get_async_filesystem_for_path (to_path )
419
454
lpath .seek (0 )
420
- with fs .open (to_path , "wb" , ** kwargs ) as s :
421
- while data_str := lpath .read (read_chunk_size_bytes ):
422
- s .write (data_str .encode (encoding ))
455
+ if isinstance (fs , AsyncFileSystem ):
456
+ async with fs .open_async (to_path , "wb" , ** kwargs ) as s :
457
+ while data_str := lpath .read (read_chunk_size_bytes ):
458
+ s .write (data_str .encode (encoding ))
459
+ else :
460
+ with fs .open (to_path , "wb" , ** kwargs ) as s :
461
+ while data_str := lpath .read (read_chunk_size_bytes ):
462
+ s .write (data_str .encode (encoding ))
423
463
return to_path
424
464
425
465
raise FlyteAssertion (f"Unsupported lpath type { type (lpath )} " )
426
466
467
+ # Public synchronous version
468
+ put_raw_data = loop_manager .synced (async_put_raw_data )
469
+
427
470
@staticmethod
428
471
def get_random_string () -> str :
429
472
return UUID (int = random .getrandbits (128 )).hex
@@ -549,7 +592,7 @@ def upload_directory(self, local_path: str, remote_path: str, **kwargs):
549
592
"""
550
593
return self .put_data (local_path , remote_path , is_multipart = True , ** kwargs )
551
594
552
- def get_data (self , remote_path : str , local_path : str , is_multipart : bool = False , ** kwargs ):
595
+ async def async_get_data (self , remote_path : str , local_path : str , is_multipart : bool = False , ** kwargs ):
553
596
"""
554
597
:param remote_path:
555
598
:param local_path:
@@ -558,7 +601,7 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False
558
601
try :
559
602
pathlib .Path (local_path ).parent .mkdir (parents = True , exist_ok = True )
560
603
with timeit (f"Download data to local from { remote_path } " ):
561
- self .get (remote_path , to_path = local_path , recursive = is_multipart , ** kwargs )
604
+ await self .get (remote_path , to_path = local_path , recursive = is_multipart , ** kwargs )
562
605
except FlyteDataNotFoundException :
563
606
raise
564
607
except Exception as ex :
@@ -567,7 +610,9 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False
567
610
f"Original exception: { str (ex )} "
568
611
)
569
612
570
- def put_data (
613
+ get_data = loop_manager .synced (async_get_data )
614
+
615
+ async def async_put_data (
571
616
self , local_path : Union [str , os .PathLike ], remote_path : str , is_multipart : bool = False , ** kwargs
572
617
) -> str :
573
618
"""
@@ -581,7 +626,7 @@ def put_data(
581
626
try :
582
627
local_path = str (local_path )
583
628
with timeit (f"Upload data to { remote_path } " ):
584
- put_result = self .put (cast (str , local_path ), remote_path , recursive = is_multipart , ** kwargs )
629
+ put_result = await self ._put (cast (str , local_path ), remote_path , recursive = is_multipart , ** kwargs )
585
630
# This is an unfortunate workaround to ensure that we return the correct path for the remote location
586
631
# Callers of this put_data function in flytekit have been changed to assign the remote path to the
587
632
# output
@@ -595,6 +640,9 @@ def put_data(
595
640
f"Original exception: { str (ex )} "
596
641
) from ex
597
642
643
+ # Public synchronous version
644
+ put_data = loop_manager .synced (async_put_data )
645
+
598
646
599
647
flyte_tmp_dir = tempfile .mkdtemp (prefix = "flyte-" )
600
648
default_local_file_access_provider = FileAccessProvider (
0 commit comments