Skip to content

Commit f115157

Browse files
Support mashumaro DataClassORJSONMixin (#2080)
* support DataClassJSONMixin Signed-off-by: Quinten Roets <quinten.roets@flawlessai.com> * make union syntax compatible with python3.8 Signed-off-by: Quinten Roets <quinten.roets@flawlessai.com> * add datetime attribute Signed-off-by: Quinten Roets <quinten.roets@flawlessai.com> * centralize serializable checking Signed-off-by: Quinten Roets <quinten.roets@flawlessai.com> * Incorporate feedback Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com> --------- Signed-off-by: Quinten Roets <quinten.roets@flawlessai.com> Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com> Co-authored-by: Eduardo Apolinario <eapolinario@users.noreply.github.com>
1 parent 1c8d4bd commit f115157

File tree

3 files changed

+58
-7
lines changed

3 files changed

+58
-7
lines changed

dev-requirements.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,5 @@ pandas
4545
scikit-learn
4646
types-requests
4747
prometheus-client
48+
49+
orjson

flytekit/core/type_engine.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,13 @@ class Test(DataClassJsonMixin):
325325

326326
def __init__(self):
327327
super().__init__("Object-Dataclass-Transformer", object)
328+
self._serializable_classes = [DataClassJSONMixin, DataClassJsonMixin]
329+
try:
330+
from mashumaro.mixins.orjson import DataClassORJSONMixin
331+
332+
self._serializable_classes.append(DataClassORJSONMixin)
333+
except ModuleNotFoundError:
334+
pass
328335

329336
def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
330337
# Skip iterating all attributes in the dataclass if the type of v already matches the expected_type
@@ -417,7 +424,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
417424
f"Type {t} cannot be parsed."
418425
)
419426

420-
if not issubclass(t, DataClassJsonMixin) and not issubclass(t, DataClassJSONMixin):
427+
if not self.is_serializable_class(t):
421428
raise AssertionError(
422429
f"Dataclass {t} should be decorated with @dataclass_json or mixin with DataClassJSONMixin to be "
423430
f"serialized correctly"
@@ -465,6 +472,9 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
465472

466473
return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT, metadata=schema, structure=ts)
467474

