Skip to content

Commit 4297af6

Browse files
authored
Fixed the order of marshal to handle Dataclass with as_dict before other types to avoid SerdeError (#60)
Situation: While doing the installation.save with a Dataclass object, it goes to the cls._marshal_dataclass(type_ref, path, inst) which results in SerdeError. Fix: Change the order of evaluation and bring the below code in the beginning if hasattr(inst, "as_dict"): return inst.as_dict(), True How it's tested: Unit tested with sample Dataclass and it works
1 parent 905e5ff commit 4297af6

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

src/databricks/labs/blueprint/installation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,8 @@ def _marshal(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool
443443
_UnionGenericAlias,
444444
)
445445

446+
if hasattr(inst, "as_dict"):
447+
return inst.as_dict(), True
446448
if dataclasses.is_dataclass(type_ref):
447449
return cls._marshal_dataclass(type_ref, path, inst)
448450
if isinstance(type_ref, types.GenericAlias):
@@ -463,8 +465,7 @@ def _marshal(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool
463465
return cls._marshal_databricks_config(inst)
464466
if type_ref in cls._PRIMITIVES:
465467
return inst, True
466-
if hasattr(inst, "as_dict"):
467-
return inst.as_dict(), True
468+
468469
raise SerdeError(f'{".".join(path)}: unknown: {inst}')
469470

470471
@classmethod

tests/unit/test_installation.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,3 +385,21 @@ def test_as_dict_serde():
385385

386386
load = installation.load(SomePolicy, filename="backups/policy-123.json")
387387
assert load == policy
388+
389+
390+
@dataclass
391+
class Policy:
392+
policy_id: str
393+
name: str
394+
395+
def as_dict(self) -> dict:
396+
return {"policy_id": self.policy_id, "name": self.name}
397+
398+
399+
def test_data_class():
400+
installation = MockInstallation()
401+
policy = Policy("123", "foo")
402+
installation.save(policy, filename="backups/policy-test.json")
403+
installation.assert_file_written("backups/policy-test.json", {"policy_id": "123", "name": "foo"})
404+
load = installation.load(Policy, filename="backups/policy-test.json")
405+
assert load == policy

0 commit comments

Comments
 (0)