Skip to content

Commit

Permalink
Fix Literal type checking in newer versions of typing_extensions/Py…
Browse files Browse the repository at this point in the history
…thon, and make 1 a valid value for `float` type checks.
  • Loading branch information
matthewwardrop committed Jul 26, 2023
1 parent a7a11fe commit 65e9854
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 4 deletions.
2 changes: 1 addition & 1 deletion spec_classes/spec_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def bootstrap(self, spec_cls: type):
# Update __annotations__ attribute to be consistent with spec_class
# typings (unless already defined on the class contrarily)
if not hasattr(spec_cls, "__annotations__"):
spec_cls.__annotations__ = {}
spec_cls.__annotations__ = {} # pragma: no cover
for attr, attr_spec in metadata.attrs.items():
if attr_spec.owner is spec_cls and attr not in spec_cls.__annotations__:
spec_cls.__annotations__[attr] = attr_spec.type
Expand Down
13 changes: 11 additions & 2 deletions spec_classes/utils/type_checking.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import numbers
from collections.abc import Sequence as SequenceMutator
from collections.abc import Set as SetMutator
from typing import (
Expand All @@ -11,7 +12,12 @@
Union,
_GenericAlias,
) # pylint: disable=protected-access
from typing_extensions import Literal
from typing_extensions import Literal as LiteralExtension

try:
from typing import Literal
except ImportError: # pragma: no cover
from typing_extensions import Literal # pylint: disable=reimported


def type_match(type_input: Type, type_reference: type) -> bool:
Expand All @@ -32,11 +38,14 @@ def check_type(value: Any, attr_type: Type) -> bool:
if attr_type is Any:
return True

if attr_type is float:
attr_type = numbers.Real

if hasattr(attr_type, "__origin__"): # we are dealing with a `typing` object.
if attr_type.__origin__ is Union:
return any(check_type(value, type_) for type_ in attr_type.__args__)

if attr_type.__origin__ is Literal:
if attr_type.__origin__ in (Literal, LiteralExtension):
return value in attr_type.__args__

if isinstance(attr_type, _GenericAlias):
Expand Down
2 changes: 1 addition & 1 deletion tests/types/test_validated.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_bounded(self):
assert not isinstance(1.0, b2)
assert isinstance(2.0, b2)
assert isinstance(1.5, b2)
assert not isinstance(2, b2)
assert isinstance(2, b2)
assert not isinstance(0.0, b2)
assert not isinstance(5, b2)
assert b2.__name__ == "float∊(1,2]"
Expand Down
3 changes: 3 additions & 0 deletions tests/utils/test_type_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def test_type_checking(self):
assert check_type("hi", Literal["hi"])
assert not check_type(1, Literal["hi"])

assert check_type(1, float)
assert not check_type(1.0, int)

def test_get_collection_item_type(self):
assert get_collection_item_type(list) is Any
assert get_collection_item_type(List) is Any
Expand Down

0 comments on commit 65e9854

Please sign in to comment.