From cdb56c0d6a3b9cae1ea7cb8e251554735d38ac24 Mon Sep 17 00:00:00 2001 From: JiaWei Jiang Date: Wed, 1 Jan 2025 23:50:10 +0800 Subject: [PATCH 1/7] fix: Retain user-specified file format info Signed-off-by: JiaWei Jiang --- flytekit/types/structured/structured_dataset.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index da9cc79753..7dc8de4695 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -275,6 +275,7 @@ def extract_cols_and_format( optional str for the format, optional pyarrow Schema """ + # breakpoint() fmt = "" ordered_dict_cols = None pa_schema = None @@ -691,7 +692,7 @@ async def async_to_literal( # Check first to see if it's even an SD type. For backwards compatibility, we may be getting a FlyteSchema python_type, *attrs = extract_cols_and_format(python_type) # In case it's a FlyteSchema - sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT)) + sdt = StructuredDatasetType(format="parquet") # self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT)) if issubclass(python_type, StructuredDataset) and not isinstance(python_val, StructuredDataset): # Catch a common mistake @@ -1093,6 +1094,11 @@ def _convert_ordered_dict_of_columns_to_list( return converted_cols def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any]) -> StructuredDatasetType: + # breakpoint() + # if get_origin(t) is Annotated: + # original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) # type: ignore + # else: + # column_map, storage_format, pa_schema = None, t.file_format, None original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) # type: ignore # Get the column information From 36ab6e93ce1515dac35b56409402117f373f329a Mon Sep 17 00:00:00 2001 From: JiaWei Jiang Date: Fri, 3 Jan 2025 22:32:14 +0800 Subject: [PATCH 2/7] fix: Set sdt format based on user-specified file_format Signed-off-by: JiaWei Jiang --- flytekit/types/structured/structured_dataset.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 7dc8de4695..5cac2d8d91 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -692,7 +692,7 @@ async def async_to_literal( # Check first to see if it's even an SD type. For backwards compatibility, we may be getting a FlyteSchema python_type, *attrs = extract_cols_and_format(python_type) # In case it's a FlyteSchema - sdt = StructuredDatasetType(format="parquet") # self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT)) + sdt = StructuredDatasetType(self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT)) if issubclass(python_type, StructuredDataset) and not isinstance(python_val, StructuredDataset): # Catch a common mistake @@ -738,10 +738,20 @@ async def async_to_literal( # return StructuredDataset(uri=uri) if python_val.dataframe is None: uri = python_val.uri + file_format = python_val.file_format + + # Check the user-specified uri if not uri: raise ValueError(f"If dataframe is not specified, then the uri should be specified. {python_val}") if not ctx.file_access.is_remote(uri): uri = await ctx.file_access.async_put_raw_data(uri) + + # Check the user-specified file_format + # When users specify file_format for a StructuredDataset, the file_format information must be retained. + # For details, please refer to https://github.com/flyteorg/flyte/issues/6096. + if file_format != GENERIC_FORMAT: + sdt.format = file_format + sd_model = literals.StructuredDataset( uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type=sdt), From 4af1f0046b6e0ac0ce1522ea8ff6d033dfc7c6cb Mon Sep 17 00:00:00 2001 From: JiaWei Jiang Date: Sat, 4 Jan 2025 13:04:54 +0800 Subject: [PATCH 3/7] Remove redundant modification Signed-off-by: JiaWei Jiang --- flytekit/types/structured/structured_dataset.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 5cac2d8d91..2ca4d758bc 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -275,7 +275,6 @@ def extract_cols_and_format( optional str for the format, optional pyarrow Schema """ - # breakpoint() fmt = "" ordered_dict_cols = None pa_schema = None @@ -692,7 +691,7 @@ async def async_to_literal( # Check first to see if it's even an SD type. For backwards compatibility, we may be getting a FlyteSchema python_type, *attrs = extract_cols_and_format(python_type) # In case it's a FlyteSchema - sdt = StructuredDatasetType(self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT)) + sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT)) if issubclass(python_type, StructuredDataset) and not isinstance(python_val, StructuredDataset): # Catch a common mistake @@ -1104,11 +1103,6 @@ def _convert_ordered_dict_of_columns_to_list( return converted_cols def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any]) -> StructuredDatasetType: - # breakpoint() - # if get_origin(t) is Annotated: - # original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) # type: ignore - # else: - # column_map, storage_format, pa_schema = None, t.file_format, None original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) # type: ignore # Get the column information From 46b7b6280ffd59a19808238cc282dcb599fad9a0 Mon Sep 17 00:00:00 2001 From: JiaWei Jiang Date: Sat, 4 Jan 2025 14:35:20 +0800 Subject: [PATCH 4/7] test: Test file_format attribute alignment in dc.sd Signed-off-by: JiaWei Jiang --- .../integration/remote/test_remote.py | 47 ++++++++++++- .../remote/workflows/basic/sd_attr.py | 68 +++++++++++++++++++ 2 files changed, 114 insertions(+), 1 deletion(-) create mode 100644 tests/flytekit/integration/remote/workflows/basic/sd_attr.py diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index f20a33aea8..491a7b33bb 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -14,7 +14,8 @@ from urllib.parse import urlparse import uuid import pytest -import mock +from unittest import mock +from dataclasses import dataclass from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase from flytekit.configuration import Config, ImageConfig, SerializationSettings @@ -26,6 +27,7 @@ from flytekit.remote.remote import FlyteRemote from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2 from flytekit.types.schema import FlyteSchema +from flytekit.types.structured import StructuredDataset from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient from flytekit.configuration import PlatformConfig @@ -833,3 +835,46 @@ def test_open_ff(): url = urlparse(remote_file_path) bucket, key = url.netloc, url.path.lstrip("/") file_transfer.delete_file(bucket=bucket, key=key) + + +def test_sd_attr(): + """Test correctness of StructuredDataset attributes. + + This test considers only the following condition: + 1. Check StructuredDataset (wrapped in a dataclass) file_format attribute + + We'll make sure uri aligns with the user-specified one in the future. + """ + from workflows.basic.sd_attr import wf + + @dataclass + class DC: + sd: StructuredDataset + + FILE_FORMAT = "parquet" + + # Upload a file to minio s3 bucket + file_transfer = SimpleFileTransfer() + remote_file_path = file_transfer.upload_file(file_type=FILE_FORMAT) + + # Create a dataclass as the workflow input because `pyflyte run` + # can't properly handle input arg `dc` as a json str so far + dc = DC(sd=StructuredDataset(uri=remote_file_path, file_format=FILE_FORMAT)) + + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN, interactive_mode_enabled=True) + wf_exec = remote.execute( + wf, + inputs={"dc": dc, "file_format": FILE_FORMAT}, + wait=True, + version=VERSION, + image_config=ImageConfig.from_images(IMAGE), + ) + assert wf_exec.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {wf_exec.closure.phase}" + assert wf_exec.outputs["o0"].file_format == FILE_FORMAT, ( + f"Workflow output StructuredDataset file_format should align with the user-specified file_format: {FILE_FORMAT}." + ) + + # Delete the remote file to free the space + url = urlparse(remote_file_path) + bucket, key = url.netloc, url.path.lstrip("/") + file_transfer.delete_file(bucket=bucket, key=key) diff --git a/tests/flytekit/integration/remote/workflows/basic/sd_attr.py b/tests/flytekit/integration/remote/workflows/basic/sd_attr.py new file mode 100644 index 0000000000..b357d52fa4 --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/sd_attr.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass + +import pandas as pd +from flytekit import task, workflow +from flytekit.types.structured import StructuredDataset + + +@dataclass +class DC: + sd: StructuredDataset + + +@task +def create_dc(uri: str, file_format: str) -> DC: + """Create a dataclass with a StructuredDataset attribute. + + Args: + uri: File URI. + file_format: File format, e.g., parquet, csv. + + Returns: + dc: A dataclass with a StructuredDataset attribute. + """ + dc = DC(sd=StructuredDataset(uri=uri, file_format=file_format)) + + return dc + + +@task +def check_file_format(sd: StructuredDataset, true_file_format: str) -> StructuredDataset: + """Check StructuredDataset file_format attribute. + + StruturedDataset file_format should align with what users specify. + + Args: + sd: Python native StructuredDataset. + true_file_format: User-specified file_format. + """ + assert sd.file_format == true_file_format, ( + f"StructuredDataset file_format should align with the user-specified file_format: {true_file_format}." + ) + assert sd._literal_sd.metadata.structured_dataset_type.format == true_file_format, ( + f"StructuredDatasetType format should align with the user-specified file_format: {true_file_format}." + ) + print(f">>> SD <<<\n{sd}") + print(f">>> Literal SD <<<\n{sd._literal_sd}") + print(f">>> SDT <<<\n{sd._literal_sd.metadata.structured_dataset_type}") + print(f">>> DF <<<\n{sd.open(pd.DataFrame).all()}") + + return sd + + +@workflow +def wf(dc: DC, file_format: str) -> StructuredDataset: + # Fail to use dc.sd.file_format as the input + sd = check_file_format(sd=dc.sd, true_file_format=file_format) + + return sd + + +if __name__ == "__main__": + # Define inputs + uri = "tests/flytekit/integration/remote/workflows/basic/data/df.parquet" + file_format = "parquet" + + dc = create_dc(uri=uri, file_format=file_format) + sd = wf(dc=dc, file_format=file_format) + print(sd.file_format) From af1ee980c3e2764b23bdc9fd79159f2a69ef174f Mon Sep 17 00:00:00 2001 From: JiaWei Jiang Date: Sat, 11 Jan 2025 20:26:34 +0800 Subject: [PATCH 5/7] Merge master and support pqt file upload Signed-off-by: JiaWei Jiang --- .../integration/remote/test_remote.py | 33 +++++++++++++++++++ tests/flytekit/integration/remote/utils.py | 3 ++ 2 files changed, 36 insertions(+) diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 491a7b33bb..ee179825b8 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -116,6 +116,7 @@ def test_remote_eager_run(): # child_workflow.parent_wf asynchronously register a parent wf1 with child lp from another wf2. run("eager_example.py", "simple_eager_workflow", "--x", "3") + def test_pydantic_default_input_with_map_task(): execution_id = run("pydantic_wf.py", "wf") remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) @@ -787,6 +788,20 @@ def test_execute_workflow_remote_fn_with_maptask(): ) assert out.outputs["o0"] == [4, 5, 6] + +def test_launch_plans_registrable(): + """Test remote execution of a @workflow-decorated python function with a map task.""" + from workflows.basic.array_map import workflow_with_maptask + + from random import choice + from string import ascii_letters + + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN, interactive_mode_enabled=True) + version = "".join(choice(ascii_letters) for _ in range(20)) + new_lp = LaunchPlan.create(name="dynamically_created_lp", workflow=workflow_with_maptask) + remote.register_launch_plan(new_lp, version=version) + + def test_register_wf_fast(register): from workflows.basic.subworkflows import parent_wf @@ -837,6 +852,24 @@ def test_open_ff(): file_transfer.delete_file(bucket=bucket, key=key) +def test_attr_access_sd(): + """Test accessing StructuredDataset attribute from a dataclass.""" + # Upload a file to minio s3 bucket + file_transfer = SimpleFileTransfer() + remote_file_path = file_transfer.upload_file(file_type="parquet") + + execution_id = run("attr_access_sd.py", "wf", "--uri", remote_file_path) + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + execution = remote.fetch_execution(name=execution_id) + execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5)) + assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}" + + # Delete the remote file to free the space + url = urlparse(remote_file_path) + bucket, key = url.netloc, url.path.lstrip("/") + file_transfer.delete_file(bucket=bucket, key=key) + + def test_sd_attr(): """Test correctness of StructuredDataset attributes. diff --git a/tests/flytekit/integration/remote/utils.py b/tests/flytekit/integration/remote/utils.py index dadc8c6530..c16a0d0f4d 100644 --- a/tests/flytekit/integration/remote/utils.py +++ b/tests/flytekit/integration/remote/utils.py @@ -84,6 +84,9 @@ def _dump_tmp_file(self, file_type: str, tmp_dir: str) -> str: tmp_file_path = pathlib.Path(tmp_dir) / "test.json" with open(tmp_file_path, "w") as f: json.dump(d, f) + elif file_type == "parquet": + # Because `upload_file` accepts a single file only, we specify 00000 to make it a single file + tmp_file_path = pathlib.Path(__file__).parent / "workflows/basic/data/df.parquet/00000" return tmp_file_path From 063801cbe14e30f04bbe934cbc171b1d78dd3ecf Mon Sep 17 00:00:00 2001 From: JiangJiaWei1103 Date: Sun, 19 Jan 2025 17:42:39 +0800 Subject: [PATCH 6/7] Remove redundant condition to always copy file_format over Signed-off-by: JiangJiaWei1103 --- flytekit/types/structured/structured_dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 2ca4d758bc..e829e8db96 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -748,8 +748,7 @@ async def async_to_literal( # Check the user-specified file_format # When users specify file_format for a StructuredDataset, the file_format information must be retained. # For details, please refer to https://github.com/flyteorg/flyte/issues/6096. - if file_format != GENERIC_FORMAT: - sdt.format = file_format + sdt.format = file_format sd_model = literals.StructuredDataset( uri=uri, From d774df1fed0a162a8bb6c3291633ade991f143e2 Mon Sep 17 00:00:00 2001 From: JiangJiaWei1103 Date: Sat, 25 Jan 2025 11:44:05 +0800 Subject: [PATCH 7/7] Prioritize file_format in type hint over the user-specified one Signed-off-by: JiangJiaWei1103 --- flytekit/types/structured/structured_dataset.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index e829e8db96..9ae143f1de 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -746,9 +746,21 @@ async def async_to_literal( uri = await ctx.file_access.async_put_raw_data(uri) # Check the user-specified file_format - # When users specify file_format for a StructuredDataset, the file_format information must be retained. + # When users specify file_format for a StructuredDataset, the file_format should be retained conditionally. # For details, please refer to https://github.com/flyteorg/flyte/issues/6096. - sdt.format = file_format + # Following illustrates why we can't always copy the user-specified file_format over: + # + # @task + # def modify_format(sd: Annotated[StructuredDataset, {}, "task-format"]) -> StructuredDataset: + # return sd + # + # sd = StructuredDataset(uri="s3://my-s3-bucket/df.parquet", file_format="user-format") + # sd2 = modify_format(sd=sd) + # + # In this case, we expect sd2.file_format to be task-format (as shown in Annotated), not user-format. + # If we directly copy the user-specified file_format over, the type hint information will be missing. + if sdt.format == GENERIC_FORMAT and file_format != GENERIC_FORMAT: + sdt.format = file_format sd_model = literals.StructuredDataset( uri=uri,