475+
def is_serializable_class(self, class_: Type[T]) -> bool:
476+
return any(issubclass(class_, serializable_class) for serializable_class in self._serializable_classes)
477+
468478
def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
469479
if isinstance(python_val, dict):
470480
json_str = json.dumps(python_val)
@@ -475,9 +485,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
475485
f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for "
476486
f"user defined datatypes in Flytekit"
477487
)
478-
if not issubclass(type(python_val), DataClassJsonMixin) and not issubclass(
479-
type(python_val), DataClassJSONMixin
480-
):
488+
if not self.is_serializable_class(type(python_val)):
481489
raise TypeTransformerFailedError(
482490
f"Dataclass {python_type} should be decorated with @dataclass_json or inherit DataClassJSONMixin to be "
483491
f"serialized correctly"
@@ -730,9 +738,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
730738
f"{expected_python_type} is not of type @dataclass, only Dataclasses are supported for "
731739
"user defined datatypes in Flytekit"
732740
)
733-
if not issubclass(expected_python_type, DataClassJsonMixin) and not issubclass(
734-
expected_python_type, DataClassJSONMixin
735-
):
741+
if not self.is_serializable_class(expected_python_type):
736742
raise TypeTransformerFailedError(
737743
f"Dataclass {expected_python_type} should be decorated with @dataclass_json or mixin with DataClassJSONMixin to be "
738744
f"serialized correctly"

tests/flytekit/unit/core/test_type_engine.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from marshmallow_enum import LoadDumpOptions
2323
from marshmallow_jsonschema import JSONSchema
2424
from mashumaro.mixins.json import DataClassJSONMixin
25+
from mashumaro.mixins.orjson import DataClassORJSONMixin
2526
from typing_extensions import Annotated, get_args, get_origin
2627

2728
from flytekit import kwtypes
@@ -2366,6 +2367,10 @@ def test_DataclassTransformer_get_literal_type():
23662367
class MyDataClassMashumaro(DataClassJsonMixin):
23672368
x: int
23682369

2370+
@dataclass
2371+
class MyDataClassMashumaroORJSON(DataClassJsonMixin):
2372+
x: int
2373+
23692374
@dataclass_json
23702375
@dataclass
23712376
class MyDataClass:
@@ -2379,6 +2384,9 @@ class MyDataClass:
23792384
literal_type = de.get_literal_type(MyDataClassMashumaro)
23802385
assert literal_type is not None
23812386

2387+
literal_type = de.get_literal_type(MyDataClassMashumaroORJSON)
2388+
assert literal_type is not None
2389+
23822390
invalid_json_str = "{ unbalanced_braces"
23832391
with pytest.raises(Exception):
23842392
Literal(scalar=Scalar(generic=_json_format.Parse(invalid_json_str, _struct.Struct())))
@@ -2389,6 +2397,10 @@ def test_DataclassTransformer_to_literal():
23892397
class MyDataClassMashumaro(DataClassJsonMixin):
23902398
x: int
23912399

2400+
@dataclass
2401+
class MyDataClassMashumaroORJSON(DataClassORJSONMixin):
2402+
x: int
2403+
23922404
@dataclass_json
23932405
@dataclass
23942406
class MyDataClass:
@@ -2398,12 +2410,19 @@ class MyDataClass:
23982410
ctx = FlyteContext.current_context()
23992411

24002412
my_dat_class_mashumaro = MyDataClassMashumaro(5)
2413+
my_dat_class_mashumaro_orjson = MyDataClassMashumaroORJSON(5)
24012414
my_data_class = MyDataClass(5)
24022415

24032416
lv_mashumaro = transformer.to_literal(ctx, my_dat_class_mashumaro, MyDataClassMashumaro, MyDataClassMashumaro)
24042417
assert lv_mashumaro is not None
24052418
assert lv_mashumaro.scalar.generic["x"] == 5
24062419

2420+
lv_mashumaro_orjson = transformer.to_literal(
2421+
ctx, my_dat_class_mashumaro_orjson, MyDataClassMashumaroORJSON, MyDataClassMashumaroORJSON
2422+
)
2423+
assert lv_mashumaro_orjson is not None
2424+
assert lv_mashumaro_orjson.scalar.generic["x"] == 5
2425+
24072426
lv = transformer.to_literal(ctx, my_data_class, MyDataClass, MyDataClass)
24082427
assert lv is not None
24092428
assert lv.scalar.generic["x"] == 5
@@ -2414,6 +2433,10 @@ def test_DataclassTransformer_to_python_value():
24142433
class MyDataClassMashumaro(DataClassJsonMixin):
24152434
x: int
24162435

2436+
@dataclass
2437+
class MyDataClassMashumaroORJSON(DataClassORJSONMixin):
2438+
x: int
2439+
24172440
@dataclass_json
24182441
@dataclass
24192442
class MyDataClass:
@@ -2432,8 +2455,18 @@ class MyDataClass:
24322455
assert isinstance(result, MyDataClassMashumaro)
24332456
assert result.x == 5
24342457

2458+
result = de.to_python_value(FlyteContext.current_context(), mock_literal, MyDataClassMashumaroORJSON)
2459+
assert isinstance(result, MyDataClassMashumaroORJSON)
2460+
assert result.x == 5
2461+
24352462

24362463
def test_DataclassTransformer_guess_python_type():
2464+
@dataclass
2465+
class DatumMashumaroORJSON(DataClassORJSONMixin):
2466+
x: int
2467+
y: Color
2468+
z: datetime.datetime
2469+
24372470
@dataclass
24382471
class DatumMashumaro(DataClassJSONMixin):
24392472
x: int
@@ -2464,6 +2497,16 @@ class Datum(DataClassJSONMixin):
24642497
assert datum_mashumaro.x == pv.x
24652498
assert datum_mashumaro.y.value == pv.y
24662499

2500+
lt = TypeEngine.to_literal_type(DatumMashumaroORJSON)
2501+
now = datetime.datetime.now()
2502+
datum_mashumaro_orjson = DatumMashumaroORJSON(5, Color.RED, now)
2503+
lv = transformer.to_literal(ctx, datum_mashumaro_orjson, DatumMashumaroORJSON, lt)
2504+
gt = transformer.guess_python_type(lt)
2505+
pv = transformer.to_python_value(ctx, lv, expected_python_type=gt)
2506+
assert datum_mashumaro_orjson.x == pv.x
2507+
assert datum_mashumaro_orjson.y.value == pv.y
2508+
assert datum_mashumaro_orjson.z.isoformat() == pv.z
2509+
24672510

24682511
def test_ListTransformer_get_sub_type():
24692512
assert ListTransformer.get_sub_type_or_none(typing.List[str]) is str

0 commit comments

Comments
 (0)