Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF-3328] Implement __getitem__ for ArrayVar #3705

Merged
merged 26 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions reflex/experimental/vars/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""Experimental Immutable-Based Var System."""

from .base import ImmutableVar as ImmutableVar
from .base import LiteralObjectVar as LiteralObjectVar
from .base import LiteralVar as LiteralVar
from .base import ObjectVar as ObjectVar
from .base import var_operation as var_operation
from .function import FunctionStringVar as FunctionStringVar
from .function import FunctionVar as FunctionVar
Expand All @@ -12,6 +10,8 @@
from .number import LiteralBooleanVar as LiteralBooleanVar
from .number import LiteralNumberVar as LiteralNumberVar
from .number import NumberVar as NumberVar
from .object import LiteralObjectVar as LiteralObjectVar
from .object import ObjectVar as ObjectVar
from .sequence import ArrayJoinOperation as ArrayJoinOperation
from .sequence import ArrayVar as ArrayVar
from .sequence import ConcatVarOperation as ConcatVarOperation
Expand Down
254 changes: 155 additions & 99 deletions reflex/experimental/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@

import dataclasses
import functools
import inspect
import sys
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Optional,
Type,
TypeVar,
Union,
overload,
)

from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, get_origin

from reflex import constants
from reflex.base import Base
Expand All @@ -30,6 +31,17 @@
_global_vars,
)

if TYPE_CHECKING:
from .function import FunctionVar, ToFunctionOperation
from .number import (
BooleanVar,
NumberVar,
ToBooleanVarOperation,
ToNumberVarOperation,
)
from .object import ObjectVar, ToObjectOperation
from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation


@dataclasses.dataclass(
eq=False,
Expand All @@ -43,7 +55,7 @@ class ImmutableVar(Var):
_var_name: str = dataclasses.field()

# The type of the var.
_var_type: Type = dataclasses.field(default=Any)
_var_type: types.GenericType = dataclasses.field(default=Any)

# Extra metadata associated with the Var
_var_data: Optional[ImmutableVarData] = dataclasses.field(default=None)
Expand Down Expand Up @@ -265,9 +277,138 @@ def __format__(self, format_spec: str) -> str:
# Encode the _var_data into the formatted output for tracking purposes.
return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._var_name}"

@overload
def to(
self, output: Type[NumberVar], var_type: type[int] | type[float] = float
) -> ToNumberVarOperation: ...

@overload
def to(self, output: Type[BooleanVar]) -> ToBooleanVarOperation: ...

@overload
def to(
self,
output: Type[ArrayVar],
var_type: type[list] | type[tuple] | type[set] = list,
) -> ToArrayOperation: ...

@overload
def to(self, output: Type[StringVar]) -> ToStringOperation: ...

@overload
def to(
self, output: Type[ObjectVar], var_type: types.GenericType = dict
) -> ToObjectOperation: ...

@overload
def to(
self, output: Type[FunctionVar], var_type: Type[Callable] = Callable
) -> ToFunctionOperation: ...

@overload
def to(
self, output: Type[OUTPUT], var_type: types.GenericType | None = None
) -> OUTPUT: ...

def to(
self, output: Type[OUTPUT], var_type: types.GenericType | None = None
) -> Var:
"""Convert the var to a different type.

Args:
output: The output type.
var_type: The type of the var.

Raises:
TypeError: If the var_type is not a supported type for the output.

