Skip to content

Commit e08792a

Browse files
Remove open_async usage in put raw data (#2998)
Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
1 parent 576fb26 commit e08792a

File tree

4 files changed

+58
-29
lines changed

4 files changed

+58
-29
lines changed

flytekit/core/data_persistence.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -423,47 +423,34 @@ async def async_put_raw_data(
423423
r = await self._put(from_path, to_path, **kwargs)
424424
return r or to_path
425425

426+
# See https://github.com/fsspec/s3fs/issues/871 for more background and pending work on the fsspec side to
427+
# support effectively async open(). For now these use-cases below will revert to sync calls.
426428
# raw bytes
427429
if isinstance(lpath, bytes):
428-
fs = await self.get_async_filesystem_for_path(to_path)
429-
if isinstance(fs, AsyncFileSystem):
430-
async with fs.open_async(to_path, "wb", **kwargs) as s:
431-
s.write(lpath)
432-
else:
433-
with fs.open(to_path, "wb", **kwargs) as s:
434-
s.write(lpath)
435-
430+
fs = self.get_filesystem_for_path(to_path)
431+
with fs.open(to_path, "wb", **kwargs) as s:
432+
s.write(lpath)
436433
return to_path
437434

438435
# If lpath is a buffered reader of some kind
439436
if isinstance(lpath, io.BufferedReader) or isinstance(lpath, io.BytesIO):
440437
if not lpath.readable():
441438
raise FlyteAssertion("Buffered reader must be readable")
442-
fs = await self.get_async_filesystem_for_path(to_path)
439+
fs = self.get_filesystem_for_path(to_path)
443440
lpath.seek(0)
444-
if isinstance(fs, AsyncFileSystem):
445-
async with fs.open_async(to_path, "wb", **kwargs) as s:
446-
while data := lpath.read(read_chunk_size_bytes):
447-
s.write(data)
448-
else:
449-
with fs.open(to_path, "wb", **kwargs) as s:
450-
while data := lpath.read(read_chunk_size_bytes):
451-
s.write(data)
441+
with fs.open(to_path, "wb", **kwargs) as s:
442+
while data := lpath.read(read_chunk_size_bytes):
443+
s.write(data)
452444
return to_path
453445

454446
if isinstance(lpath, io.StringIO):
455447
if not lpath.readable():
456448
raise FlyteAssertion("Buffered reader must be readable")
457-
fs = await self.get_async_filesystem_for_path(to_path)
449+
fs = self.get_filesystem_for_path(to_path)
458450
lpath.seek(0)
459-
if isinstance(fs, AsyncFileSystem):
460-
async with fs.open_async(to_path, "wb", **kwargs) as s:
461-
while data_str := lpath.read(read_chunk_size_bytes):
462-
s.write(data_str.encode(encoding))
463-
else:
464-
with fs.open(to_path, "wb", **kwargs) as s:
465-
while data_str := lpath.read(read_chunk_size_bytes):
466-
s.write(data_str.encode(encoding))
451+
with fs.open(to_path, "wb", **kwargs) as s:
452+
while data_str := lpath.read(read_chunk_size_bytes):
453+
s.write(data_str.encode(encoding))
467454
return to_path
468455

469456
raise FlyteAssertion(f"Unsupported lpath type {type(lpath)}")

plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,10 @@ def encode(
6969
df.to_parquet(output_bytes)
7070

7171
if structured_dataset.uri is not None:
72+
output_bytes.seek(0)
7273
fs = ctx.file_access.get_filesystem_for_path(path=structured_dataset.uri)
7374
with fs.open(structured_dataset.uri, "wb") as s:
74-
s.write(output_bytes)
75+
s.write(output_bytes.read())
7576
output_uri = structured_dataset.uri
7677
else:
7778
remote_fn = "00000" # 00000 is our default unnamed parquet filename

plugins/flytekit-polars/tests/test_polars_plugin_sd.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from flytekitplugins.polars.sd_transformers import PolarsDataFrameRenderer
77
from typing_extensions import Annotated
8-
from packaging import version
8+
import numpy as np
99
from polars.testing import assert_frame_equal
1010

1111
from flytekit import kwtypes, task, workflow
@@ -134,3 +134,28 @@ def consume_sd_return_sd(sd: StructuredDataset) -> StructuredDataset:
134134
opened_sd = opened_sd.collect()
135135

136136
assert_frame_equal(opened_sd, polars_df)
137+
138+
139+
def test_with_uri():
140+
temp_file = tempfile.mktemp()
141+
142+
@task
143+
def random_dataframe(num_rows: int) -> StructuredDataset:
144+
feature_1_list = np.random.randint(low=100, high=999, size=(num_rows,))
145+
feature_2_list = np.random.normal(loc=0, scale=1, size=(num_rows, ))
146+
pl_df = pl.DataFrame({'protein_length': feature_1_list,
147+
'protein_feature': feature_2_list})
148+
sd = StructuredDataset(dataframe=pl_df, uri=temp_file)
149+
return sd
150+
151+
@task
152+
def consume(df: pd.DataFrame):
153+
print(df.head(5))
154+
print(df.describe())
155+
156+
@workflow
157+
def my_wf(num_rows: int):
158+
pl = random_dataframe(num_rows=num_rows)
159+
consume(pl)
160+
161+
my_wf(num_rows=100)

tests/flytekit/unit/core/test_data_persistence.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import io
22
import os
3-
import fsspec
43
import pathlib
54
import random
65
import string
76
import sys
87
import tempfile
98

9+
import fsspec
1010
import mock
1111
import pytest
1212
from azure.identity import ClientSecretCredential, DefaultAzureCredential
1313

14+
from flytekit.configuration import Config
1415
from flytekit.core.data_persistence import FileAccessProvider
1516
from flytekit.core.local_fsspec import FlyteLocalFileSystem
1617

@@ -207,3 +208,18 @@ def __init__(self, *args, **kwargs):
207208

208209
fp = FileAccessProvider("/tmp", "s3://my-bucket")
209210
fp.get_filesystem("testgetfs", test_arg="test_arg")
211+
212+
213+
@pytest.mark.sandbox_test
214+
def test_put_raw_data_bytes():
215+
dc = Config.for_sandbox().data_config
216+
raw_output = f"s3://my-s3-bucket/"
217+
provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc)
218+
prefix = provider.get_random_string()
219+
provider.put_raw_data(lpath=b"hello", upload_prefix=prefix, file_name="hello_bytes")
220+
provider.put_raw_data(lpath=io.BytesIO(b"hello"), upload_prefix=prefix, file_name="hello_bytes_io")
221+
provider.put_raw_data(lpath=io.StringIO("hello"), upload_prefix=prefix, file_name="hello_string_io")
222+
223+
fs = provider.get_filesystem("s3")
224+
listing = fs.ls(f"{raw_output}{prefix}/")
225+
assert len(listing) == 3

0 commit comments

Comments
 (0)