From 2dcbb90f7a331eb2d75a826053fc41f944ad4aa2 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Wed, 18 Sep 2024 13:43:06 -0700 Subject: [PATCH] Read offloaded literals (#2685) * [WIP] - Read offloaded literals Signed-off-by: Eduardo Apolinario * Use LiteralOffloadedMetadata field Signed-off-by: Eduardo Apolinario * Assert use of offloaded uri to get around typing constraint Signed-off-by: Eduardo Apolinario * Add a bunch of unit tests Signed-off-by: Eduardo Apolinario * Remove TODO and fix comment Signed-off-by: Eduardo Apolinario * Simplify generation of local file to store literal Signed-off-by: Eduardo Apolinario * Rename variable: `local_literal_file` to `literal_local_file` Signed-off-by: Eduardo Apolinario * Fix lint errors Signed-off-by: Eduardo Apolinario --------- Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- flytekit/core/type_engine.py | 10 +- flytekit/models/literals.py | 63 +++++- pyproject.toml | 2 +- .../unit/core/test_offloaded_literals.py | 179 ++++++++++++++++++ tests/flytekit/unit/core/test_type_engine.py | 138 ++++++++++++++ 5 files changed, 388 insertions(+), 4 deletions(-) create mode 100644 tests/flytekit/unit/core/test_offloaded_literals.py diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index d42e2c2a54..861909eedd 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -33,7 +33,7 @@ from flytekit.core.context_manager import FlyteContext from flytekit.core.hash import HashMethod from flytekit.core.type_helpers import load_type_from_tag -from flytekit.core.utils import timeit +from flytekit.core.utils import load_proto_from_file, timeit from flytekit.exceptions import user as user_exceptions from flytekit.interaction.string_literals import literal_map_string_repr from flytekit.lazy_import.lazy_module import is_imported @@ -1155,6 +1155,14 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T """ Converts a Literal value with an expected python type into a python value. """ + # Initiate the process of loading the offloaded literal if offloaded_metadata is set + if lv.offloaded_metadata: + literal_local_file = ctx.file_access.get_random_local_path() + assert lv.offloaded_metadata.uri, "missing offloaded uri" + ctx.file_access.download(lv.offloaded_metadata.uri, literal_local_file) + input_proto = load_proto_from_file(literals_pb2.Literal, literal_local_file) + lv = Literal.from_flyte_idl(input_proto) + transformer = cls.get_transformer(expected_python_type) return transformer.to_python_value(ctx, lv, expected_python_type) diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index 7d6ff76a89..9e14a95ce4 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -8,7 +8,7 @@ from flytekit.exceptions import user as _user_exceptions from flytekit.models import common as _common from flytekit.models.core import types as _core_types -from flytekit.models.types import Error, StructuredDatasetType +from flytekit.models.types import Error, LiteralType, StructuredDatasetType from flytekit.models.types import LiteralType as _LiteralType from flytekit.models.types import OutputReference as _OutputReference from flytekit.models.types import SchemaType as _SchemaType @@ -852,6 +852,52 @@ def from_flyte_idl(cls, pb2_object): ) +class LiteralOffloadedMetadata(_common.FlyteIdlEntity): + def __init__( + self, + uri: Optional[str] = None, + size_bytes: Optional[int] = None, + inferred_type: Optional[LiteralType] = None, + ): + """ + :param Text uri: URI of the offloaded literal + :param int size_bytes: Size in bytes of the offloaded literal proto + :param LiteralType inferred_type: Inferred type of the offloaded literal + """ + self._uri = uri + self._size_bytes = size_bytes + self._inferred_type = inferred_type + + @property + def uri(self): + return self._uri + + @property + def size_bytes(self): + return self._size_bytes + + @property + def inferred_type(self): + return self._inferred_type + + def to_flyte_idl(self): + return _literals_pb2.LiteralOffloadedMetadata( + uri=self.uri, + size_bytes=self.size_bytes, + inferred_type=self.inferred_type.to_flyte_idl() if self.inferred_type else None, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object): + return cls( + uri=pb2_object.uri, + size_bytes=pb2_object.size_bytes, + inferred_type=_LiteralType.from_flyte_idl(pb2_object.inferred_type) + if pb2_object.HasField("inferred_type") + else None, + ) + + class Literal(_common.FlyteIdlEntity): def __init__( self, @@ -860,6 +906,7 @@ def __init__( map: Optional[LiteralMap] = None, hash: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, + offloaded_metadata: Optional[LiteralOffloadedMetadata] = None, ): """ This IDL message represents a literal value in the Flyte ecosystem. @@ -873,6 +920,7 @@ def __init__( self._map = map self._hash = hash self._metadata = metadata + self._offloaded_metadata = offloaded_metadata @property def scalar(self): @@ -925,6 +973,13 @@ def metadata(self) -> Optional[Dict[str, str]]: """ return self._metadata + @property + def offloaded_metadata(self) -> Optional[LiteralOffloadedMetadata]: + """ + This value holds metadata about the offloaded literal. + """ + return self._offloaded_metadata + def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.Literal @@ -935,10 +990,11 @@ def to_flyte_idl(self): map=self.map.to_flyte_idl() if self.map is not None else None, hash=self.hash, metadata=self.metadata, + offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata else None, ) @classmethod - def from_flyte_idl(cls, pb2_object): + def from_flyte_idl(cls, pb2_object: _literals_pb2.Literal): """ :param flyteidl.core.literals_pb2.Literal pb2_object: :rtype: Literal @@ -953,6 +1009,9 @@ def from_flyte_idl(cls, pb2_object): map=LiteralMap.from_flyte_idl(pb2_object.map) if pb2_object.HasField("map") else None, hash=pb2_object.hash if pb2_object.hash else None, metadata={k: v for k, v in pb2_object.metadata.items()} if pb2_object.metadata else None, + offloaded_metadata=LiteralOffloadedMetadata.from_flyte_idl(pb2_object.offloaded_metadata) + if pb2_object.HasField("offloaded_metadata") + else None, ) def set_metadata(self, metadata: Dict[str, str]): diff --git a/pyproject.toml b/pyproject.toml index 8e8fcef90f..ba2cc46e83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.13.1", + "flyteidl>=1.13.4", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57", diff --git a/tests/flytekit/unit/core/test_offloaded_literals.py b/tests/flytekit/unit/core/test_offloaded_literals.py new file mode 100644 index 0000000000..97fd6e97c1 --- /dev/null +++ b/tests/flytekit/unit/core/test_offloaded_literals.py @@ -0,0 +1,179 @@ +from dataclasses import dataclass +import typing + +from mashumaro.mixins.json import DataClassJSONMixin +import pytest +from flytekit import task +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.models import literals as literal_models +from flytekit.core import context_manager +from flytekit.models.types import SimpleType +from flytekit.core.type_engine import TypeEngine + +@pytest.fixture +def flyte_ctx(): + return context_manager.FlyteContext.current_context() + + +def test_task_offloaded_literal_single_input(tmp_path): + @task + def t1(a: int) -> str: + return str(a) + + original_input_literal = literal_models.Literal( + scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=3)) + ) + + # Write offloaded_lv as bytes to a temp file + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(original_input_literal.to_flyte_idl().SerializeToString()) + + offloaded_input_literal = literal_models.Literal( + offloaded_metadata=literal_models.LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=literal_models.LiteralType(simple=SimpleType.INTEGER), + ) + ) + + ctx = context_manager.FlyteContextManager.current_context() + output_lm = t1.dispatch_execute( + ctx, + literal_models.LiteralMap( + literals={ + "a": offloaded_input_literal, + } + ), + ) + assert output_lm.literals["o0"].scalar.primitive.string_value == "3" + + +def test_task_offloaded_literal_multiple_input(tmp_path): + @task + def t1(a: int, b: int) -> int: + return a + b + + original_input_literal_a = literal_models.Literal( + scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=3)) + ) + original_input_literal_b = literal_models.Literal( + scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=4)) + ) + + # Write offloaded_lv as bytes to a temp file + with open(f"{tmp_path}/offloaded_proto_a.pb", "wb") as f: + f.write(original_input_literal_a.to_flyte_idl().SerializeToString()) + with open(f"{tmp_path}/offloaded_proto_b.pb", "wb") as f: + f.write(original_input_literal_b.to_flyte_idl().SerializeToString()) + + offloaded_input_literal_a = literal_models.Literal( + offloaded_metadata=literal_models.LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto_a.pb", + inferred_type=literal_models.LiteralType(simple=SimpleType.INTEGER), + ) + ) + offloaded_input_literal_b = literal_models.Literal( + offloaded_metadata=literal_models.LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto_b.pb", + inferred_type=literal_models.LiteralType(simple=SimpleType.INTEGER), + ) + ) + + ctx = context_manager.FlyteContextManager.current_context() + output_lm = t1.dispatch_execute( + ctx, + literal_models.LiteralMap( + literals={ + "a": offloaded_input_literal_a, + "b": offloaded_input_literal_b, + } + ), + ) + assert output_lm.literals["o0"].scalar.primitive.integer == 7 + + +def test_task_offloaded_literal_single_dataclass(tmp_path, flyte_ctx): + @dataclass + class DC(DataClassJSONMixin): + x: int + y: str + z: typing.List[int] + + @task + def t1(dc: DC) -> DC: + return dc + + lt = TypeEngine.to_literal_type(DC) + original_input_literal = TypeEngine.to_literal(flyte_ctx, DC(x=3, y="hello", z=[1, 2, 3]), DC, lt) + + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(original_input_literal.to_flyte_idl().SerializeToString()) + + offloaded_input_literal = literal_models.Literal( + offloaded_metadata=literal_models.LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=lt, + ) + ) + + ctx = context_manager.FlyteContextManager.current_context() + output_lm = t1.dispatch_execute( + ctx, + literal_models.LiteralMap( + literals={ + "dc": offloaded_input_literal, + } + ), + ) + assert output_lm.literals["o0"] == original_input_literal + + +def test_task_offloaded_literal_list_int(tmp_path): + @task + def t1(xs: typing.List[int]) -> typing.List[str]: + return [str(a) for a in xs] + + original_input_literal = literal_models.Literal( + collection=literal_models.LiteralCollection( + literals=[ + literal_models.Literal( + scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=3)) + ), + literal_models.Literal( + scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=4)) + ), + ] + ) + ) + expected_output_literal = literal_models.Literal( + collection=literal_models.LiteralCollection( + literals=[ + literal_models.Literal( + scalar=literal_models.Scalar(primitive=literal_models.Primitive(string_value="3")) + ), + literal_models.Literal( + scalar=literal_models.Scalar(primitive=literal_models.Primitive(string_value="4")) + ), + ] + ) + ) + + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(original_input_literal.to_flyte_idl().SerializeToString()) + + offloaded_input_literal = literal_models.Literal( + offloaded_metadata=literal_models.LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=literal_models.LiteralType(collection_type=SimpleType.INTEGER), + ) + ) + + ctx = context_manager.FlyteContextManager.current_context() + output_lm = t1.dispatch_execute( + ctx, + literal_models.LiteralMap( + literals={ + "xs": offloaded_input_literal, + } + ), + ) + assert output_lm.literals["o0"] == expected_output_literal diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 58bba44151..a8e4cd31a8 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -59,6 +59,7 @@ Literal, LiteralCollection, LiteralMap, + LiteralOffloadedMetadata, Primitive, Scalar, Void, @@ -3204,6 +3205,143 @@ def test_union_file_directory(): assert pv._remote_source == s3_dir +@pytest.mark.parametrize( + "pt,pv", + [ + (bool, True), + (bool, False), + (int, 42), + (str, "hello"), + (Annotated[int, "tag"], 42), + (typing.List[int], [1, 2, 3]), + (typing.List[str], ["a", "b", "c"]), + (typing.List[Color], [Color.RED, Color.GREEN, Color.BLUE]), + (typing.List[Annotated[int, "tag"]], [1, 2, 3]), + (typing.List[Annotated[str, "tag"]], ["a", "b", "c"]), + (typing.Dict[int, str], {"1": "a", "2": "b", "3": "c"}), + (typing.Dict[str, int], {"a": 1, "b": 2, "c": 3}), + (typing.Dict[str, typing.List[int]], {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}), + (typing.Dict[str, typing.Dict[int, str]], {"a": {"1": "a", "2": "b", "3": "c"}, "b": {"4": "d", "5": "e", "6": "f"}}), + (typing.Union[int, str], 42), + (typing.Union[int, str], "hello"), + (typing.Union[typing.List[int], typing.List[str]], [1, 2, 3]), + (typing.Union[typing.List[int], typing.List[str]], ["a", "b", "c"]), + (typing.Union[typing.List[int], str], [1, 2, 3]), + (typing.Union[typing.List[int], str], "hello"), + ], +) +def test_offloaded_literal(tmp_path, pt, pv): + ctx = FlyteContext.current_context() + + lt = TypeEngine.to_literal_type(pt) + to_be_offloaded_lv = TypeEngine.to_literal(ctx, pv, pt, lt) + + # Write offloaded_lv as bytes to a temp file + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(to_be_offloaded_lv.to_flyte_idl().SerializeToString()) + + literal = Literal( + offloaded_metadata=LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=lt, + ), + ) + + loaded_pv = TypeEngine.to_python_value(ctx, literal, pt) + assert loaded_pv == pv + + +def test_offloaded_literal_with_inferred_type(): + ctx = FlyteContext.current_context() + lt = TypeEngine.to_literal_type(str) + offloaded_literal_missing_uri = Literal( + offloaded_metadata=LiteralOffloadedMetadata( + inferred_type=lt, + ), + ) + with pytest.raises(AssertionError): + TypeEngine.to_python_value(ctx, offloaded_literal_missing_uri, str) + + +def test_offloaded_literal_dataclass(tmp_path): + @dataclass + class InnerDatum(DataClassJsonMixin): + x: int + y: str + + @dataclass + class Datum(DataClassJsonMixin): + inner: InnerDatum + x: int + y: str + z: typing.Dict[int, int] + w: List[int] + + datum = Datum( + inner=InnerDatum(x=1, y="1"), + x=1, + y="1", + z={1: 1}, + w=[1, 1, 1, 1], + ) + + ctx = FlyteContext.current_context() + lt = TypeEngine.to_literal_type(Datum) + to_be_offloaded_lv = TypeEngine.to_literal(ctx, datum, Datum, lt) + + # Write offloaded_lv as bytes to a temp file + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(to_be_offloaded_lv.to_flyte_idl().SerializeToString()) + + literal = Literal( + offloaded_metadata=LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=lt, + ), + ) + + loaded_datum = TypeEngine.to_python_value(ctx, literal, Datum) + assert loaded_datum == datum + + +def test_offloaded_literal_flytefile(tmp_path): + ctx = FlyteContext.current_context() + lt = TypeEngine.to_literal_type(FlyteFile) + to_be_offloaded_lv = TypeEngine.to_literal(ctx, "s3://my-file", FlyteFile, lt) + + # Write offloaded_lv as bytes to a temp file + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(to_be_offloaded_lv.to_flyte_idl().SerializeToString()) + + literal = Literal( + offloaded_metadata=LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=lt, + ), + ) + + loaded_pv = TypeEngine.to_python_value(ctx, literal, FlyteFile) + assert loaded_pv._remote_source == "s3://my-file" + + +def test_offloaded_literal_flytedirectory(tmp_path): + ctx = FlyteContext.current_context() + lt = TypeEngine.to_literal_type(FlyteDirectory) + to_be_offloaded_lv = TypeEngine.to_literal(ctx, "s3://my-dir", FlyteDirectory, lt) + + # Write offloaded_lv as bytes to a temp file + with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f: + f.write(to_be_offloaded_lv.to_flyte_idl().SerializeToString()) + + literal = Literal( + offloaded_metadata=LiteralOffloadedMetadata( + uri=f"{tmp_path}/offloaded_proto.pb", + inferred_type=lt, + ), + ) + + loaded_pv: FlyteDirectory = TypeEngine.to_python_value(ctx, literal, FlyteDirectory) + assert loaded_pv._remote_source == "s3://my-dir" @pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") def test_dataclass_none_output_input_deserialization(): @dataclass