3
3
from __future__ import annotations
4
4
5
5
import inspect
6
+ import sys
6
7
import warnings
7
8
from collections import OrderedDict
8
9
from collections .abc import Mapping
12
13
from logging import getLogger
13
14
from pathlib import Path
14
15
from typing import Any , Callable , TypeVar
15
- from typing_extensions import Literal
16
16
17
17
from simple_parsing .annotation_utils .get_field_annotations import (
18
18
evaluate_string_annotation ,
44
44
_decoding_fns : dict [type [T ], Callable [[Any ], T ]] = {
45
45
# the 'primitive' types are decoded using the type fn as a constructor.
46
46
t : t
47
- for t in [str , float , int , bytes ]
47
+ for t in [str , bytes ]
48
48
}
49
49
50
50
51
- def decode_bool (v : Any ) -> bool :
52
- if isinstance (v , str ):
53
- return str2bool (v )
54
- return bool (v )
51
+ def register_decoding_fn (
52
+ some_type : type [T ], function : Callable [[Any ], T ], overwrite : bool = False
53
+ ) -> None :
54
+ """Register a decoding function for the type `some_type`."""
55
+ _register (some_type , function , overwrite = overwrite )
56
+
57
+
58
+ def _register (t : type , func : Callable , overwrite : bool = False ) -> None :
59
+ if t not in _decoding_fns or overwrite :
60
+ # logger.debug(f"Registering the type {t} with decoding function {func}")
61
+ _decoding_fns [t ] = func
62
+
55
63
64
+ C = TypeVar ("C" , bound = Callable [[Any ], Any ])
56
65
57
- _decoding_fns [bool ] = decode_bool
66
+
67
+ def decoding_fn_for_type (some_type : type ) -> Callable [[C ], C ]:
68
+ """Registers a function to be used to convert a serialized value to the given type.
69
+
70
+ The function should accept one argument (the serialized value) and return the decoded value.
71
+ """
72
+
73
+ def _wrapper (fn : C ) -> C :
74
+ register_decoding_fn (some_type , fn , overwrite = True )
75
+ return fn
76
+
77
+ return _wrapper
78
+
79
+
80
+ @decoding_fn_for_type (int )
81
+ def _decode_int (v : str ) -> int :
82
+ int_v = int (v )
83
+ if isinstance (v , bool ):
84
+ warnings .warn (UnsafeCastingWarning (raw_value = v , decoded_value = int_v ))
85
+ elif int_v != float (v ):
86
+ warnings .warn (UnsafeCastingWarning (raw_value = v , decoded_value = int_v ))
87
+ return int_v
88
+
89
+
90
+ @decoding_fn_for_type (float )
91
+ def _decode_float (v : Any ) -> float :
92
+ float_v = float (v )
93
+ if isinstance (v , bool ):
94
+ warnings .warn (UnsafeCastingWarning (raw_value = v , decoded_value = float_v ))
95
+ return float_v
96
+
97
+
98
+ @decoding_fn_for_type (bool )
99
+ def _decode_bool (v : Any ) -> bool :
100
+ if isinstance (v , str ):
101
+ bool_v = str2bool (v )
102
+ else :
103
+ bool_v = bool (v )
104
+ if isinstance (v , (int , float )) and v not in (0 , 1 , 0.0 , 1.0 ):
105
+ warnings .warn (UnsafeCastingWarning (raw_value = v , decoded_value = bool_v ))
106
+ return bool_v
58
107
59
108
60
109
def decode_field (
@@ -93,11 +142,36 @@ def decode_field(
93
142
94
143
decoding_function = get_decoding_fn (field_type )
95
144
96
- if is_dataclass_type (field_type ) and drop_extra_fields is not None :
97
- # Pass the drop_extra_fields argument to the decoding function.
98
- return decoding_function (raw_value , drop_extra_fields = drop_extra_fields )
145
+ _kwargs = dict (category = UnsafeCastingWarning ) if sys .version_info >= (3 , 11 ) else {}
99
146
100
- return decoding_function (raw_value )
147
+ with warnings .catch_warnings (record = True , ** _kwargs ) as warning_messages :
148
+ if is_dataclass_type (field_type ) and drop_extra_fields is not None :
149
+ # Pass the drop_extra_fields argument to the decoding function.
150
+ decoded_value = decoding_function (raw_value , drop_extra_fields = drop_extra_fields )
151
+ else :
152
+ decoded_value = decoding_function (raw_value )
153
+
154
+ for warning_message in warning_messages .copy ():
155
+ if not isinstance (warning_message .message , UnsafeCastingWarning ):
156
+ warnings .warn_explicit (
157
+ message = warning_message .message ,
158
+ category = warning_message .category ,
159
+ filename = warning_message .filename ,
160
+ lineno = warning_message .lineno ,
161
+ # module=warning_message.module,
162
+ # registry=warning_message.registry,
163
+ # module_globals=warning_message.module_globals,
164
+ )
165
+ warning_messages .remove (warning_message )
166
+
167
+ if warning_messages :
168
+ warnings .warn (
169
+ RuntimeWarning (
170
+ f"Unsafe casting occurred when deserializing field '{ name } ' of type { field_type } : "
171
+ f"raw value: { raw_value !r} , decoded value: { decoded_value !r} ."
172
+ )
173
+ )
174
+ return decoded_value
101
175
102
176
103
177
# NOTE: Disabling the caching here might help avoid some bugs, and it's unclear if this has that
@@ -224,7 +298,7 @@ def get_decoding_fn(type_annotation: type[T] | str) -> Callable[..., T]:
224
298
logger .debug (f"Decoding a typevar: { t } , bound type is { bound } ." )
225
299
if bound is not None :
226
300
return get_decoding_fn (bound )
227
-
301
+
228
302
if is_literal (t ):
229
303
logger .debug (f"Decoding a Literal field: { t } " )
230
304
possible_vals = get_type_arguments (t )
@@ -241,19 +315,6 @@ def get_decoding_fn(type_annotation: type[T] | str) -> Callable[..., T]:
241
315
return try_constructor (t )
242
316
243
317
244
- def _register (t : type , func : Callable , overwrite : bool = False ) -> None :
245
- if t not in _decoding_fns or overwrite :
246
- # logger.debug(f"Registering the type {t} with decoding function {func}")
247
- _decoding_fns [t ] = func
248
-
249
-
250
- def register_decoding_fn (
251
- some_type : type [T ], function : Callable [[Any ], T ], overwrite : bool = False
252
- ) -> None :
253
- """Register a decoding function for the type `some_type`."""
254
- _register (some_type , function , overwrite = overwrite )
255
-
256
-
257
318
def decode_optional (t : type [T ]) -> Callable [[Any | None ], T | None ]:
258
319
decode = get_decoding_fn (t )
259
320
@@ -281,15 +342,21 @@ def _try_functions(val: Any) -> T | Any:
281
342
282
343
283
344
def decode_union (* types : type [T ]) -> Callable [[Any ], T | Any ]:
284
- types = list (types )
285
- optional = type (None ) in types
345
+ types_list = list (types )
346
+ optional = type (None ) in types_list
347
+
286
348
# Partition the Union into None and non-None types.
287
- while type (None ) in types :
288
- types .remove (type (None ))
349
+ while type (None ) in types_list :
350
+ types_list .remove (type (None ))
289
351
290
352
decoding_fns : list [Callable [[Any ], T ]] = [
291
- decode_optional (t ) if optional else get_decoding_fn (t ) for t in types
353
+ decode_optional (t ) if optional else get_decoding_fn (t ) for t in types_list
292
354
]
355
+
356
+ # TODO: We could be a bit smarter about the order in which we try the functions, but for now,
357
+ # we just try the functions in the same order as the annotation, and return the result from the
358
+ # first function that doesn't raise an exception.
359
+
293
360
# Try using each of the non-None types, in succession. Worst case, return the value.
294
361
return try_functions (* decoding_fns )
295
362
@@ -455,3 +522,10 @@ def constructor(val):
455
522
456
523
457
524
register_decoding_fn (Path , Path )
525
+
526
+
527
+ class UnsafeCastingWarning (RuntimeWarning ):
528
+ def __init__ (self , raw_value : Any , decoded_value : Any ) -> None :
529
+ super ().__init__ ()
530
+ self .raw_value = raw_value
531
+ self .decoded_value = decoded_value
0 commit comments