Returns:
The converted var.
"""
from .number import (
BooleanVar,
NumberVar,
ToBooleanVarOperation,
ToNumberVarOperation,
)

fixed_type = (
var_type
if var_type is None or inspect.isclass(var_type)
else get_origin(var_type)
)

if issubclass(output, NumberVar):
if fixed_type is not None and not issubclass(fixed_type, (int, float)):
raise TypeError(
f"Unsupported type {var_type} for NumberVar. Must be int or float."
)
return ToNumberVarOperation(self, var_type or float)
if issubclass(output, BooleanVar):
return ToBooleanVarOperation(self)

from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation

if issubclass(output, ArrayVar):
if fixed_type is not None and not issubclass(
fixed_type, (list, tuple, set)
):
raise TypeError(
f"Unsupported type {var_type} for ArrayVar. Must be list, tuple, or set."
)
return ToArrayOperation(self, var_type or list)
if issubclass(output, StringVar):
return ToStringOperation(self)

from .object import ObjectVar, ToObjectOperation

class ObjectVar(ImmutableVar):
"""Base class for immutable object vars."""
if issubclass(output, ObjectVar):
return ToObjectOperation(self, var_type or dict)

from .function import FunctionVar, ToFunctionOperation

if issubclass(output, FunctionVar):
if fixed_type is not None and not issubclass(fixed_type, Callable):
raise TypeError(
f"Unsupported type {var_type} for FunctionVar. Must be Callable."
)
return ToFunctionOperation(self, var_type or Callable)

return output(
_var_name=self._var_name,
_var_type=self._var_type if var_type is None else var_type,
_var_data=self._var_data,
)

def guess_type(self) -> ImmutableVar:
"""Guess the type of the var.

Returns:
The guessed type.
"""
from .number import NumberVar
from .object import ObjectVar
from .sequence import ArrayVar, StringVar

if self._var_type is Any:
return self

var_type = self._var_type

fixed_type = var_type if inspect.isclass(var_type) else get_origin(var_type)

if issubclass(fixed_type, (int, float)):
return self.to(NumberVar, var_type)
if issubclass(fixed_type, dict):
return self.to(ObjectVar, var_type)
if issubclass(fixed_type, (list, tuple, set)):
return self.to(ArrayVar, var_type)
if issubclass(fixed_type, str):
return self.to(StringVar)
return self


OUTPUT = TypeVar("OUTPUT", bound=ImmutableVar)


class LiteralVar(ImmutableVar):
Expand Down Expand Up @@ -299,6 +440,8 @@ def create(
if value is None:
return ImmutableVar.create_safe("null", _var_data=_var_data)

from .object import LiteralObjectVar

if isinstance(value, Base):
return LiteralObjectVar(
value.dict(), _var_type=type(value), _var_data=_var_data
Expand Down Expand Up @@ -330,103 +473,16 @@ def create(
def __post_init__(self):
"""Post-initialize the var."""

def json(self) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the json() method intended to be used for?

"""Serialize the var to a JSON string.

@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralObjectVar(LiteralVar):
"""Base class for immutable literal object vars."""

_var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
default_factory=dict
)

def __init__(
self,
_var_value: dict[Var | Any, Var | Any],
_var_type: Type = dict,
_var_data: VarData | None = None,
):
"""Initialize the object var.

Args:
_var_value: The value of the var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(LiteralObjectVar, self).__init__(
_var_name="",
_var_type=_var_type,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
self,
"_var_value",
_var_value,
)
object.__delattr__(self, "_var_name")

def __getattr__(self, name):
"""Get an attribute of the var.

Args:
name: The name of the attribute.

Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)

@functools.cached_property
def _cached_var_name(self) -> str:
"""The name of the var.

Returns:
The name of the var.
"""
return (
"{ "
+ ", ".join(
[
f"[{str(LiteralVar.create(key))}] : {str(LiteralVar.create(value))}"
for key, value in self._var_value.items()
]
)
+ " }"
)

@functools.cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.

Returns:
The VarData of the components and all of its children.
Raises:
NotImplementedError: If the method is not implemented.
"""
return ImmutableVarData.merge(
*[
value._get_all_var_data()
for key, value in self._var_value
if isinstance(value, Var)
],
*[
key._get_all_var_data()
for key, value in self._var_value
if isinstance(key, Var)
],
self._var_data,
raise NotImplementedError(
"LiteralVar subclasses must implement the json method."
)

def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.

Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data


P = ParamSpec("P")
T = TypeVar("T", bound=ImmutableVar)
Expand Down
Loading
Loading