diff --git a/reflex/utils/types.py b/reflex/utils/types.py index b8bcbf2d69a..ac30342e214 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -829,6 +829,22 @@ def wrapper(*args, **kwargs): StateIterBases = get_base_class(StateIterVar) +def safe_issubclass(cls: Type, cls_check: Type | Tuple[Type, ...]): + """Check if a class is a subclass of another class. Returns False if internal error occurs. + + Args: + cls: The class to check. + cls_check: The class to check against. + + Returns: + Whether the class is a subclass of the other class. + """ + try: + return issubclass(cls, cls_check) + except TypeError: + return False + + def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool: """Check if a type hint is a subclass of another type hint. diff --git a/reflex/vars/base.py b/reflex/vars/base.py index a4496d77d5c..fed4eb86302 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -64,6 +64,7 @@ _isinstance, get_origin, has_args, + safe_issubclass, unionize, ) @@ -686,7 +687,9 @@ def to( # If the first argument is a python type, we map it to the corresponding Var type. for var_subclass in _var_subclasses[::-1]: - if fixed_output_type in var_subclass.python_types: + if fixed_output_type in var_subclass.python_types or safe_issubclass( + fixed_output_type, var_subclass.python_types + ): return self.to(var_subclass.var_subclass, output) if fixed_output_type is None: diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 5de431f5ab1..30c012b1633 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -8,8 +8,8 @@ from inspect import isclass from typing import ( Any, - Dict, List, + Mapping, NoReturn, Tuple, Type, @@ -36,7 +36,7 @@ from .number import BooleanVar, NumberVar, raise_unsupported_operand_types from .sequence import ArrayVar, StringVar -OBJECT_TYPE = TypeVar("OBJECT_TYPE") +OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True) KEY_TYPE = TypeVar("KEY_TYPE") VALUE_TYPE = TypeVar("VALUE_TYPE") @@ -46,7 +46,7 @@ OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE") -class ObjectVar(Var[OBJECT_TYPE], python_types=dict): +class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): """Base class for immutable object vars.""" def _key_type(self) -> Type: @@ -59,7 +59,7 @@ def _key_type(self) -> Type: @overload def _value_type( - self: ObjectVar[Dict[Any, VALUE_TYPE]], + self: ObjectVar[Mapping[Any, VALUE_TYPE]], ) -> Type[VALUE_TYPE]: ... @overload @@ -87,7 +87,7 @@ def keys(self) -> ArrayVar[List[str]]: @overload def values( - self: ObjectVar[Dict[Any, VALUE_TYPE]], + self: ObjectVar[Mapping[Any, VALUE_TYPE]], ) -> ArrayVar[List[VALUE_TYPE]]: ... @overload @@ -103,7 +103,7 @@ def values(self) -> ArrayVar: @overload def entries( - self: ObjectVar[Dict[Any, VALUE_TYPE]], + self: ObjectVar[Mapping[Any, VALUE_TYPE]], ) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ... @overload @@ -133,49 +133,55 @@ def merge(self, other: ObjectVar): # NoReturn is used here to catch when key value is Any @overload def __getitem__( - self: ObjectVar[Dict[Any, NoReturn]], + self: ObjectVar[Mapping[Any, NoReturn]], key: Var | Any, ) -> Var: ... + @overload + def __getitem__( + self: (ObjectVar[Mapping[Any, bool]]), + key: Var | Any, + ) -> BooleanVar: ... + @overload def __getitem__( self: ( - ObjectVar[Dict[Any, int]] - | ObjectVar[Dict[Any, float]] - | ObjectVar[Dict[Any, int | float]] + ObjectVar[Mapping[Any, int]] + | ObjectVar[Mapping[Any, float]] + | ObjectVar[Mapping[Any, int | float]] ), key: Var | Any, ) -> NumberVar: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, str]], + self: ObjectVar[Mapping[Any, str]], key: Var | Any, ) -> StringVar: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]], key: Var | Any, ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]], key: Var | Any, ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]], + self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]], key: Var | Any, ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], + self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]], key: Var | Any, - ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... + ) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ... def __getitem__(self, key: Var | Any) -> Var: """Get an item from the object. @@ -195,49 +201,49 @@ def __getitem__(self, key: Var | Any) -> Var: # NoReturn is used here to catch when key value is Any @overload def __getattr__( - self: ObjectVar[Dict[Any, NoReturn]], + self: ObjectVar[Mapping[Any, NoReturn]], name: str, ) -> Var: ... @overload def __getattr__( self: ( - ObjectVar[Dict[Any, int]] - | ObjectVar[Dict[Any, float]] - | ObjectVar[Dict[Any, int | float]] + ObjectVar[Mapping[Any, int]] + | ObjectVar[Mapping[Any, float]] + | ObjectVar[Mapping[Any, int | float]] ), name: str, ) -> NumberVar: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, str]], + self: ObjectVar[Mapping[Any, str]], name: str, ) -> StringVar: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]], name: str, ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]], name: str, ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]], + self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]], name: str, ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], + self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]], name: str, - ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... + ) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ... @overload def __getattr__( @@ -299,7 +305,7 @@ def contains(self, key: Var | Any) -> BooleanVar: class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): """Base class for immutable literal object vars.""" - _var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field( + _var_value: Mapping[Union[Var, Any], Union[Var, Any]] = dataclasses.field( default_factory=dict ) @@ -466,7 +472,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar): """ return var_operation_return( js_expression=f"({{...{lhs}, ...{rhs}}})", - var_type=Dict[ + var_type=Mapping[ Union[lhs._key_type(), rhs._key_type()], Union[lhs._value_type(), rhs._value_type()], ], diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index 5864e70b998..1b11f50e6e1 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -987,7 +987,7 @@ def __getitem__(self, i: Any) -> ArrayVar[ARRAY_VAR_TYPE] | Var: raise_unsupported_operand_types("[]", (type(self), type(i))) return array_item_operation(self, i) - def length(self) -> NumberVar: + def length(self) -> NumberVar[int]: """Get the length of the array. Returns: