Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): add support for passing an optional index parameter to array map and filter #10205

Merged
merged 5 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,13 +774,17 @@
array = self.f.array_reverse(self.f.array_agg(arg))
return array[self.f.safe_offset(0)]

def visit_ArrayFilter(self, op, *, arg, body, param):
def visit_ArrayFilter(self, op, *, arg, body, param, index):
return self.f.array(
sg.select(param).from_(self._unnest(arg, as_=param)).where(body)
sg.select(param)
.from_(self._unnest(arg, as_=param, offset=index))
.where(body)
)

def visit_ArrayMap(self, op, *, arg, body, param):
return self.f.array(sg.select(body).from_(self._unnest(arg, as_=param)))
def visit_ArrayMap(self, op, *, arg, body, param, index):
return self.f.array(

Check warning on line 785 in ibis/backends/sql/compilers/bigquery/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/bigquery/__init__.py#L785

Added line #L785 was not covered by tests
sg.select(body).from_(self._unnest(arg, as_=param, offset=index))
)

def visit_ArrayZip(self, op, *, arg):
lengths = [self.f.array_length(arr) - 1 for arr in arg]
Expand Down
28 changes: 22 additions & 6 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,13 +597,29 @@ def visit_ExtractQuery(self, op, *, arg, key):
def visit_ArrayStringJoin(self, op, *, arg, sep):
return self.f.arrayStringConcat(arg, sep)

def visit_ArrayMap(self, op, *, arg, param, body):
func = sge.Lambda(this=body, expressions=[param])
return self.f.arrayMap(func, arg)
def visit_ArrayMap(self, op, *, arg, param, body, index):
expressions = [param]
args = [arg]

if index is not None:
expressions.append(index)
args.append(self.f.range(0, self.f.length(arg)))

func = sge.Lambda(this=body, expressions=expressions)

return self.f.arrayMap(func, *args)

def visit_ArrayFilter(self, op, *, arg, param, body, index):
expressions = [param]
args = [arg]

if index is not None:
expressions.append(index)
args.append(self.f.range(0, self.f.length(arg)))

func = sge.Lambda(this=body, expressions=expressions)

def visit_ArrayFilter(self, op, *, arg, param, body):
func = sge.Lambda(this=body, expressions=[param])
return self.f.arrayFilter(func, arg)
return self.f.arrayFilter(func, *args)

def visit_ArrayRemove(self, op, *, arg, other):
x = sg.to_identifier(util.gen_name("x"))
Expand Down
24 changes: 19 additions & 5 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from ibis import util
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DuckDBType
from ibis.backends.sql.rewrites import lower_sample
from ibis.backends.sql.rewrites import (
lower_sample,
subtract_one_from_array_map_filter_index,
)
from ibis.util import gen_name

if TYPE_CHECKING:
Expand Down Expand Up @@ -42,6 +45,7 @@ class DuckDBCompiler(SQLGlotCompiler):
type_mapper = DuckDBType

agg = AggGen(supports_filter=True, supports_order_by=True)
rewrites = (subtract_one_from_array_map_filter_index, *SQLGlotCompiler.rewrites)

supports_qualify = True

Expand Down Expand Up @@ -187,12 +191,22 @@ def visit_ArraySlice(self, op, *, arg, start, stop):

return self.f.list_slice(arg, start + 1, stop)

def visit_ArrayMap(self, op, *, arg, body, param):
lamduh = sge.Lambda(this=body, expressions=[sg.to_identifier(param)])
def visit_ArrayMap(self, op, *, arg, body, param, index):
expressions = [param]

if index is not None:
expressions.append(index)

lamduh = sge.Lambda(this=body, expressions=expressions)
return self.f.list_apply(arg, lamduh)

def visit_ArrayFilter(self, op, *, arg, body, param):
lamduh = sge.Lambda(this=body, expressions=[sg.to_identifier(param)])
def visit_ArrayFilter(self, op, *, arg, body, param, index):
expressions = [sg.to_identifier(param)]

if index is not None:
expressions.append(sg.to_identifier(index))

lamduh = sge.Lambda(this=body, expressions=expressions)
return self.f.list_filter(arg, lamduh)

def visit_ArrayIntersect(self, op, *, left, right):
Expand Down
26 changes: 21 additions & 5 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import PostgresType
from ibis.backends.sql.dialects import Postgres
from ibis.backends.sql.rewrites import lower_sample, split_select_distinct_with_order_by
from ibis.backends.sql.rewrites import (
lower_sample,
split_select_distinct_with_order_by,
subtract_one_from_array_map_filter_index,
)
from ibis.common.exceptions import InvalidDecoratorError
from ibis.util import gen_name

Expand All @@ -42,6 +46,7 @@ class PostgresCompiler(SQLGlotCompiler):

