Skip to content

Commit c5ef3bc

Browse files
Attempt to improve liskov substitution principle error (Python-Cardano#128)
* UPDATE. including network.py module in mypy test * UPDATE: bump mypy version UPDATE: include serialization.py in mypy test UPDATE: Primitive is a union type now UPDATE: f-string formatted values are explicitly converted into strings REMOVE: field_sorter() had no real usage UPDATE: from_primitive() takes a generalized Primitive value UPDATE: narrow Primitive value type within from_primitive() FIX: BlockFrostChainContext constructor to have an empty string default value * ADD. adding network.py module unittests for better coverage
1 parent 3d835c0 commit c5ef3bc

File tree

6 files changed

+161
-94
lines changed

6 files changed

+161
-94
lines changed

poetry.lock

Lines changed: 79 additions & 72 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pycardano/backend/blockfrost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class BlockFrostChainContext(ChainContext):
4949
_protocol_param: Optional[ProtocolParameters] = None
5050

5151
def __init__(
52-
self, project_id: str, network: Network = Network.TESTNET, base_url: str = None
52+
self, project_id: str, network: Network = Network.TESTNET, base_url: str = ""
5353
):
5454
self._network = network
5555
self._project_id = project_id

pycardano/network.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from enum import Enum
66
from typing import Type
77

8-
from pycardano.serialization import CBORSerializable
8+
from pycardano.exception import DeserializeException
9+
from pycardano.serialization import CBORSerializable, Primitive
910

1011
__all__ = ["Network"]
1112

@@ -22,5 +23,9 @@ def to_primitive(self) -> int:
2223
return self.value
2324

2425
@classmethod
25-
def from_primitive(cls: Type[Network], value: int) -> Network:
26+
def from_primitive(cls: Type[Network], value: Primitive) -> Network:
27+
if not isinstance(value, int):
28+
raise DeserializeException(
29+
f"An integer value is required for deserialization: {str(value)}"
30+
)
2631
return cls(value)

pycardano/serialization.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from datetime import datetime
1010
from decimal import Decimal
1111
from inspect import isclass
12-
from typing import Any, Callable, ClassVar, List, Type, TypeVar, Union, get_type_hints
12+
from typing import Any, Callable, List, Type, TypeVar, Union, get_type_hints
1313

1414
from cbor2 import CBOREncoder, CBORSimpleValue, CBORTag, dumps, loads, undefined
1515
from pprintpp import pformat
@@ -53,8 +53,31 @@ class RawCBOR:
5353
cbor: bytes
5454

5555

56-
Primitive = TypeVar(
57-
"Primitive",
56+
Primitive = Union[
57+
bytes,
58+
bytearray,
59+
str,
60+
int,
61+
float,
62+
Decimal,
63+
bool,
64+
None,
65+
tuple,
66+
list,
67+
IndefiniteList,
68+
dict,
69+
defaultdict,
70+
OrderedDict,
71+
undefined.__class__,
72+
datetime,
73+
re.Pattern,
74+
CBORSimpleValue,
75+
CBORTag,
76+
set,
77+
frozenset,
78+
]
79+
80+
PRIMITIVE_TYPES = (
5881
bytes,
5982
bytearray,
6083
str,
@@ -381,10 +404,10 @@ def _restore_dataclass_field(
381404
return t.from_primitive(v)
382405
except DeserializeException:
383406
pass
384-
elif t in Primitive.__constraints__ and isinstance(v, t):
407+
elif t in PRIMITIVE_TYPES and isinstance(v, t):
385408
return v
386409
raise DeserializeException(
387-
f"Cannot deserialize object: \n{v}\n in any valid type from {t_args}."
410+
f"Cannot deserialize object: \n{str(v)}\n in any valid type from {t_args}."
388411
)
389412
return v
390413

@@ -453,8 +476,6 @@ class ArrayCBORSerializable(CBORSerializable):
453476
Test2(c='c', test1=Test1(a='a', b=None))
454477
"""
455478

456-
field_sorter: ClassVar[Callable[[List], List]] = lambda x: x
457-
458479
def to_shallow_primitive(self) -> List[Primitive]:
459480
"""
460481
Returns:
@@ -465,15 +486,15 @@ def to_shallow_primitive(self) -> List[Primitive]:
465486
types.
466487
"""
467488
primitives = []
468-
for f in self.__class__.field_sorter(fields(self)):
489+
for f in fields(self):
469490
val = getattr(self, f.name)
470491
if val is None and f.metadata.get("optional"):
471492
continue
472493
primitives.append(val)
473494
return primitives
474495

475496
@classmethod
476-
def from_primitive(cls: Type[ArrayBase], values: List[Primitive]) -> ArrayBase:
497+
def from_primitive(cls: Type[ArrayBase], values: Primitive) -> ArrayBase:
477498
"""Restore a primitive value to its original class type.
478499
479500
Args:
@@ -660,7 +681,7 @@ def __init__(self, *args, **kwargs):
660681
def __getattr__(self, item):
661682
return getattr(self.data, item)
662683

663-
def __setitem__(self, key: KEY_TYPE, value: VALUE_TYPE):
684+
def __setitem__(self, key: Any, value: Any):
664685
check_type("key", key, self.KEY_TYPE)
665686
check_type("value", value, self.VALUE_TYPE)
666687
self.data[key] = value
@@ -704,7 +725,7 @@ def _get_sortable_val(key):
704725
return dict(sorted(self.data.items(), key=lambda x: _get_sortable_val(x[0])))
705726

706727
@classmethod
707-
def from_primitive(cls: Type[DictBase], value: dict) -> DictBase:
728+
def from_primitive(cls: Type[DictBase], value: Primitive) -> DictBase:
708729
"""Restore a primitive value to its original class type.
709730
710731
Args:
@@ -718,13 +739,17 @@ def from_primitive(cls: Type[DictBase], value: dict) -> DictBase:
718739
DeserializeException: When the object could not be restored from primitives.
719740
"""
720741
if not value:
721-
raise DeserializeException(f"Cannot accept empty value {value}.")
742+
raise DeserializeException(f"Cannot accept empty value {str(value)}.")
743+
if not isinstance(value, dict):
744+
raise DeserializeException(
745+
f"A dictionary value is required for deserialization: {str(value)}"
746+
)
747+
722748
restored = cls()
723749
for k, v in value.items():
724750
k = (
725751
cls.KEY_TYPE.from_primitive(k)
726-
if isclass(cls.VALUE_TYPE)
727-
and issubclass(cls.KEY_TYPE, CBORSerializable)
752+
if isclass(cls.KEY_TYPE) and issubclass(cls.KEY_TYPE, CBORSerializable)
728753
else k
729754
)
730755
v = (
@@ -736,13 +761,13 @@ def from_primitive(cls: Type[DictBase], value: dict) -> DictBase:
736761
restored[k] = v
737762
return restored
738763

739-
def copy(self) -> DictBase:
764+
def copy(self) -> DictCBORSerializable:
740765
return self.__class__(self)
741766

742767

743768
@typechecked
744769
def list_hook(
745-
cls: Type[CBORSerializable],
770+
cls: Type[CBORBase],
746771
) -> Callable[[List[Primitive]], List[CBORBase]]:
747772
"""A factory that generates a Callable which turns a list of Primitive to a list of CBORSerializables.
748773

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ sphinx-copybutton = "^0.5.0"
4444
retry = "^0.9.2"
4545
Flask = "^2.0.3"
4646
pytest-xdist = "^3.0.2"
47-
mypy = "^0.982"
47+
mypy = "^0.990"
4848

4949
[build-system]
5050
requires = ["poetry-core>=1.0.0"]
@@ -76,9 +76,7 @@ exclude = [
7676
'^pycardano/logging.py$',
7777
'^pycardano/metadata.py$',
7878
'^pycardano/nativescript.py$',
79-
'^pycardano/network.py$',
8079
'^pycardano/plutus.py$',
81-
'^pycardano/serialization.py$',
8280
'^pycardano/transaction.py$',
8381
'^pycardano/txbuilder.py$',
8482
'^pycardano/utils.py$',

test/pycardano/test_network.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
3+
from pycardano.exception import DeserializeException
4+
from pycardano.network import Network
5+
6+
7+
def test_from_primitive_invalid_primitive_input():
8+
value = "a string value"
9+
with pytest.raises(DeserializeException):
10+
Network.from_primitive(value)
11+
12+
13+
def test_from_primitive_testnet():
14+
testnet_value = 0
15+
network = Network.from_primitive(testnet_value)
16+
assert network.value == testnet_value
17+
18+
19+
def test_from_primitive_mainnet():
20+
mainnet_value = 1
21+
network = Network.from_primitive(mainnet_value)
22+
assert network.value == mainnet_value
23+
24+
25+
def test_to_primitive_testnet():
26+
network = Network(0)
27+
assert network.to_primitive() == 0
28+
29+
30+
def test_to_primitive_mainnet():
31+
network = Network(1)
32+
assert network.to_primitive() == 1

0 commit comments

Comments
 (0)