Skip to content

Commit

Permalink
implement array operations
Browse files Browse the repository at this point in the history
  • Loading branch information
adhami3310 committed Jul 25, 2024
1 parent 63b80d6 commit 1035c34
Show file tree
Hide file tree
Showing 2 changed files with 271 additions and 19 deletions.
271 changes: 252 additions & 19 deletions reflex/experimental/vars/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,77 @@ def length(self) -> NumberVar:
"""
return ArrayLengthOperation(self)

@overload
@classmethod
def range(cls, stop: int | NumberVar, /) -> RangeOperation: # noqa: D418
"""Create a range of numbers.
Args:
stop: The end of the range.
Returns:
The range of numbers.
"""
...

@overload
@classmethod
def range( # noqa: D418
cls,
start: int | NumberVar,
end: int | NumberVar,
step: int | NumberVar = 1,
/,
) -> RangeOperation:
"""Create a range of numbers.
Args:
start: The start of the range.
end: The end of the range.
step: The step of the range.
Returns:
The range of numbers.
"""
...

@classmethod
def range(
cls,
first_endpoint: int | NumberVar,
second_endpoint: int | NumberVar = None,
step: int | NumberVar | None = None,
) -> RangeOperation:
"""Create a range of numbers.
Args:
first_endpoint: The end of the range if second_endpoint is not provided, otherwise the start of the range.
second_endpoint: The end of the range.
step: The step of the range.
Returns:
The range of numbers.
"""
if second_endpoint is None:
start = 0
end = first_endpoint
else:
start = first_endpoint
end = second_endpoint

return RangeOperation(start, end, step or 1)

def contains(self, other: Any) -> ArrayContainsOperation:
"""Check if the array contains an element.
Args:
other: The element to check for.
Returns:
The array contains operation.
"""
return ArrayContainsOperation(self, other)


@dataclasses.dataclass(
eq=False,
Expand Down Expand Up @@ -1002,34 +1073,40 @@ def _cached_var_name(self) -> str:
"""
start, end, step = self._slice.start, self._slice.stop, self._slice.step

if step is not None and step < 0:
actual_start = end + 1 if end is not None else 0
actual_end = start + 1 if start is not None else self.a.length()
return str(
ArraySliceOperation(
ArrayReverseOperation(
ArraySliceOperation(self.a, slice(actual_start, actual_end))
),
slice(None, None, -step),
)
)

start = (
normalized_start = (
LiteralVar.create(start)
if start is not None
else ImmutableVar.create_safe("undefined")
)
end = (
normalized_end = (
LiteralVar.create(end)
if end is not None
else ImmutableVar.create_safe("undefined")
)

if step is None:
return f"{str(self.a)}.slice({str(start)}, {str(end)})"
if step == 0:
raise ValueError("slice step cannot be zero")
return f"{str(self.a)}.slice({str(start)}, {str(end)}).filter((_, i) => i % {str(step)} === 0)"
return (
f"{str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)})"
)
if not isinstance(step, Var):
if step < 0:
actual_start = end + 1 if end is not None else 0
actual_end = start + 1 if start is not None else self.a.length()
return str(
ArraySliceOperation(
ArrayReverseOperation(
ArraySliceOperation(self.a, slice(actual_start, actual_end))
),
slice(None, None, -step),
)
)
if step == 0:
raise ValueError("slice step cannot be zero")
return f"{str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)}).filter((_, i) => i % {str(step)} === 0)"

actual_start_reverse = end + 1 if end is not None else 0
actual_end_reverse = start + 1 if start is not None else self.a.length()

return f"{str(self.step)} > 0 ? {str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)}).filter((_, i) => i % {str(step)} === 0) : {str(self.a)}.slice({str(actual_start_reverse)}, {str(actual_end_reverse)}).reverse().filter((_, i) => i % {str(-step)} === 0)"

def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Expand Down Expand Up @@ -1231,3 +1308,159 @@ def _cached_get_all_var_data(self) -> ImmutableVarData | None:

def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data


@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class RangeOperation(ArrayVar):
"""Base class for immutable array vars that are the result of a range operation."""

start: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0))
end: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0))
step: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(1))

def __init__(
self,
start: NumberVar | int,
end: NumberVar | int,
step: NumberVar | int,
_var_data: VarData | None = None,
):
"""Initialize the range operation var.
Args:
start: The start of the range.
end: The end of the range.
step: The step of the range.
_var_data: Additional hooks and imports associated with the Var.
"""
super(RangeOperation, self).__init__(
_var_name="",
_var_type=list,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
self,
"start",
start if isinstance(start, Var) else LiteralNumberVar(start),
)
object.__setattr__(
self,
"end",
end if isinstance(end, Var) else LiteralNumberVar(end),
)
object.__setattr__(
self,
"step",
step if isinstance(step, Var) else LiteralNumberVar(step),
)
object.__delattr__(self, "_var_name")

@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
start, end, step = self.start, self.end, self.step
return f"Array.from({{ length: ({str(end)} - {str(start)}) / {str(step)} }}, (_, i) => {str(start)} + i * {str(step)})"

def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(RangeOperation, self), name)

@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self.start._get_all_var_data(),
self.end._get_all_var_data(),
self.step._get_all_var_data(),
self._var_data,
)

def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data


@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ArrayContainsOperation(BooleanVar):
"""Base class for immutable boolean vars that are the result of an array contains operation."""

a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([]))
b: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None))

def __init__(
self, a: ArrayVar | list[Any], b: Any | Var, _var_data: VarData | None = None
):
"""Initialize the array contains operation var.
Args:
a: The array.
b: The element to check for.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ArrayContainsOperation, self).__init__(
_var_name="",
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "a", a if isinstance(a, Var) else LiteralArrayVar(a))
object.__setattr__(self, "b", b if isinstance(b, Var) else LiteralVar.create(b))
object.__delattr__(self, "_var_name")

@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"{str(self.a)}.includes({str(self.b)})"

def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(ArrayContainsOperation, self), name)

@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data
)

def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
19 changes: 19 additions & 0 deletions tests/test_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
NumberVar,
)
from reflex.experimental.vars.sequence import (
ArrayVar,
ConcatVarOperation,
LiteralArrayVar,
LiteralStringVar,
Expand Down Expand Up @@ -992,6 +993,24 @@ def test_index_operation():
assert str(array_var[0].to(NumberVar) + 9) == "([1, 2, 3, 4, 5].at(0) + 9)"


def test_array_operations():
array_var = LiteralArrayVar.create([1, 2, 3, 4, 5])

assert str(array_var.length()) == "[1, 2, 3, 4, 5].length"
assert str(array_var.contains(3)) == "[1, 2, 3, 4, 5].includes(3)"
assert str(array_var.reverse()) == "[1, 2, 3, 4, 5].reverse()"
assert str(ArrayVar.range(10)) == "Array.from({ length: (10 - 0) / 1 }, (_, i) => 0 + i * 1)"
assert str(ArrayVar.range(1, 10)) == "Array.from({ length: (10 - 1) / 1 }, (_, i) => 1 + i * 1)"
assert (
str(ArrayVar.range(1, 10, 2))
== "Array.from({ length: (10 - 1) / 2 }, (_, i) => 1 + i * 2)"
)
assert (
str(ArrayVar.range(1, 10, -1))
== "Array.from({ length: (10 - 1) / -1 }, (_, i) => 1 + i * -1)"
)


def test_retrival():
var_without_data = ImmutableVar.create("test")
assert var_without_data is not None
Expand Down

0 comments on commit 1035c34

Please sign in to comment.