Skip to content

Commit

Permalink
enhance types
Browse files Browse the repository at this point in the history
  • Loading branch information
adhami3310 committed Jul 29, 2024
1 parent 0036e92 commit 622d836
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 27 deletions.
92 changes: 87 additions & 5 deletions reflex/experimental/vars/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import dataclasses
import sys
from functools import cached_property
from typing import Any, Dict, Type, Union
from typing import Any, Dict, Tuple, Type, Union

from reflex.experimental.vars.base import ImmutableVar, LiteralVar
from reflex.experimental.vars.sequence import ArrayVar
Expand All @@ -15,6 +15,22 @@
class ObjectVar(ImmutableVar):
"""Base class for immutable object vars."""

def _key_type(self) -> Type:
"""Get the type of the keys of the object.
Returns:
The type of the keys of the object.
"""
return ImmutableVar

def _value_type(self) -> Type:
"""Get the type of the values of the object.
Returns:
The type of the values of the object.
"""
return ImmutableVar

def keys(self) -> ObjectKeysOperation:
"""Get the keys of the object.
Expand Down Expand Up @@ -88,18 +104,19 @@ class LiteralObjectVar(LiteralVar, ObjectVar):
def __init__(
self,
_var_value: dict[Var | Any, Var | Any],
_var_type: Type = dict,
_var_type: Type | None = None,
_var_data: VarData | None = None,
):
"""Initialize the object var.
Args:
_var_value: The value of the var.
_var_type: The type of the var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(LiteralObjectVar, self).__init__(
_var_name="",
_var_type=_var_type,
_var_type=type(_var_value) if _var_type is None else _var_type,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
Expand All @@ -109,6 +126,27 @@ def __init__(
)
object.__delattr__(self, "_var_name")

def _key_type(self) -> Type:
"""Get the type of the keys of the object.
Returns:
The type of the keys of the object.
"""
print(self._var_type)
return (
self._var_type.__args__[0] if hasattr(self._var_type, "__args__") else Any
)

def _value_type(self) -> Type:
"""Get the type of the values of the object.
Returns:
The type of the values of the object.
"""
return (
self._var_type.__args__[1] if hasattr(self._var_type, "__args__") else Any
)

def __getattr__(self, name):
"""Get an attribute of the var.
Expand Down Expand Up @@ -183,6 +221,7 @@ class ObjectToArrayOperation(ArrayVar):
def __init__(
self,
_var_value: ObjectVar,
_var_type: Type = list,
_var_data: VarData | None = None,
):
"""Initialize the object to array operation.
Expand All @@ -193,7 +232,7 @@ def __init__(
"""
super(ObjectToArrayOperation, self).__init__(
_var_name="",
_var_type=list,
_var_type=_var_type,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "value", _var_value)
Expand Down Expand Up @@ -247,6 +286,19 @@ def _get_all_var_data(self) -> ImmutableVarData | None:
class ObjectKeysOperation(ObjectToArrayOperation):
"""Operation to get the keys of an object."""

def __init__(
self,
value: ObjectVar,
_var_data: VarData | None = None,
):
"""Initialize the object keys operation.
Args:
value: The value of the operation.
_var_data: Additional hooks and imports associated with the operation.
"""
super(ObjectKeysOperation, self).__init__(value, value._key_type(), _var_data)

@cached_property
def _cached_var_name(self) -> str:
"""The name of the operation.
Expand All @@ -260,6 +312,21 @@ def _cached_var_name(self) -> str:
class ObjectValuesOperation(ObjectToArrayOperation):
"""Operation to get the values of an object."""

def __init__(
self,
value: ObjectVar,
_var_data: VarData | None = None,
):
"""Initialize the object values operation.
Args:
value: The value of the operation.
_var_data: Additional hooks and imports associated with the operation.
"""
super(ObjectValuesOperation, self).__init__(
value, value._value_type(), _var_data
)

@cached_property
def _cached_var_name(self) -> str:
"""The name of the operation.
Expand All @@ -273,6 +340,21 @@ def _cached_var_name(self) -> str:
class ObjectEntriesOperation(ObjectToArrayOperation):
"""Operation to get the entries of an object."""

def __init__(
self,
value: ObjectVar,
_var_data: VarData | None = None,
):
"""Initialize the object entries operation.
Args:
value: The value of the operation.
_var_data: Additional hooks and imports associated with the operation.
"""
super(ObjectEntriesOperation, self).__init__(
value, Tuple[value._key_type(), value._value_type()], _var_data
)

@cached_property
def _cached_var_name(self) -> str:
"""The name of the operation.
Expand Down Expand Up @@ -386,7 +468,7 @@ def __init__(
"""
super(ObjectItemOperation, self).__init__(
_var_name="",
_var_type=value._var_type,
_var_type=value._value_type(),
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "value", value)
Expand Down
43 changes: 21 additions & 22 deletions reflex/experimental/vars/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ class ArrayJoinOperation(StringVar):
)

def __init__(
self, a: ArrayVar | list, b: StringVar | str, _var_data: VarData | None = None
self, a: ArrayVar, b: StringVar | str, _var_data: VarData | None = None
):
"""Initialize the array join operation var.
Expand All @@ -441,9 +441,7 @@ def __init__(
_var_type=str,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
self, "a", a if isinstance(a, Var) else LiteralArrayVar.create(a)
)
object.__setattr__(self, "a", a)
object.__setattr__(
self, "b", b if isinstance(b, Var) else LiteralStringVar.create(b)
)
Expand Down Expand Up @@ -801,19 +799,21 @@ class LiteralArrayVar(LiteralVar, ArrayVar):

def __init__(
self,
_var_value: list[Var | Any] | tuple[Var | Any] | set[Var | Any],
_var_value: list[Var | Any] | tuple[Var | Any, ...] | set[Var | Any],
_var_type: type[list] | type[tuple] | type[set] | None = None,
_var_data: VarData | None = None,
):
"""Initialize the array var.
Args:
_var_value: The value of the var.
_var_type: The type of the var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(LiteralArrayVar, self).__init__(
_var_name="",
_var_data=ImmutableVarData.merge(_var_data),
_var_type=list,
_var_type=type(_var_value) if _var_type is None else _var_type,
)
object.__setattr__(self, "_var_value", _var_value)
object.__delattr__(self, "_var_name")
Expand Down Expand Up @@ -898,7 +898,7 @@ def __init__(
"""
super(StringSplitOperation, self).__init__(
_var_name="",
_var_type=list,
_var_type=list[str],
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
Expand Down Expand Up @@ -956,7 +956,7 @@ class ArrayToArrayOperation(ArrayVar):

a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([]))

def __init__(self, a: ArrayVar | list[Any], _var_data: VarData | None = None):
def __init__(self, a: ArrayVar, _var_data: VarData | None = None):
"""Initialize the array to array operation var.
Args:
Expand All @@ -965,10 +965,10 @@ def __init__(self, a: ArrayVar | list[Any], _var_data: VarData | None = None):
"""
super(ArrayToArrayOperation, self).__init__(
_var_name="",
_var_type=list,
_var_type=a._var_type,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "a", a if isinstance(a, Var) else LiteralArrayVar(a))
object.__setattr__(self, "a", a)
object.__delattr__(self, "_var_name")

@cached_property
Expand Down Expand Up @@ -1022,9 +1022,7 @@ class ArraySliceOperation(ArrayVar):
a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([]))
_slice: slice = dataclasses.field(default_factory=lambda: slice(None, None, None))

def __init__(
self, a: ArrayVar | list[Any], _slice: slice, _var_data: VarData | None = None
):
def __init__(self, a: ArrayVar, _slice: slice, _var_data: VarData | None = None):
"""Initialize the string slice operation var.
Args:
Expand All @@ -1034,10 +1032,10 @@ def __init__(
"""
super(ArraySliceOperation, self).__init__(
_var_name="",
_var_type=str,
_var_type=a._var_type,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "a", a if isinstance(a, Var) else LiteralArrayVar(a))
object.__setattr__(self, "a", a)
object.__setattr__(self, "_slice", _slice)
object.__delattr__(self, "_var_name")

Expand Down Expand Up @@ -1151,7 +1149,7 @@ class ArrayToNumberOperation(NumberVar):
default_factory=lambda: LiteralArrayVar([]),
)

def __init__(self, a: ArrayVar | list[Any], _var_data: VarData | None = None):
def __init__(self, a: ArrayVar, _var_data: VarData | None = None):
"""Initialize the string to number operation var.
Args:
Expand Down Expand Up @@ -1229,7 +1227,7 @@ class ArrayItemOperation(ImmutableVar):

def __init__(
self,
a: ArrayVar | list[Any],
a: ArrayVar,
i: NumberVar | int,
_var_data: VarData | None = None,
):
Expand All @@ -1242,7 +1240,9 @@ def __init__(
"""
super(ArrayItemOperation, self).__init__(
_var_name="",
_var_type=Any,
_var_type=(
a._var_type.__args__[0] if hasattr(a._var_type, "__args__") else Any
),
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "a", a if isinstance(a, Var) else LiteralArrayVar(a))
Expand Down Expand Up @@ -1391,9 +1391,7 @@ class ArrayContainsOperation(BooleanVar):
a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([]))
b: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None))

def __init__(
self, a: ArrayVar | list[Any], b: Any | Var, _var_data: VarData | None = None
):
def __init__(self, a: ArrayVar, b: Any | Var, _var_data: VarData | None = None):
"""Initialize the array contains operation var.
Args:
Expand All @@ -1403,9 +1401,10 @@ def __init__(
"""
super(ArrayContainsOperation, self).__init__(
_var_name="",
_var_type=bool,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "a", a if isinstance(a, Var) else LiteralArrayVar(a))
object.__setattr__(self, "a", a)
object.__setattr__(self, "b", b if isinstance(b, Var) else LiteralVar.create(b))
object.__delattr__(self, "_var_name")

Expand Down

0 comments on commit 622d836

Please sign in to comment.