Skip to content

Commit 65b07f3

Browse files
authored
Raise a warning on unsafe/lossy cast to int, float, and bools (#242)
* Raise a warning on lossy int casting when loading - Fixes #227: When the decoding function for ints would cast to an int in a way that loses precision (e.g. from a float), then we raise a warning. Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Broaden test, reduce # of warnings Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Add decoding_fn_for_type decorator, improve test Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix type annotation, add docstring in other fn Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix typing issue with py38 Signed-off-by: Fabrice Normandin <normandf@mila.quebec> --------- Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
1 parent 0cb4e58 commit 65b07f3

File tree

6 files changed

+223
-47
lines changed

6 files changed

+223
-47
lines changed

simple_parsing/helpers/serialization/decoding.py

Lines changed: 104 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import inspect
6+
import sys
67
import warnings
78
from collections import OrderedDict
89
from collections.abc import Mapping
@@ -12,7 +13,6 @@
1213
from logging import getLogger
1314
from pathlib import Path
1415
from typing import Any, Callable, TypeVar
15-
from typing_extensions import Literal
1616

1717
from simple_parsing.annotation_utils.get_field_annotations import (
1818
evaluate_string_annotation,
@@ -44,17 +44,66 @@
4444
_decoding_fns: dict[type[T], Callable[[Any], T]] = {
4545
# the 'primitive' types are decoded using the type fn as a constructor.
4646
t: t
47-
for t in [str, float, int, bytes]
47+
for t in [str, bytes]
4848
}
4949

5050

51-
def decode_bool(v: Any) -> bool:
52-
if isinstance(v, str):
53-
return str2bool(v)
54-
return bool(v)
51+
def register_decoding_fn(
52+
some_type: type[T], function: Callable[[Any], T], overwrite: bool = False
53+
) -> None:
54+
"""Register a decoding function for the type `some_type`."""
55+
_register(some_type, function, overwrite=overwrite)
56+
57+
58+
def _register(t: type, func: Callable, overwrite: bool = False) -> None:
59+
if t not in _decoding_fns or overwrite:
60+
# logger.debug(f"Registering the type {t} with decoding function {func}")
61+
_decoding_fns[t] = func
62+
5563

64+
C = TypeVar("C", bound=Callable[[Any], Any])
5665

57-
_decoding_fns[bool] = decode_bool
66+
67+
def decoding_fn_for_type(some_type: type) -> Callable[[C], C]:
68+
"""Registers a function to be used to convert a serialized value to the given type.
69+
70+
The function should accept one argument (the serialized value) and return the decoded value.
71+
"""
72+
73+
def _wrapper(fn: C) -> C:
74+
register_decoding_fn(some_type, fn, overwrite=True)
75+
return fn
76+
77+
return _wrapper
78+
79+
80+
@decoding_fn_for_type(int)
81+
def _decode_int(v: str) -> int:
82+
int_v = int(v)
83+
if isinstance(v, bool):
84+
warnings.warn(UnsafeCastingWarning(raw_value=v, decoded_value=int_v))
85+
elif int_v != float(v):
86+
warnings.warn(UnsafeCastingWarning(raw_value=v, decoded_value=int_v))
87+
return int_v
88+
89+
90+
@decoding_fn_for_type(float)
91+
def _decode_float(v: Any) -> float:
92+
float_v = float(v)
93+
if isinstance(v, bool):
94+
warnings.warn(UnsafeCastingWarning(raw_value=v, decoded_value=float_v))
95+
return float_v
96+
97+
98+
@decoding_fn_for_type(bool)
99+
def _decode_bool(v: Any) -> bool:
100+
if isinstance(v, str):
101+
bool_v = str2bool(v)
102+
else:
103+
bool_v = bool(v)
104+
if isinstance(v, (int, float)) and v not in (0, 1, 0.0, 1.0):
105+
warnings.warn(UnsafeCastingWarning(raw_value=v, decoded_value=bool_v))
106+
return bool_v
58107

59108

60109
def decode_field(
@@ -93,11 +142,36 @@ def decode_field(
93142

94143
decoding_function = get_decoding_fn(field_type)
95144

96-
if is_dataclass_type(field_type) and drop_extra_fields is not None:
97-
# Pass the drop_extra_fields argument to the decoding function.
98-
return decoding_function(raw_value, drop_extra_fields=drop_extra_fields)
145+
_kwargs = dict(category=UnsafeCastingWarning) if sys.version_info >= (3, 11) else {}
99146

100-
return decoding_function(raw_value)
147+
with warnings.catch_warnings(record=True, **_kwargs) as warning_messages:
148+
if is_dataclass_type(field_type) and drop_extra_fields is not None:
149+
# Pass the drop_extra_fields argument to the decoding function.
150+
decoded_value = decoding_function(raw_value, drop_extra_fields=drop_extra_fields)
151+
else:
152+
decoded_value = decoding_function(raw_value)
153+
154+
for warning_message in warning_messages.copy():
155+
if not isinstance(warning_message.message, UnsafeCastingWarning):
156+
warnings.warn_explicit(
157+
message=warning_message.message,
158+
category=warning_message.category,
159+
filename=warning_message.filename,
160+
lineno=warning_message.lineno,
161+
# module=warning_message.module,
162+
# registry=warning_message.registry,
163+
# module_globals=warning_message.module_globals,
164+
)
165+
warning_messages.remove(warning_message)
166+
167+
if warning_messages:
168+
warnings.warn(
169+
RuntimeWarning(
170+
f"Unsafe casting occurred when deserializing field '{name}' of type {field_type}: "
171+
f"raw value: {raw_value!r}, decoded value: {decoded_value!r}."
172+
)
173+
)
174+
return decoded_value
101175

102176

103177
# NOTE: Disabling the caching here might help avoid some bugs, and it's unclear if this has that
@@ -224,7 +298,7 @@ def get_decoding_fn(type_annotation: type[T] | str) -> Callable[..., T]:
224298
logger.debug(f"Decoding a typevar: {t}, bound type is {bound}.")
225299
if bound is not None:
226300
return get_decoding_fn(bound)
227-
301+
228302
if is_literal(t):
229303
logger.debug(f"Decoding a Literal field: {t}")
230304
possible_vals = get_type_arguments(t)
@@ -241,19 +315,6 @@ def get_decoding_fn(type_annotation: type[T] | str) -> Callable[..., T]:
241315
return try_constructor(t)
242316

243317

244-
def _register(t: type, func: Callable, overwrite: bool = False) -> None:
245-
if t not in _decoding_fns or overwrite:
246-
# logger.debug(f"Registering the type {t} with decoding function {func}")
247-
_decoding_fns[t] = func
248-
249-
250-
def register_decoding_fn(
251-
some_type: type[T], function: Callable[[Any], T], overwrite: bool = False
252-
) -> None:
253-
"""Register a decoding function for the type `some_type`."""
254-
_register(some_type, function, overwrite=overwrite)
255-
256-
257318
def decode_optional(t: type[T]) -> Callable[[Any | None], T | None]:
258319
decode = get_decoding_fn(t)
259320

@@ -281,15 +342,21 @@ def _try_functions(val: Any) -> T | Any:
281342

282343

283344
def decode_union(*types: type[T]) -> Callable[[Any], T | Any]:
284-
types = list(types)
285-
optional = type(None) in types
345+
types_list = list(types)
346+
optional = type(None) in types_list
347+
286348
# Partition the Union into None and non-None types.
287-
while type(None) in types:
288-
types.remove(type(None))
349+
while type(None) in types_list:
350+
types_list.remove(type(None))
289351

290352
decoding_fns: list[Callable[[Any], T]] = [
291-
decode_optional(t) if optional else get_decoding_fn(t) for t in types
353+
decode_optional(t) if optional else get_decoding_fn(t) for t in types_list
292354
]
355+
356+
# TODO: We could be a bit smarter about the order in which we try the functions, but for now,
357+
# we just try the functions in the same order as the annotation, and return the result from the
358+
# first function that doesn't raise an exception.
359+
293360
# Try using each of the non-None types, in succession. Worst case, return the value.
294361
return try_functions(*decoding_fns)
295362

@@ -455,3 +522,10 @@ def constructor(val):
455522

456523

457524
register_decoding_fn(Path, Path)
525+
526+
527+
class UnsafeCastingWarning(RuntimeWarning):
528+
def __init__(self, raw_value: Any, decoded_value: Any) -> None:
529+
super().__init__()
530+
self.raw_value = raw_value
531+
self.decoded_value = decoded_value

simple_parsing/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,13 @@ def unflatten(flattened: Mapping[tuple[K, ...], V]) -> PossiblyNestedDict[K, V]:
917917

918918

919919
def flatten_join(nested: PossiblyNestedMapping[str, V], sep: str = ".") -> dict[str, V]:
920-
"""Flatten a dictionary of dictionaries. When collisions occur, joins the keys with `sep`."""
920+
"""Flatten a dictionary of dictionaries. Joins different nesting levels with `sep` as separator.
921+
922+
>>> flatten_join({'a': {'b': 2, 'c': 3}, 'c': {'d': 3, 'e': 4}})
923+
{'a.b': 2, 'a.c': 3, 'c.d': 3, 'c.e': 4}
924+
>>> flatten_join({'a': {'b': 2, 'c': 3}, 'c': {'d': 3, 'e': 4}}, sep="/")
925+
{'a/b': 2, 'a/c': 3, 'c/d': 3, 'c/e': 4}
926+
"""
921927
return {sep.join(keys): value for keys, value in flatten(nested).items()}
922928

923929

test/conftest.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import pathlib
66
import sys
7+
import warnings
78
from logging import getLogger as get_logger
89
from typing import Any, Generic, TypeVar
910

@@ -173,8 +174,15 @@ def no_stdout(capsys, caplog):
173174
pytest.fail(f"Test generated some output in stderr: '{captured.err}'")
174175

175176

177+
@pytest.fixture(autouse=False)
178+
def no_warnings_raised():
179+
with warnings.catch_warnings(record=True):
180+
warnings.simplefilter("error")
181+
yield
182+
183+
176184
@pytest.fixture
177-
def no_warnings(caplog):
185+
def no_warning_log_messages(caplog):
178186
yield
179187
for when in ("setup", "call"):
180188
messages = [x.message for x in caplog.get_records(when) if x.levelno == logging.WARNING]
@@ -183,7 +191,7 @@ def no_warnings(caplog):
183191

184192

185193
@pytest.fixture
186-
def silent(no_stdout, no_warnings):
194+
def silent(no_stdout, no_warning_log_messages):
187195
"""
188196
Test fixture that will make a test fail if it prints anything to stdout or
189197
logs warnings

test/helpers/test_enum_serialization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ class Parameters(Serializable):
3131
@pytest.mark.xfail(
3232
raises=KeyError, match="'jsonl'", strict=True, reason="Enums are saved by name, not by value."
3333
)
34-
def test_decode_enum_saved_by_value_doesnt_work():
34+
def test_decode_enum_saved_by_value_doesnt_work(tmp_path: Path):
3535
"""Test to reproduce https://github.com/lebrice/SimpleParsing/issues/219#issuecomment-1437817369"""
36-
with open("conf.yaml", "w") as f:
36+
with open(tmp_path / "conf.yaml", "w") as f:
3737
f.write(
3838
textwrap.dedent(
3939
"""\
@@ -45,7 +45,7 @@ def test_decode_enum_saved_by_value_doesnt_work():
4545
)
4646
)
4747

48-
file_config = Parameters.load_yaml("conf.yaml")
48+
file_config = Parameters.load_yaml(tmp_path / "conf.yaml")
4949
assert file_config == Parameters(hparams=Hparams(xyz=[LoggingTypes.JSONL]), p=Path("/tmp"))
5050

5151

0 commit comments

Comments
 (0)