dialect = Postgres
type_mapper = PostgresType
rewrites = (subtract_one_from_array_map_filter_index, *SQLGlotCompiler.rewrites)
post_rewrites = (split_select_distinct_with_order_by,)

agg = AggGen(supports_filter=True, supports_order_by=True)
Expand Down Expand Up @@ -323,16 +328,27 @@ def visit_ArrayContains(self, op, *, arg, other):
expression=self.f.array(self.cast(other, arg_dtype.value_type)),
)

def visit_ArrayFilter(self, op, *, arg, body, param):
def visit_ArrayFilter(self, op, *, arg, body, param, index):
if index is None:
alias = param
else:
alias = sge.TableAlias(this=sg.to_identifier("_"), columns=[param])

return self.f.array(
sg.select(sg.column(param, quoted=self.quoted))
.from_(sge.Unnest(expressions=[arg], alias=param))
.from_(sge.Unnest(expressions=[arg], alias=alias, offset=index))
.where(body)
)

def visit_ArrayMap(self, op, *, arg, body, param):
def visit_ArrayMap(self, op, *, arg, body, param, index):
if index is None:
alias = param
else:
alias = sge.TableAlias(this=sg.to_identifier("_"), columns=[param])
return self.f.array(
sg.select(body).from_(sge.Unnest(expressions=[arg], alias=param))
sg.select(body).from_(
sge.Unnest(expressions=[arg], alias=alias, offset=index)
)
)

def visit_ArrayPosition(self, op, *, arg, other):
Expand Down
23 changes: 16 additions & 7 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,16 +381,25 @@ def visit_MapGet(self, op, *, arg, key, default):
def visit_ArrayZip(self, op, *, arg):
return self.cast(self.f.arrays_zip(*arg), op.dtype)

def visit_ArrayMap(self, op, *, arg, body, param):
param = sge.Identifier(this=param)
func = sge.Lambda(this=body, expressions=[param])
def visit_ArrayMap(self, op, *, arg, body, param, index):
expressions = [param]

if index is not None:
expressions.append(index)

func = sge.Lambda(this=body, expressions=expressions)
return self.f.transform(arg, func)

def visit_ArrayFilter(self, op, *, arg, body, param):
param = sge.Identifier(this=param)
func = sge.Lambda(this=self.if_(body, param, NULL), expressions=[param])
def visit_ArrayFilter(self, op, *, arg, body, param, index):
expressions = [param]

if index is not None:
expressions.append(index)

func = sge.Lambda(this=self.if_(body, param, NULL), expressions=expressions)
transform = self.f.transform(arg, func)
func = sge.Lambda(this=param.is_(sg.not_(NULL)), expressions=[param])

func = sge.Lambda(this=param.is_(sg.not_(NULL)), expressions=expressions)
return self.f.filter(transform, func)

def visit_ArrayIndex(self, op, *, arg, index):
Expand Down
79 changes: 65 additions & 14 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,31 @@
lower_log10,
lower_sample,
rewrite_empty_order_by_window,
x,
)
from ibis.common.patterns import replace
from ibis.expr.rewrites import p


@replace(p.ArrayMap | p.ArrayFilter)
def multiple_args_to_zipped_struct_field_access(_, **kwargs):
# no index argument, so do nothing
if _.index is None:
return _

Check warning on line 44 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L44

Added line #L44 was not covered by tests

param = _.param.name

Check warning on line 46 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L46

Added line #L46 was not covered by tests

@replace(x @ p.Argument(name=param))

Check warning on line 48 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L48

Added line #L48 was not covered by tests
def argument_replacer(_, x, **kwargs):
return ops.StructField(x.copy(dtype=dt.Struct({"$1": _.dtype})), "$1")

Check warning on line 50 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L50

Added line #L50 was not covered by tests

@replace(x @ p.Argument(name=_.index.name))

Check warning on line 52 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L52

Added line #L52 was not covered by tests
def index_replacer(_, x, **kwargs):
return ops.StructField(

Check warning on line 54 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L54

Added line #L54 was not covered by tests
x.copy(name=param, dtype=dt.Struct({"$2": _.dtype})), "$2"
)

return _.copy(body=_.body.replace(argument_replacer | index_replacer))

Check warning on line 58 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L58

Added line #L58 was not covered by tests


class SnowflakeFuncGen(FuncGen):
Expand All @@ -54,6 +78,7 @@
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
rewrite_empty_order_by_window,
multiple_args_to_zipped_struct_field_access,
*SQLGlotCompiler.rewrites,
)

Expand Down Expand Up @@ -768,23 +793,49 @@
.subquery()
)

