diff --git a/hsms/clvm_serde/__init__.py b/hsms/clvm_serde/__init__.py index b85d9ca..cd4a121 100644 --- a/hsms/clvm_serde/__init__.py +++ b/hsms/clvm_serde/__init__.py @@ -1,5 +1,5 @@ from dataclasses import is_dataclass, fields, MISSING -from typing import Any, Callable, Optional, Tuple, Type, get_type_hints +from typing import Any, Callable, Optional, Tuple, Type, Union, get_type_hints from chia_base.meta.type_tree import ArgsType, CompoundLookup, OriginArgsType, TypeTree from chia_base.meta.typing import GenericAlias @@ -52,6 +52,21 @@ def serialize_list(items): return serialize_list +def serialize_for_optional(origin, args, type_tree: TypeTree) -> Program: + if len(args) == 2 and type(None) is args[1]: + write_item = type_tree(args[0]) + + def serialize_optional(item): + if item is None: + return Program.to((Program.null(), Program.null())) + else: + return Program.to((1, write_item(item))) + + return serialize_optional + else: + raise ValueError("No serialization support for Union types (besides Optional)") + + def serialize_for_tuple(origin, args, type_tree: TypeTree) -> Program: write_items = [type_tree(_) for _ in args] @@ -99,6 +114,7 @@ def ser(item): list: serialize_for_list, tuple: serialize_for_tuple, tuple_frugal: ser_for_tuple_frugal, + Union: serialize_for_optional, } @@ -306,11 +322,26 @@ def de(p: Program) -> Tuple[Any, ...]: return de +def deser_for_optional(origin, args, type_tree: TypeTree): + if len(args) == 2 and type(None) is args[1]: + read_item = type_tree(args[0]) + + def deserialize_optional(p: Program): + if p.first() == Program.null(): + return None + else: + return read_item(p.rest()) + + return deserialize_optional + else: + raise ValueError("No serialization support for Union types (besides Optional)") + DESERIALIZER_COMPOUND_TYPE_LOOKUP: CompoundLookup[FromProgram] = { list: deser_for_list, tuple: deser_for_tuple, tuple_frugal: de_for_tuple_frugal, + Union: deser_for_optional, } diff --git a/tests/test_clvm_serde.py b/tests/test_clvm_serde.py index 194a506..ea05159 100644 --- a/tests/test_clvm_serde.py +++ b/tests/test_clvm_serde.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Tuple +from typing import List, Optional, Union, Tuple import random @@ -134,6 +134,41 @@ class Foo: f1 = fp(p) assert f1 == foo + @dataclass + class Foo: + a: Optional[int] = field(metadata=dict(key="a")) + + tp = to_program_for_type(Foo) + fp = from_program_for_type(Foo) + p = Program.to([("a", Program.to((1, 1000)))]) + foo = fp(p) + assert foo.a == 1000 + p1 = tp(foo) + assert p1 == p + p = Program.to([("a", Program.to((Program.null(), Program.null())))]) + foo = fp(p) + assert foo.a is None + p1 = tp(foo) + assert p1 == p + + @dataclass + class Foo: + a: Optional[List[int]] = field(metadata=dict(key="a")) + + tp = to_program_for_type(Foo) + fp = from_program_for_type(Foo) + assert fp(tp(Foo([]))) == Foo([]) + + @dataclass + class Bar: + a: Union[int, str] + + with pytest.raises(ValueError): + _ = to_program_for_type(Bar) + + with pytest.raises(ValueError): + _ = from_program_for_type(Bar) + def test_serde_frugal(): @dataclass