Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF-3328] Implement __getitem__ for ArrayVar #3705

Merged
merged 26 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 91 additions & 2 deletions reflex/experimental/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import functools
import sys
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Type,
TypeVar,
overload,
)

from typing_extensions import ParamSpec
Expand All @@ -28,6 +30,17 @@
_global_vars,
)

if TYPE_CHECKING:
from .function import FunctionVar, ToFunctionOperation
from .number import (
BooleanVar,
NumberVar,
ToBooleanVarOperation,
ToNumberVarOperation,
)
from .object import ObjectVar, ToObjectOperation
from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation


@dataclasses.dataclass(
eq=False,
Expand Down Expand Up @@ -263,18 +276,94 @@ def __format__(self, format_spec: str) -> str:
# Encode the _var_data into the formatted output for tracking purposes.
return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._var_name}"

def to(self, output: Type[OUTPUT]) -> OUTPUT:
@overload
def to(
self, output: Type[NumberVar], var_type: type[int] | type[float] = float
) -> ToNumberVarOperation: ...

@overload
def to(self, output: Type[BooleanVar]) -> ToBooleanVarOperation: ...

@overload
def to(
self,
output: Type[ArrayVar],
var_type: type[list] | type[tuple] | type[set] = list,
) -> ToArrayOperation: ...

@overload
def to(self, output: Type[StringVar]) -> ToStringOperation: ...

@overload
def to(
self, output: Type[ObjectVar], var_type: Type = dict
) -> ToObjectOperation: ...

@overload
def to(
self, output: Type[FunctionVar], var_type: Type[Callable] = Callable
) -> ToFunctionOperation: ...

@overload
def to(self, output: Type[OUTPUT], var_type: Type | None = None) -> OUTPUT: ...

def to(self, output: Type[OUTPUT], var_type: Type | None = None) -> Var:
"""Convert the var to a different type.

Args:
output: The output type.
var_type: The type of the var.

Raises:
TypeError: If the var_type is not a supported type for the output.

Returns:
The converted var.
"""
from .number import (
BooleanVar,
NumberVar,
ToBooleanVarOperation,
ToNumberVarOperation,
)

if issubclass(output, NumberVar):
if var_type is not None and not issubclass(var_type, (int, float)):
raise TypeError(
f"Unsupported type {var_type} for NumberVar. Must be int or float."
)
return ToNumberVarOperation(self, var_type or float)
if issubclass(output, BooleanVar):
return ToBooleanVarOperation(self)

from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation

if issubclass(output, ArrayVar):
if var_type is not None and not issubclass(var_type, (list, tuple, set)):
raise TypeError(
f"Unsupported type {var_type} for ArrayVar. Must be list, tuple, or set."
)
return ToArrayOperation(self, var_type or list)
if issubclass(output, StringVar):
return ToStringOperation(self)

from .object import ObjectVar, ToObjectOperation

if issubclass(output, ObjectVar):
return ToObjectOperation(self, var_type or dict)

from .function import FunctionVar, ToFunctionOperation

if issubclass(output, FunctionVar):
if var_type is not None and not issubclass(var_type, Callable):
raise TypeError(
f"Unsupported type {var_type} for FunctionVar. Must be Callable."
)
return ToFunctionOperation(self, var_type or Callable)

return output(
_var_name=self._var_name,
_var_type=self._var_type,
_var_type=self._var_type if var_type is None else var_type,
_var_data=self._var_data,
)

Expand Down
78 changes: 77 additions & 1 deletion reflex/experimental/vars/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import dataclasses
import sys
from functools import cached_property
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Type, Union

from reflex.experimental.vars.base import ImmutableVar, LiteralVar
from reflex.vars import ImmutableVarData, Var, VarData
Expand Down Expand Up @@ -212,3 +212,79 @@ def _get_all_var_data(self) -> ImmutableVarData | None:

def __post_init__(self):
"""Post-initialize the var."""


@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToFunctionOperation(FunctionVar):
"""Base class of converting a var to a function."""

_original_var: Var = dataclasses.field(
default_factory=lambda: LiteralVar.create(None)
)

def __init__(
self,
original_var: Var,
_var_type: Type[Callable] = Callable,
_var_data: VarData | None = None,
) -> None:
"""Initialize the function with arguments var.

Args:
original_var: The original var to convert to a function.
_var_type: The type of the function.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ToFunctionOperation, self).__init__(
_var_name=f"",
_var_type=_var_type,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_original_var", original_var)
object.__delattr__(self, "_var_name")

def __getattr__(self, name):
"""Get an attribute of the var.

Args:
name: The name of the attribute.

Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)

@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.

Returns:
The name of the var.
"""
return str(self._original_var)

@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.

Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._original_var._get_all_var_data(),
self._var_data,
)

def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.

Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
138 changes: 138 additions & 0 deletions reflex/experimental/vars/number.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,3 +1293,141 @@ def __hash__(self) -> int:

number_types = Union[NumberVar, LiteralNumberVar, int, float]
boolean_types = Union[BooleanVar, LiteralBooleanVar, bool]


@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToNumberVarOperation(NumberVar):
"""Base class for immutable number vars that are the result of a number operation."""

_original_value: Var = dataclasses.field(
default_factory=lambda: LiteralNumberVar(0)
)

def __init__(
self,
_original_value: Var,
_var_type: type[int] | type[float] = float,
_var_data: VarData | None = None,
):
"""Initialize the number var.

Args:
_original_value: The original value.
_var_type: The type of the Var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ToNumberVarOperation, self).__init__(
_var_name="",
_var_type=_var_type,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_original_value", _original_value)
object.__delattr__(self, "_var_name")

@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.

Returns:
The name of the var.
"""
return str(self._original_value)

def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.

Args:
name: The name of the attribute.

Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(ToNumberVarOperation, self), name)

@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.

Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._original_value._get_all_var_data(), self._var_data
)

def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data


@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToBooleanVarOperation(BooleanVar):
"""Base class for immutable boolean vars that are the result of a boolean operation."""

_original_value: Var = dataclasses.field(
default_factory=lambda: LiteralBooleanVar(False)
)

def __init__(
self,
_original_value: Var,
_var_data: VarData | None = None,
):
"""Initialize the boolean var.

Args:
_original_value: The original value.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ToBooleanVarOperation, self).__init__(
_var_name="",
_var_type=bool,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_original_value", _original_value)
object.__delattr__(self, "_var_name")

@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.

Returns:
The name of the var.
"""
return str(self._original_value)

def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.

Args:
name: The name of the attribute.

Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(ToBooleanVarOperation, self), name)

@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.

Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._original_value._get_all_var_data(), self._var_data
)

def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
Loading
Loading