22
22
from marshmallow_enum import LoadDumpOptions
23
23
from marshmallow_jsonschema import JSONSchema
24
24
from mashumaro .mixins .json import DataClassJSONMixin
25
+ from mashumaro .mixins .orjson import DataClassORJSONMixin
25
26
from typing_extensions import Annotated , get_args , get_origin
26
27
27
28
from flytekit import kwtypes
@@ -2366,6 +2367,10 @@ def test_DataclassTransformer_get_literal_type():
2366
2367
class MyDataClassMashumaro (DataClassJsonMixin ):
2367
2368
x : int
2368
2369
2370
+ @dataclass
2371
+ class MyDataClassMashumaroORJSON (DataClassJsonMixin ):
2372
+ x : int
2373
+
2369
2374
@dataclass_json
2370
2375
@dataclass
2371
2376
class MyDataClass :
@@ -2379,6 +2384,9 @@ class MyDataClass:
2379
2384
literal_type = de .get_literal_type (MyDataClassMashumaro )
2380
2385
assert literal_type is not None
2381
2386
2387
+ literal_type = de .get_literal_type (MyDataClassMashumaroORJSON )
2388
+ assert literal_type is not None
2389
+
2382
2390
invalid_json_str = "{ unbalanced_braces"
2383
2391
with pytest .raises (Exception ):
2384
2392
Literal (scalar = Scalar (generic = _json_format .Parse (invalid_json_str , _struct .Struct ())))
@@ -2389,6 +2397,10 @@ def test_DataclassTransformer_to_literal():
2389
2397
class MyDataClassMashumaro (DataClassJsonMixin ):
2390
2398
x : int
2391
2399
2400
+ @dataclass
2401
+ class MyDataClassMashumaroORJSON (DataClassORJSONMixin ):
2402
+ x : int
2403
+
2392
2404
@dataclass_json
2393
2405
@dataclass
2394
2406
class MyDataClass :
@@ -2398,12 +2410,19 @@ class MyDataClass:
2398
2410
ctx = FlyteContext .current_context ()
2399
2411
2400
2412
my_dat_class_mashumaro = MyDataClassMashumaro (5 )
2413
+ my_dat_class_mashumaro_orjson = MyDataClassMashumaroORJSON (5 )
2401
2414
my_data_class = MyDataClass (5 )
2402
2415
2403
2416
lv_mashumaro = transformer .to_literal (ctx , my_dat_class_mashumaro , MyDataClassMashumaro , MyDataClassMashumaro )
2404
2417
assert lv_mashumaro is not None
2405
2418
assert lv_mashumaro .scalar .generic ["x" ] == 5
2406
2419
2420
+ lv_mashumaro_orjson = transformer .to_literal (
2421
+ ctx , my_dat_class_mashumaro_orjson , MyDataClassMashumaroORJSON , MyDataClassMashumaroORJSON
2422
+ )
2423
+ assert lv_mashumaro_orjson is not None
2424
+ assert lv_mashumaro_orjson .scalar .generic ["x" ] == 5
2425
+
2407
2426
lv = transformer .to_literal (ctx , my_data_class , MyDataClass , MyDataClass )
2408
2427
assert lv is not None
2409
2428
assert lv .scalar .generic ["x" ] == 5
@@ -2414,6 +2433,10 @@ def test_DataclassTransformer_to_python_value():
2414
2433
class MyDataClassMashumaro (DataClassJsonMixin ):
2415
2434
x : int
2416
2435
2436
+ @dataclass
2437
+ class MyDataClassMashumaroORJSON (DataClassORJSONMixin ):
2438
+ x : int
2439
+
2417
2440
@dataclass_json
2418
2441
@dataclass
2419
2442
class MyDataClass :
@@ -2432,8 +2455,18 @@ class MyDataClass:
2432
2455
assert isinstance (result , MyDataClassMashumaro )
2433
2456
assert result .x == 5
2434
2457
2458
+ result = de .to_python_value (FlyteContext .current_context (), mock_literal , MyDataClassMashumaroORJSON )
2459
+ assert isinstance (result , MyDataClassMashumaroORJSON )
2460
+ assert result .x == 5
2461
+
2435
2462
2436
2463
def test_DataclassTransformer_guess_python_type ():
2464
+ @dataclass
2465
+ class DatumMashumaroORJSON (DataClassORJSONMixin ):
2466
+ x : int
2467
+ y : Color
2468
+ z : datetime .datetime
2469
+
2437
2470
@dataclass
2438
2471
class DatumMashumaro (DataClassJSONMixin ):
2439
2472
x : int
@@ -2464,6 +2497,16 @@ class Datum(DataClassJSONMixin):
2464
2497
assert datum_mashumaro .x == pv .x
2465
2498
assert datum_mashumaro .y .value == pv .y
2466
2499
2500
+ lt = TypeEngine .to_literal_type (DatumMashumaroORJSON )
2501
+ now = datetime .datetime .now ()
2502
+ datum_mashumaro_orjson = DatumMashumaroORJSON (5 , Color .RED , now )
2503
+ lv = transformer .to_literal (ctx , datum_mashumaro_orjson , DatumMashumaroORJSON , lt )
2504
+ gt = transformer .guess_python_type (lt )
2505
+ pv = transformer .to_python_value (ctx , lv , expected_python_type = gt )
2506
+ assert datum_mashumaro_orjson .x == pv .x
2507
+ assert datum_mashumaro_orjson .y .value == pv .y
2508
+ assert datum_mashumaro_orjson .z .isoformat () == pv .z
2509
+
2467
2510
2468
2511
def test_ListTransformer_get_sub_type ():
2469
2512
assert ListTransformer .get_sub_type_or_none (typing .List [str ]) is str
0 commit comments