Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF-3328] Implement __getitem__ for ArrayVar #3705

Merged
merged 26 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions reflex/experimental/vars/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""Experimental Immutable-Based Var System."""

from .base import ImmutableVar as ImmutableVar
from .base import LiteralObjectVar as LiteralObjectVar
from .base import LiteralVar as LiteralVar
from .base import ObjectVar as ObjectVar
from .base import var_operation as var_operation
from .function import FunctionStringVar as FunctionStringVar
from .function import FunctionVar as FunctionVar
Expand All @@ -12,6 +10,8 @@
from .number import LiteralBooleanVar as LiteralBooleanVar
from .number import LiteralNumberVar as LiteralNumberVar
from .number import NumberVar as NumberVar
from .object import LiteralObjectVar as LiteralObjectVar
from .object import ObjectVar as ObjectVar
from .sequence import ArrayJoinOperation as ArrayJoinOperation
from .sequence import ArrayVar as ArrayVar
from .sequence import ConcatVarOperation as ConcatVarOperation
Expand Down
208 changes: 107 additions & 101 deletions reflex/experimental/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import functools
import sys
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Optional,
Type,
TypeVar,
Union,
overload,
)

from typing_extensions import ParamSpec
Expand All @@ -30,6 +30,17 @@
_global_vars,
)

if TYPE_CHECKING:
from .function import FunctionVar, ToFunctionOperation
from .number import (
BooleanVar,
NumberVar,
ToBooleanVarOperation,
ToNumberVarOperation,
)
from .object import ObjectVar, ToObjectOperation
from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation


@dataclasses.dataclass(
eq=False,
Expand Down Expand Up @@ -265,9 +276,99 @@ def __format__(self, format_spec: str) -> str:
# Encode the _var_data into the formatted output for tracking purposes.
return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._var_name}"

@overload
def to(
self, output: Type[NumberVar], var_type: type[int] | type[float] = float
) -> ToNumberVarOperation: ...

@overload
def to(self, output: Type[BooleanVar]) -> ToBooleanVarOperation: ...

@overload
def to(
self,
output: Type[ArrayVar],
var_type: type[list] | type[tuple] | type[set] = list,
) -> ToArrayOperation: ...

@overload
def to(self, output: Type[StringVar]) -> ToStringOperation: ...

@overload
def to(
self, output: Type[ObjectVar], var_type: Type = dict
) -> ToObjectOperation: ...

@overload
def to(
self, output: Type[FunctionVar], var_type: Type[Callable] = Callable
) -> ToFunctionOperation: ...

@overload
def to(self, output: Type[OUTPUT], var_type: Type | None = None) -> OUTPUT: ...

def to(self, output: Type[OUTPUT], var_type: Type | None = None) -> Var:
"""Convert the var to a different type.

Args:
output: The output type.
var_type: The type of the var.

Raises:
TypeError: If the var_type is not a supported type for the output.

Returns:
The converted var.
"""
from .number import (
BooleanVar,
NumberVar,
ToBooleanVarOperation,
ToNumberVarOperation,
)

if issubclass(output, NumberVar):
if var_type is not None and not issubclass(var_type, (int, float)):
raise TypeError(
f"Unsupported type {var_type} for NumberVar. Must be int or float."
)
return ToNumberVarOperation(self, var_type or float)
if issubclass(output, BooleanVar):
return ToBooleanVarOperation(self)

from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation

if issubclass(output, ArrayVar):
if var_type is not None and not issubclass(var_type, (list, tuple, set)):
raise TypeError(
f"Unsupported type {var_type} for ArrayVar. Must be list, tuple, or set."
)
return ToArrayOperation(self, var_type or list)
if issubclass(output, StringVar):
return ToStringOperation(self)

from .object import ObjectVar, ToObjectOperation

class ObjectVar(ImmutableVar):
"""Base class for immutable object vars."""
if issubclass(output, ObjectVar):
return ToObjectOperation(self, var_type or dict)

from .function import FunctionVar, ToFunctionOperation

if issubclass(output, FunctionVar):
if var_type is not None and not issubclass(var_type, Callable):
raise TypeError(
f"Unsupported type {var_type} for FunctionVar. Must be Callable."
)
return ToFunctionOperation(self, var_type or Callable)

return output(
_var_name=self._var_name,
_var_type=self._var_type if var_type is None else var_type,
_var_data=self._var_data,
)


OUTPUT = TypeVar("OUTPUT", bound=ImmutableVar)


class LiteralVar(ImmutableVar):
Expand Down Expand Up @@ -299,6 +400,8 @@ def create(
if value is None:
return ImmutableVar.create_safe("null", _var_data=_var_data)

from .object import LiteralObjectVar

if isinstance(value, Base):
return LiteralObjectVar(
value.dict(), _var_type=type(value), _var_data=_var_data
Expand Down Expand Up @@ -331,103 +434,6 @@ def __post_init__(self):
"""Post-initialize the var."""


@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralObjectVar(LiteralVar):
"""Base class for immutable literal object vars."""

_var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
default_factory=dict
)

def __init__(
self,
_var_value: dict[Var | Any, Var | Any],
_var_type: Type = dict,
_var_data: VarData | None = None,
):
"""Initialize the object var.

Args:
_var_value: The value of the var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(LiteralObjectVar, self).__init__(
_var_name="",
_var_type=_var_type,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
self,
"_var_value",
_var_value,
)
object.__delattr__(self, "_var_name")

def __getattr__(self, name):
"""Get an attribute of the var.

Args:
name: The name of the attribute.

Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)

@functools.cached_property
def _cached_var_name(self) -> str:
"""The name of the var.

Returns:
The name of the var.
"""
return (
"{ "
+ ", ".join(
[
f"[{str(LiteralVar.create(key))}] : {str(LiteralVar.create(value))}"
for key, value in self._var_value.items()
]
)
+ " }"
)

@functools.cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.

Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
*[
value._get_all_var_data()
for key, value in self._var_value
if isinstance(value, Var)
],
*[
key._get_all_var_data()
for key, value in self._var_value
if isinstance(key, Var)
],
self._var_data,
)

def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.

Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data


P = ParamSpec("P")
T = TypeVar("T", bound=ImmutableVar)

Expand Down
78 changes: 77 additions & 1 deletion reflex/experimental/vars/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import dataclasses
import sys
from functools import cached_property
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Type, Union

from reflex.experimental.vars.base import ImmutableVar, LiteralVar
from reflex.vars import ImmutableVarData, Var, VarData
Expand Down Expand Up @@ -212,3 +212,79 @@ def _get_all_var_data(self) -> ImmutableVarData | None:

def __post_init__(self):
"""Post-initialize the var."""


@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToFunctionOperation(FunctionVar):
"""Base class of converting a var to a function."""

_original_var: Var = dataclasses.field(
default_factory=lambda: LiteralVar.create(None)
)

def __init__(
self,
original_var: Var,
_var_type: Type[Callable] = Callable,
_var_data: VarData | None = None,
) -> None:
"""Initialize the function with arguments var.

Args:
original_var: The original var to convert to a function.
_var_type: The type of the function.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ToFunctionOperation, self).__init__(
_var_name=f"",
_var_type=_var_type,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_original_var", original_var)
object.__delattr__(self, "_var_name")

def __getattr__(self, name):
"""Get an attribute of the var.

Args:
name: The name of the attribute.

Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)

@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.

Returns:
The name of the var.
"""
return str(self._original_var)

@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.

Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._original_var._get_all_var_data(),
self._var_data,
)

def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.

Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
Loading
Loading