Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
make object var handle all mapping instead of just dict
Browse files Browse the repository at this point in the history
adhami3310 committed Jan 8, 2025
1 parent 5d877d5 commit c3b8e7a
Showing 4 changed files with 55 additions and 30 deletions.
16 changes: 16 additions & 0 deletions reflex/utils/types.py
Original file line number Diff line number Diff line change
@@ -829,6 +829,22 @@ def wrapper(*args, **kwargs):
StateIterBases = get_base_class(StateIterVar)


def safe_issubclass(cls: Type, cls_check: Type | Tuple[Type, ...]):
"""Check if a class is a subclass of another class. Returns False if internal error occurs.
Args:
cls: The class to check.
cls_check: The class to check against.
Returns:
Whether the class is a subclass of the other class.
"""
try:
return issubclass(cls, cls_check)
except TypeError:
return False


def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
"""Check if a type hint is a subclass of another type hint.
5 changes: 4 additions & 1 deletion reflex/vars/base.py
Original file line number Diff line number Diff line change
@@ -64,6 +64,7 @@
_isinstance,
get_origin,
has_args,
safe_issubclass,
unionize,
)

@@ -686,7 +687,9 @@ def to(

# If the first argument is a python type, we map it to the corresponding Var type.
for var_subclass in _var_subclasses[::-1]:
if fixed_output_type in var_subclass.python_types:
if fixed_output_type in var_subclass.python_types or safe_issubclass(
fixed_output_type, var_subclass.python_types
):
return self.to(var_subclass.var_subclass, output)

if fixed_output_type is None:
62 changes: 34 additions & 28 deletions reflex/vars/object.py
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
from inspect import isclass
from typing import (
Any,
Dict,
List,
Mapping,
NoReturn,
Tuple,
Type,
@@ -36,7 +36,7 @@
from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
from .sequence import ArrayVar, StringVar

OBJECT_TYPE = TypeVar("OBJECT_TYPE")
OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True)

KEY_TYPE = TypeVar("KEY_TYPE")
VALUE_TYPE = TypeVar("VALUE_TYPE")
@@ -46,7 +46,7 @@
OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE")


class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
"""Base class for immutable object vars."""

def _key_type(self) -> Type:
@@ -59,7 +59,7 @@ def _key_type(self) -> Type:

@overload
def _value_type(
self: ObjectVar[Dict[Any, VALUE_TYPE]],
self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> Type[VALUE_TYPE]: ...

@overload
@@ -87,7 +87,7 @@ def keys(self) -> ArrayVar[List[str]]:

@overload
def values(
self: ObjectVar[Dict[Any, VALUE_TYPE]],
self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> ArrayVar[List[VALUE_TYPE]]: ...

@overload
@@ -103,7 +103,7 @@ def values(self) -> ArrayVar:

@overload
def entries(
self: ObjectVar[Dict[Any, VALUE_TYPE]],
self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ...

@overload
@@ -133,49 +133,55 @@ def merge(self, other: ObjectVar):
# NoReturn is used here to catch when key value is Any
@overload
def __getitem__(
self: ObjectVar[Dict[Any, NoReturn]],
self: ObjectVar[Mapping[Any, NoReturn]],
key: Var | Any,
) -> Var: ...

@overload
def __getitem__(
self: (ObjectVar[Mapping[Any, bool]]),
key: Var | Any,
) -> BooleanVar: ...

@overload
def __getitem__(
self: (
ObjectVar[Dict[Any, int]]
| ObjectVar[Dict[Any, float]]
| ObjectVar[Dict[Any, int | float]]
ObjectVar[Mapping[Any, int]]
| ObjectVar[Mapping[Any, float]]
| ObjectVar[Mapping[Any, int | float]]
),
key: Var | Any,
) -> NumberVar: ...

@overload
def __getitem__(
self: ObjectVar[Dict[Any, str]],
self: ObjectVar[Mapping[Any, str]],
key: Var | Any,
) -> StringVar: ...

@overload
def __getitem__(
self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]],
self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...

@overload
def __getitem__(
self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]],
self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...

@overload
def __getitem__(
self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]],
self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
key: Var | Any,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...

@overload
def __getitem__(
self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]],
key: Var | Any,
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ...

def __getitem__(self, key: Var | Any) -> Var:
"""Get an item from the object.
@@ -195,49 +201,49 @@ def __getitem__(self, key: Var | Any) -> Var:
# NoReturn is used here to catch when key value is Any
@overload
def __getattr__(
self: ObjectVar[Dict[Any, NoReturn]],
self: ObjectVar[Mapping[Any, NoReturn]],
name: str,
) -> Var: ...

@overload
def __getattr__(
self: (
ObjectVar[Dict[Any, int]]
| ObjectVar[Dict[Any, float]]
| ObjectVar[Dict[Any, int | float]]
ObjectVar[Mapping[Any, int]]
| ObjectVar[Mapping[Any, float]]
| ObjectVar[Mapping[Any, int | float]]
),
name: str,
) -> NumberVar: ...

@overload
def __getattr__(
self: ObjectVar[Dict[Any, str]],
self: ObjectVar[Mapping[Any, str]],
name: str,
) -> StringVar: ...

@overload
def __getattr__(
self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]],
self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...

@overload
def __getattr__(
self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]],
self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...

@overload
def __getattr__(
self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]],
self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
name: str,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...

@overload
def __getattr__(
self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]],
name: str,
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ...

@overload
def __getattr__(
@@ -299,7 +305,7 @@ def contains(self, key: Var | Any) -> BooleanVar:
class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
"""Base class for immutable literal object vars."""

_var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
_var_value: Mapping[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
default_factory=dict
)

@@ -466,7 +472,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
"""
return var_operation_return(
js_expression=f"({{...{lhs}, ...{rhs}}})",
var_type=Dict[
var_type=Mapping[
Union[lhs._key_type(), rhs._key_type()],
Union[lhs._value_type(), rhs._value_type()],
],
2 changes: 1 addition & 1 deletion reflex/vars/sequence.py
Original file line number Diff line number Diff line change
@@ -987,7 +987,7 @@ def __getitem__(self, i: Any) -> ArrayVar[ARRAY_VAR_TYPE] | Var:
raise_unsupported_operand_types("[]", (type(self), type(i)))
return array_item_operation(self, i)

def length(self) -> NumberVar:
def length(self) -> NumberVar[int]:
"""Get the length of the array.
Returns:

0 comments on commit c3b8e7a

Please sign in to comment.