Skip to content

Commit cf41ba1

Browse files
authored
Add support for complex dictionary types (Python-Cardano#180)
* Add support for complex dictionary types * Add a cbor test * Fix parsing CBOR of dicts * Hide identity
1 parent 8f29313 commit cf41ba1

File tree

3 files changed

+86
-2
lines changed

3 files changed

+86
-2
lines changed

pycardano/plutus.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,35 @@ def _dfs(obj):
589589
raise DeserializeException(
590590
f"Unexpected data structure: {f}."
591591
)
592+
elif (
593+
hasattr(f_info.type, "__origin__")
594+
and f_info.type.__origin__ is dict
595+
):
596+
t_args = f_info.type.__args__
597+
if len(t_args) != 2:
598+
raise DeserializeException(
599+
"Dict type with wrong number of arguments"
600+
)
601+
if "map" not in f:
602+
raise DeserializeException(
603+
f'Expected type "map" in object but got "{f}"'
604+
)
605+
key_t = t_args[0]
606+
val_t = t_args[1]
607+
if inspect.isclass(key_t) and issubclass(key_t, PlutusData):
608+
key_convert = key_t.from_dict
609+
else:
610+
key_convert = _dfs
611+
if inspect.isclass(val_t) and issubclass(val_t, PlutusData):
612+
val_convert = val_t.from_dict
613+
else:
614+
val_convert = _dfs
615+
converted_fields.append(
616+
{
617+
key_convert(pair["k"]): val_convert(pair["v"])
618+
for pair in f["map"]
619+
}
620+
)
592621
else:
593622
converted_fields.append(_dfs(f))
594623
return cls(*converted_fields)

pycardano/serialization.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
]
3838

3939

40+
def _identity(x):
41+
return x
42+
43+
4044
class IndefiniteList(UserList):
4145
def __init__(self, li: Primitive): # type: ignore
4246
super().__init__(li) # type: ignore
@@ -415,6 +419,25 @@ def _restore_dataclass_field(
415419
return f.type.from_primitive(v)
416420
elif isclass(f.type) and issubclass(f.type, IndefiniteList):
417421
return IndefiniteList(v)
422+
elif hasattr(f.type, "__origin__") and (f.type.__origin__ is dict):
423+
t_args = f.type.__args__
424+
if len(t_args) != 2:
425+
raise DeserializeException(
426+
f"Dict types need exactly two type arguments, but got {t_args}"
427+
)
428+
key_t = t_args[0]
429+
val_t = t_args[1]
430+
if isclass(key_t) and issubclass(key_t, CBORSerializable):
431+
key_converter = key_t.from_primitive
432+
else:
433+
key_converter = _identity
434+
if isclass(val_t) and issubclass(val_t, CBORSerializable):
435+
val_converter = val_t.from_primitive
436+
else:
437+
val_converter = _identity
438+
if not isinstance(v, dict):
439+
raise DeserializeException(f"Expected dict type but got {type(v)}")
440+
return {key_converter(key): val_converter(val) for key, val in v.items()}
418441
elif hasattr(f.type, "__origin__") and (
419442
f.type.__origin__ is Union or f.type.__origin__ is Optional
420443
):

test/pycardano/test_plutus.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from dataclasses import dataclass, field
1+
from dataclasses import dataclass
2+
import unittest
3+
24
from test.pycardano.util import check_two_way_cbor
3-
from typing import Union, Optional
5+
from typing import Union, Dict
46

57
import pytest
68

@@ -39,6 +41,13 @@ class LargestTest(PlutusData):
3941
CONSTR_ID = 9
4042

4143

44+
@dataclass
45+
class DictTest(PlutusData):
46+
CONSTR_ID = 3
47+
48+
a: Dict[int, LargestTest]
49+
50+
4251
@dataclass
4352
class VestingParam(PlutusData):
4453
CONSTR_ID = 1
@@ -95,6 +104,29 @@ def test_plutus_data_json():
95104
assert my_vesting == VestingParam.from_json(encoded_json)
96105

97106

107+
def test_plutus_data_json_dict():
108+
test = DictTest({0: LargestTest(), 1: LargestTest()})
109+
110+
encoded_json = test.to_json(separators=(",", ":"))
111+
112+
assert (
113+
'{"constructor":3,"fields":[{"map":[{"v":{"constructor":9,"fields":[]},"k":{"int":0}},{"v":{"constructor":9,"fields":[]},"k":{"int":1}}]}]}'
114+
== encoded_json
115+
)
116+
117+
assert test == DictTest.from_json(encoded_json)
118+
119+
120+
def test_plutus_data_cbor_dict():
121+
test = DictTest({0: LargestTest(), 1: LargestTest()})
122+
123+
encoded_cbor = test.to_cbor()
124+
125+
assert "d87c9fa200d905028001d9050280ff" == encoded_cbor
126+
127+
assert test == DictTest.from_cbor(encoded_cbor)
128+
129+
98130
def test_plutus_data_to_json_wrong_type():
99131
test = MyTest(123, b"1234", IndefiniteList([4, 5, 6]), {1: b"1", 2: b"2"})
100132
test.a = "123"

0 commit comments

Comments
 (0)