Skip to content

[BUG] Fix StructuredDataset empty-str file_format in dc attr access #3027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,10 +739,31 @@ async def async_to_literal(
# return StructuredDataset(uri=uri)
if python_val.dataframe is None:
uri = python_val.uri
file_format = python_val.file_format

# Check the user-specified uri
if not uri:
raise ValueError(f"If dataframe is not specified, then the uri should be specified. {python_val}")
if not ctx.file_access.is_remote(uri):
uri = await ctx.file_access.async_put_raw_data(uri)

# Check the user-specified file_format
# When users specify file_format for a StructuredDataset, the file_format should be retained conditionally.
# For details, please refer to https://github.com/flyteorg/flyte/issues/6096.
# Following illustrates why we can't always copy the user-specified file_format over:
#
# @task
# def modify_format(sd: Annotated[StructuredDataset, {}, "task-format"]) -> StructuredDataset:
# return sd
#
# sd = StructuredDataset(uri="s3://my-s3-bucket/df.parquet", file_format="user-format")
# sd2 = modify_format(sd=sd)
#
# In this case, we expect sd2.file_format to be task-format (as shown in Annotated), not user-format.
# If we directly copy the user-specified file_format over, the type hint information will be missing.
if sdt.format == GENERIC_FORMAT and file_format != GENERIC_FORMAT:
sdt.format = file_format

sd_model = literals.StructuredDataset(
uri=uri,
metadata=StructuredDatasetMetadata(structured_dataset_type=sdt),
Expand Down
46 changes: 46 additions & 0 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import uuid
import pytest
from unittest import mock
from dataclasses import dataclass

from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase
from flytekit.configuration import Config, ImageConfig, SerializationSettings
Expand All @@ -27,6 +28,7 @@
from flytekit.remote.remote import FlyteRemote
from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2
from flytekit.types.schema import FlyteSchema
from flytekit.types.structured import StructuredDataset
from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient
from flytekit.configuration import PlatformConfig

Expand Down Expand Up @@ -877,6 +879,50 @@ def test_attr_access_sd():
bucket, key = url.netloc, url.path.lstrip("/")
file_transfer.delete_file(bucket=bucket, key=key)


def test_sd_attr():
"""Test correctness of StructuredDataset attributes.

This test considers only the following condition:
1. Check StructuredDataset (wrapped in a dataclass) file_format attribute

We'll make sure uri aligns with the user-specified one in the future.
"""
from workflows.basic.sd_attr import wf

@dataclass
class DC:
sd: StructuredDataset

FILE_FORMAT = "parquet"

# Upload a file to minio s3 bucket
file_transfer = SimpleFileTransfer()
remote_file_path = file_transfer.upload_file(file_type=FILE_FORMAT)

# Create a dataclass as the workflow input because `pyflyte run`
# can't properly handle input arg `dc` as a json str so far
dc = DC(sd=StructuredDataset(uri=remote_file_path, file_format=FILE_FORMAT))

remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN, interactive_mode_enabled=True)
wf_exec = remote.execute(
wf,
inputs={"dc": dc, "file_format": FILE_FORMAT},
wait=True,
version=VERSION,
image_config=ImageConfig.from_images(IMAGE),
)
assert wf_exec.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {wf_exec.closure.phase}"
assert wf_exec.outputs["o0"].file_format == FILE_FORMAT, (
f"Workflow output StructuredDataset file_format should align with the user-specified file_format: {FILE_FORMAT}."
)

# Delete the remote file to free the space
url = urlparse(remote_file_path)
bucket, key = url.netloc, url.path.lstrip("/")
file_transfer.delete_file(bucket=bucket, key=key)


def test_signal_approve_reject(register):
from flytekit.models.types import LiteralType, SimpleType
from time import sleep
Expand Down
68 changes: 68 additions & 0 deletions tests/flytekit/integration/remote/workflows/basic/sd_attr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from dataclasses import dataclass

import pandas as pd
from flytekit import task, workflow
from flytekit.types.structured import StructuredDataset


@dataclass
class DC:
sd: StructuredDataset


@task
def create_dc(uri: str, file_format: str) -> DC:
"""Create a dataclass with a StructuredDataset attribute.

Args:
uri: File URI.
file_format: File format, e.g., parquet, csv.

Returns:
dc: A dataclass with a StructuredDataset attribute.
"""
dc = DC(sd=StructuredDataset(uri=uri, file_format=file_format))

return dc


@task
def check_file_format(sd: StructuredDataset, true_file_format: str) -> StructuredDataset:
"""Check StructuredDataset file_format attribute.

StruturedDataset file_format should align with what users specify.

Args:
sd: Python native StructuredDataset.
true_file_format: User-specified file_format.
"""
assert sd.file_format == true_file_format, (
f"StructuredDataset file_format should align with the user-specified file_format: {true_file_format}."
)
assert sd._literal_sd.metadata.structured_dataset_type.format == true_file_format, (
f"StructuredDatasetType format should align with the user-specified file_format: {true_file_format}."
)
print(f">>> SD <<<\n{sd}")
print(f">>> Literal SD <<<\n{sd._literal_sd}")
print(f">>> SDT <<<\n{sd._literal_sd.metadata.structured_dataset_type}")
print(f">>> DF <<<\n{sd.open(pd.DataFrame).all()}")

return sd


@workflow
def wf(dc: DC, file_format: str) -> StructuredDataset:
# Fail to use dc.sd.file_format as the input
sd = check_file_format(sd=dc.sd, true_file_format=file_format)

return sd


if __name__ == "__main__":
# Define inputs
uri = "tests/flytekit/integration/remote/workflows/basic/data/df.parquet"
file_format = "parquet"

dc = create_dc(uri=uri, file_format=file_format)
sd = wf(dc=dc, file_format=file_format)
print(sd.file_format)
Loading