def visit_ArrayMap(self, op, *, arg, param, body):
def visit_Sample(
self, op, *, parent, fraction: float, method: str, seed: int | None, **_
):
sample = sge.TableSample(

Check warning on line 799 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L799

Added line #L799 was not covered by tests
method="bernoulli" if method == "row" else "system",
percent=sge.convert(fraction * 100.0),
seed=None if seed is None else sge.convert(seed),
)
return self._make_sample_backwards_compatible(sample=sample, parent=parent)

Check warning on line 804 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L804

Added line #L804 was not covered by tests

def visit_ArrayMap(self, op, *, arg, param, body, index):
if index is not None:
arg = self.f.arrays_zip(

Check warning on line 808 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L808

Added line #L808 was not covered by tests
arg, self.f.array_generate_range(0, self.f.array_size(arg))
)
return self.f.transform(arg, sge.Lambda(this=body, expressions=[param]))

def visit_ArrayFilter(self, op, *, arg, param, body):
return self.f.filter(
arg,
sge.Lambda(
this=sg.and_(
body,
# necessary otherwise null values are treated as JSON nulls
# instead of SQL NULLs
self.cast(sg.to_identifier(param), op.dtype.value_type).is_(
sg.not_(NULL)
),
def visit_ArrayFilter(self, op, *, arg, param, body, index):
if index is not None:
arg = self.f.arrays_zip(

Check warning on line 815 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L815

Added line #L815 was not covered by tests
arg, self.f.array_generate_range(0, self.f.array_size(arg))
)
null_filter_arg = self.f.get(param, "$1")

Check warning on line 818 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L818

Added line #L818 was not covered by tests
# extract the field we care about
placeholder = sg.to_identifier("__ibis_snowflake_arg__")

Check warning on line 820 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L820

Added line #L820 was not covered by tests
post_process = lambda arg: self.f.transform(
arg,
sge.Lambda(
this=self.f.get(placeholder, "$1"), expressions=[placeholder]
),
expressions=[param],
),
)
else:
null_filter_arg = param

Check warning on line 828 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L828

Added line #L828 was not covered by tests
post_process = lambda arg: arg

# null_filter is necessary otherwise null values are treated as JSON
# nulls instead of SQL NULLs
null_filter = self.cast(null_filter_arg, op.dtype.value_type).is_(sg.not_(NULL))

Check warning on line 833 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L833

Added line #L833 was not covered by tests

return post_process(

Check warning on line 835 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L835

Added line #L835 was not covered by tests
self.f.filter(
arg, sge.Lambda(this=sg.and_(body, null_filter), expressions=[param])
)
)

def visit_JoinLink(self, op, *, how, table, predicates):
Expand Down
36 changes: 32 additions & 4 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,39 @@ def _neg_idx_to_pos(n, idx):

return self.f.slice(arg, start + 1, stop - start)

def visit_ArrayMap(self, op, *, arg, param, body):
return self.f.transform(arg, sge.Lambda(this=body, expressions=[param]))
def visit_ArrayMap(self, op, *, arg, param, body, index):
if index is None:
return self.f.transform(arg, sge.Lambda(this=body, expressions=[param]))
else:
return self.f.zip_with(
arg,
self.f.sequence(0, self.f.cardinality(arg) - 1),
sge.Lambda(this=body, expressions=[param, index]),
)

def visit_ArrayFilter(self, op, *, arg, param, body):
return self.f.filter(arg, sge.Lambda(this=body, expressions=[param]))
def visit_ArrayFilter(self, op, *, arg, param, body, index):
if index is None:
return self.f.filter(arg, sge.Lambda(this=body, expressions=[param]))
else:
placeholder = sg.to_identifier("__trino_filter__")
index = sg.to_identifier(index)
return self.f.filter(
self.f.zip_with(
arg,
# users are limited to 10_000 elements here because it
# seems like trino won't ever actually address the limit
self.f.sequence(0, self.f.cardinality(arg) - 1),
sge.Lambda(
# semantics are: arg if predicate(arg, index) else null
this=self.if_(body, param, NULL),
expressions=[param, index],
),
),
# then, filter out elements that are null
sge.Lambda(
this=placeholder.is_(sg.not_(NULL)), expressions=[placeholder]
),
)

def visit_ArrayContains(self, op, *, arg, other):
return self.if_(
Expand Down
13 changes: 13 additions & 0 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,3 +634,16 @@ def lower(_, **kwargs):
return _

return lower


@replace(p.ArrayMap | p.ArrayFilter)
def subtract_one_from_array_map_filter_index(_, **kwargs):
# no index argument, so do nothing
if _.index is None:
return _

@replace(y @ p.Argument(name=_.index.name))
def argument_replacer(_, y, **kwargs):
return ops.Subtract(y, 1)

return _.copy(body=_.body.replace(argument_replacer))
Loading