From 1035c349777a14ea0a0141f274052b9beb4aac4d Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 25 Jul 2024 14:02:52 -0700 Subject: [PATCH] implement array operations --- reflex/experimental/vars/sequence.py | 271 +++++++++++++++++++++++++-- tests/test_var.py | 19 ++ 2 files changed, 271 insertions(+), 19 deletions(-) diff --git a/reflex/experimental/vars/sequence.py b/reflex/experimental/vars/sequence.py index 179f62f41cc..7396717870e 100644 --- a/reflex/experimental/vars/sequence.py +++ b/reflex/experimental/vars/sequence.py @@ -735,6 +735,77 @@ def length(self) -> NumberVar: """ return ArrayLengthOperation(self) + @overload + @classmethod + def range(cls, stop: int | NumberVar, /) -> RangeOperation: # noqa: D418 + """Create a range of numbers. + + Args: + stop: The end of the range. + + Returns: + The range of numbers. + """ + ... + + @overload + @classmethod + def range( # noqa: D418 + cls, + start: int | NumberVar, + end: int | NumberVar, + step: int | NumberVar = 1, + /, + ) -> RangeOperation: + """Create a range of numbers. + + Args: + start: The start of the range. + end: The end of the range. + step: The step of the range. + + Returns: + The range of numbers. + """ + ... + + @classmethod + def range( + cls, + first_endpoint: int | NumberVar, + second_endpoint: int | NumberVar = None, + step: int | NumberVar | None = None, + ) -> RangeOperation: + """Create a range of numbers. + + Args: + first_endpoint: The end of the range if second_endpoint is not provided, otherwise the start of the range. + second_endpoint: The end of the range. + step: The step of the range. + + Returns: + The range of numbers. + """ + if second_endpoint is None: + start = 0 + end = first_endpoint + else: + start = first_endpoint + end = second_endpoint + + return RangeOperation(start, end, step or 1) + + def contains(self, other: Any) -> ArrayContainsOperation: + """Check if the array contains an element. + + Args: + other: The element to check for. + + Returns: + The array contains operation. + """ + return ArrayContainsOperation(self, other) + @dataclasses.dataclass( eq=False, @@ -1002,34 +1073,40 @@ def _cached_var_name(self) -> str: """ start, end, step = self._slice.start, self._slice.stop, self._slice.step - if step is not None and step < 0: - actual_start = end + 1 if end is not None else 0 - actual_end = start + 1 if start is not None else self.a.length() - return str( - ArraySliceOperation( - ArrayReverseOperation( - ArraySliceOperation(self.a, slice(actual_start, actual_end)) - ), - slice(None, None, -step), - ) - ) - - start = ( + normalized_start = ( LiteralVar.create(start) if start is not None else ImmutableVar.create_safe("undefined") ) - end = ( + normalized_end = ( LiteralVar.create(end) if end is not None else ImmutableVar.create_safe("undefined") ) - if step is None: - return f"{str(self.a)}.slice({str(start)}, {str(end)})" - if step == 0: - raise ValueError("slice step cannot be zero") - return f"{str(self.a)}.slice({str(start)}, {str(end)}).filter((_, i) => i % {str(step)} === 0)" + return ( + f"{str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)})" + ) + if not isinstance(step, Var): + if step < 0: + actual_start = end + 1 if end is not None else 0 + actual_end = start + 1 if start is not None else self.a.length() + return str( + ArraySliceOperation( + ArrayReverseOperation( + ArraySliceOperation(self.a, slice(actual_start, actual_end)) + ), + slice(None, None, -step), + ) + ) + if step == 0: + raise ValueError("slice step cannot be zero") + return f"{str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)}).filter((_, i) => i % {str(step)} === 0)" + + actual_start_reverse = end + 1 if end is not None else 0 + actual_end_reverse = start + 1 if start is not None else self.a.length() + + return f"{str(self.step)} > 0 ? {str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)}).filter((_, i) => i % {str(step)} === 0) : {str(self.a)}.slice({str(actual_start_reverse)}, {str(actual_end_reverse)}).reverse().filter((_, i) => i % {str(-step)} === 0)" def __getattr__(self, name: str) -> Any: """Get an attribute of the var. @@ -1231,3 +1308,159 @@ def _cached_get_all_var_data(self) -> ImmutableVarData | None: def _get_all_var_data(self) -> ImmutableVarData | None: return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class RangeOperation(ArrayVar): + """Base class for immutable array vars that are the result of a range operation.""" + + start: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) + end: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) + step: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(1)) + + def __init__( + self, + start: NumberVar | int, + end: NumberVar | int, + step: NumberVar | int, + _var_data: VarData | None = None, + ): + """Initialize the range operation var. + + Args: + start: The start of the range. + end: The end of the range. + step: The step of the range. + _var_data: Additional hooks and imports associated with the Var. + """ + super(RangeOperation, self).__init__( + _var_name="", + _var_type=list, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, + "start", + start if isinstance(start, Var) else LiteralNumberVar(start), + ) + object.__setattr__( + self, + "end", + end if isinstance(end, Var) else LiteralNumberVar(end), + ) + object.__setattr__( + self, + "step", + step if isinstance(step, Var) else LiteralNumberVar(step), + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + start, end, step = self.start, self.end, self.step + return f"Array.from({{ length: ({str(end)} - {str(start)}) / {str(step)} }}, (_, i) => {str(start)} + i * {str(step)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(RangeOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.start._get_all_var_data(), + self.end._get_all_var_data(), + self.step._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ArrayContainsOperation(BooleanVar): + """Base class for immutable boolean vars that are the result of an array contains operation.""" + + a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + b: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) + + def __init__( + self, a: ArrayVar | list[Any], b: Any | Var, _var_data: VarData | None = None + ): + """Initialize the array contains operation var. + + Args: + a: The array. + b: The element to check for. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ArrayContainsOperation, self).__init__( + _var_name="", + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a if isinstance(a, Var) else LiteralArrayVar(a)) + object.__setattr__(self, "b", b if isinstance(b, Var) else LiteralVar.create(b)) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.includes({str(self.b)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ArrayContainsOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data diff --git a/tests/test_var.py b/tests/test_var.py index 7c742f2b5d8..88abaa9da82 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -20,6 +20,7 @@ NumberVar, ) from reflex.experimental.vars.sequence import ( + ArrayVar, ConcatVarOperation, LiteralArrayVar, LiteralStringVar, @@ -992,6 +993,24 @@ def test_index_operation(): assert str(array_var[0].to(NumberVar) + 9) == "([1, 2, 3, 4, 5].at(0) + 9)" +def test_array_operations(): + array_var = LiteralArrayVar.create([1, 2, 3, 4, 5]) + + assert str(array_var.length()) == "[1, 2, 3, 4, 5].length" + assert str(array_var.contains(3)) == "[1, 2, 3, 4, 5].includes(3)" + assert str(array_var.reverse()) == "[1, 2, 3, 4, 5].reverse()" + assert str(ArrayVar.range(10)) == "Array.from({ length: (10 - 0) / 1 }, (_, i) => 0 + i * 1)" + assert str(ArrayVar.range(1, 10)) == "Array.from({ length: (10 - 1) / 1 }, (_, i) => 1 + i * 1)" + assert ( + str(ArrayVar.range(1, 10, 2)) + == "Array.from({ length: (10 - 1) / 2 }, (_, i) => 1 + i * 2)" + ) + assert ( + str(ArrayVar.range(1, 10, -1)) + == "Array.from({ length: (10 - 1) / -1 }, (_, i) => 1 + i * -1)" + ) + + def test_retrival(): var_without_data = ImmutableVar.create("test") assert var_without_data is not None