9
9
from datetime import datetime
10
10
from decimal import Decimal
11
11
from inspect import isclass
12
- from typing import Any , Callable , ClassVar , List , Type , TypeVar , Union , get_type_hints
12
+ from typing import Any , Callable , List , Type , TypeVar , Union , get_type_hints
13
13
14
14
from cbor2 import CBOREncoder , CBORSimpleValue , CBORTag , dumps , loads , undefined
15
15
from pprintpp import pformat
@@ -53,8 +53,31 @@ class RawCBOR:
53
53
cbor : bytes
54
54
55
55
56
- Primitive = TypeVar (
57
- "Primitive" ,
56
+ Primitive = Union [
57
+ bytes ,
58
+ bytearray ,
59
+ str ,
60
+ int ,
61
+ float ,
62
+ Decimal ,
63
+ bool ,
64
+ None ,
65
+ tuple ,
66
+ list ,
67
+ IndefiniteList ,
68
+ dict ,
69
+ defaultdict ,
70
+ OrderedDict ,
71
+ undefined .__class__ ,
72
+ datetime ,
73
+ re .Pattern ,
74
+ CBORSimpleValue ,
75
+ CBORTag ,
76
+ set ,
77
+ frozenset ,
78
+ ]
79
+
80
+ PRIMITIVE_TYPES = (
58
81
bytes ,
59
82
bytearray ,
60
83
str ,
@@ -381,10 +404,10 @@ def _restore_dataclass_field(
381
404
return t .from_primitive (v )
382
405
except DeserializeException :
383
406
pass
384
- elif t in Primitive . __constraints__ and isinstance (v , t ):
407
+ elif t in PRIMITIVE_TYPES and isinstance (v , t ):
385
408
return v
386
409
raise DeserializeException (
387
- f"Cannot deserialize object: \n { v } \n in any valid type from { t_args } ."
410
+ f"Cannot deserialize object: \n { str ( v ) } \n in any valid type from { t_args } ."
388
411
)
389
412
return v
390
413
@@ -453,8 +476,6 @@ class ArrayCBORSerializable(CBORSerializable):
453
476
Test2(c='c', test1=Test1(a='a', b=None))
454
477
"""
455
478
456
- field_sorter : ClassVar [Callable [[List ], List ]] = lambda x : x
457
-
458
479
def to_shallow_primitive (self ) -> List [Primitive ]:
459
480
"""
460
481
Returns:
@@ -465,15 +486,15 @@ def to_shallow_primitive(self) -> List[Primitive]:
465
486
types.
466
487
"""
467
488
primitives = []
468
- for f in self . __class__ . field_sorter ( fields (self ) ):
489
+ for f in fields (self ):
469
490
val = getattr (self , f .name )
470
491
if val is None and f .metadata .get ("optional" ):
471
492
continue
472
493
primitives .append (val )
473
494
return primitives
474
495
475
496
@classmethod
476
- def from_primitive (cls : Type [ArrayBase ], values : List [ Primitive ] ) -> ArrayBase :
497
+ def from_primitive (cls : Type [ArrayBase ], values : Primitive ) -> ArrayBase :
477
498
"""Restore a primitive value to its original class type.
478
499
479
500
Args:
@@ -660,7 +681,7 @@ def __init__(self, *args, **kwargs):
660
681
def __getattr__ (self , item ):
661
682
return getattr (self .data , item )
662
683
663
- def __setitem__ (self , key : KEY_TYPE , value : VALUE_TYPE ):
684
+ def __setitem__ (self , key : Any , value : Any ):
664
685
check_type ("key" , key , self .KEY_TYPE )
665
686
check_type ("value" , value , self .VALUE_TYPE )
666
687
self .data [key ] = value
@@ -704,7 +725,7 @@ def _get_sortable_val(key):
704
725
return dict (sorted (self .data .items (), key = lambda x : _get_sortable_val (x [0 ])))
705
726
706
727
@classmethod
707
- def from_primitive (cls : Type [DictBase ], value : dict ) -> DictBase :
728
+ def from_primitive (cls : Type [DictBase ], value : Primitive ) -> DictBase :
708
729
"""Restore a primitive value to its original class type.
709
730
710
731
Args:
@@ -718,13 +739,17 @@ def from_primitive(cls: Type[DictBase], value: dict) -> DictBase:
718
739
DeserializeException: When the object could not be restored from primitives.
719
740
"""
720
741
if not value :
721
- raise DeserializeException (f"Cannot accept empty value { value } ." )
742
+ raise DeserializeException (f"Cannot accept empty value { str (value )} ." )
743
+ if not isinstance (value , dict ):
744
+ raise DeserializeException (
745
+ f"A dictionary value is required for deserialization: { str (value )} "
746
+ )
747
+
722
748
restored = cls ()
723
749
for k , v in value .items ():
724
750
k = (
725
751
cls .KEY_TYPE .from_primitive (k )
726
- if isclass (cls .VALUE_TYPE )
727
- and issubclass (cls .KEY_TYPE , CBORSerializable )
752
+ if isclass (cls .KEY_TYPE ) and issubclass (cls .KEY_TYPE , CBORSerializable )
728
753
else k
729
754
)
730
755
v = (
@@ -736,13 +761,13 @@ def from_primitive(cls: Type[DictBase], value: dict) -> DictBase:
736
761
restored [k ] = v
737
762
return restored
738
763
739
- def copy (self ) -> DictBase :
764
+ def copy (self ) -> DictCBORSerializable :
740
765
return self .__class__ (self )
741
766
742
767
743
768
@typechecked
744
769
def list_hook (
745
- cls : Type [CBORSerializable ],
770
+ cls : Type [CBORBase ],
746
771
) -> Callable [[List [Primitive ]], List [CBORBase ]]:
747
772
"""A factory that generates a Callable which turns a list of Primitive to a list of CBORSerializables.
748
773
0 commit comments