Skip to content

Commit 850fc36

Browse files
committed
feat(api): add approx_quantiles for computing approximate quantiles
1 parent 006486f commit 850fc36

File tree

23 files changed

+307
-24
lines changed

23 files changed

+307
-24
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
SELECT
2+
[
3+
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[1],
4+
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[2],
5+
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[3]
6+
] AS `qs`
7+
FROM `functional_alltypes` AS `t0`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS) AS `qs`
3+
FROM `functional_alltypes` AS `t0`
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
approx_quantiles(`t0`.`double_col`, 2 IGNORE NULLS)[1] AS `qs`
3+
FROM `functional_alltypes` AS `t0`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
SELECT
2+
[
3+
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[2],
4+
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[1],
5+
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[3]
6+
] AS `qs`
7+
FROM `functional_alltypes` AS `t0`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
approx_quantiles(`t0`.`double_col`, 100000 IGNORE NULLS)[33333] AS `qs`
3+
FROM `functional_alltypes` AS `t0`

ibis/backends/bigquery/tests/unit/test_compiler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,3 +677,19 @@ def test_time_from_hms_with_micros(snapshot):
677677
literal = ibis.literal(datetime.time(12, 34, 56))
678678
result = ibis.to_sql(literal, dialect="bigquery")
679679
snapshot.assert_match(result, "no_micros.sql")
680+
681+
682+
@pytest.mark.parametrize(
683+
"quantiles",
684+
[
685+
param(0.5, id="scalar"),
686+
param(1 / 3, id="tricky-scalar"),
687+
param([0.25, 0.5, 0.75], id="array"),
688+
param([0.5, 0.25, 0.75], id="shuffled-array"),
689+
param([0, 0.25, 0.5, 0.75, 1], id="complete-array"),
690+
],
691+
)
692+
def test_approx_quantiles(alltypes, quantiles, snapshot):
693+
query = alltypes.double_col.approx_quantile(quantiles).name("qs")
694+
result = ibis.to_sql(query, dialect="bigquery")
695+
snapshot.assert_match(result, "out.sql")

ibis/backends/polars/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,7 @@ def execute_mode(op, **kw):
785785

786786

787787
@translate.register(ops.Quantile)
788+
@translate.register(ops.ApproxQuantile)
788789
def execute_quantile(op, **kw):
789790
arg = translate(op.arg, **kw)
790791
quantile = translate(op.quantile, **kw)

ibis/backends/sql/compilers/bigquery/__init__.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import decimal
6+
import math
57
import re
68
from typing import TYPE_CHECKING, Any
79

@@ -392,6 +394,41 @@ def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
392394

393395
return sge.GroupConcat(this=arg, separator=sep)
394396

397+
def visit_ApproxQuantile(self, op, *, arg, quantile, where):
398+
if not isinstance(op.quantile, ops.Literal):
399+
raise com.UnsupportedOperationError(
400+
"quantile must be a literal in BigQuery"
401+
)
402+
403+
# BigQuery syntax is `APPROX_QUANTILES(col, resolution)` to return
404+
# `resolution + 1` quantiles array. To handle this, we compute the
405+
# resolution ourselves then restructure the output array as needed.
406+
# To avoid excessive resolution we arbitrarily cap it at 100,000 -
407+
# since these are approximate quantiles anyway this seems fine.
408+
quantiles = util.promote_list(op.quantile.value)
409+
fracs = [decimal.Decimal(str(q)).as_integer_ratio() for q in quantiles]
410+
resolution = min(math.lcm(*(den for _, den in fracs)), 100_000)
411+
indices = [(num * resolution) // den for num, den in fracs]
412+
413+
if where is not None:
414+
arg = self.if_(where, arg, NULL)
415+
416+
if not op.arg.dtype.is_floating():
417+
arg = self.cast(arg, dt.float64)
418+
419+
array = self.f.approx_quantiles(
420+
arg, sge.IgnoreNulls(this=sge.convert(resolution))
421+
)
422+
if isinstance(op, ops.ApproxQuantile):
423+
return array[indices[0]]
424+
425+
if indices == list(range(resolution + 1)):
426+
return array
427+
else:
428+
return sge.Array(expressions=[array[i] for i in indices])
429+
430+
visit_ApproxMultiQuantile = visit_ApproxQuantile
431+
395432
def visit_FloorDivide(self, op, *, left, right):
396433
return self.cast(self.f.floor(self.f.ieee_divide(left, right)), op.dtype)
397434

ibis/backends/sql/compilers/clickhouse.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -188,18 +188,24 @@ def visit_CountStar(self, op, *, where, arg):
188188
return self.f.countIf(where)
189189
return sge.Count(this=STAR)
190190

191-
def visit_Quantile(self, op, *, arg, quantile, where):
192-
if where is None:
193-
return self.agg.quantile(arg, quantile, where=where)
194-
195-
func = "quantile" + "s" * isinstance(op, ops.MultiQuantile)
191+
def _visit_quantile(self, func, arg, quantile, where):
196192
return sge.ParameterizedAgg(
197-
this=f"{func}If",
193+
this=f"{func}If" if where is not None else func,
198194
expressions=util.promote_list(quantile),
199-
params=[arg, where],
195+
params=[arg, where] if where is not None else [arg],
200196
)
201197

202-
visit_MultiQuantile = visit_Quantile
198+
def visit_Quantile(self, op, *, arg, quantile, where):
199+
return self._visit_quantile("quantile", arg, quantile, where)
200+
201+
def visit_MultiQuantile(self, op, *, arg, quantile, where):
202+
return self._visit_quantile("quantiles", arg, quantile, where)
203+
204+
def visit_ApproxQuantile(self, op, *, arg, quantile, where):
205+
return self._visit_quantile("quantileTDigest", arg, quantile, where)
206+
207+
def visit_ApproxMultiQuantile(self, op, *, arg, quantile, where):
208+
return self._visit_quantile("quantilesTDigest", arg, quantile, where)
203209

204210
def visit_Correlation(self, op, *, left, right, how, where):
205211
if how == "pop":

ibis/backends/sql/compilers/datafusion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class DataFusionCompiler(SQLGlotCompiler):
5151
)
5252

5353
SIMPLE_OPS = {
54+
ops.ApproxQuantile: "approx_percentile_cont",
5455
ops.ApproxMedian: "approx_median",
5556
ops.ArrayRemove: "array_remove_all",
5657
ops.BitAnd: "bit_and",

ibis/backends/sql/compilers/duckdb.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,17 @@ def visit_Quantile(self, op, *, arg, quantile, where):
532532
def visit_MultiQuantile(self, op, *, arg, quantile, where):
533533
return self.visit_Quantile(op, arg=arg, quantile=quantile, where=where)
534534

535+
def visit_ApproxQuantile(self, op, *, arg, quantile, where):
536+
# duckdb casts the return value back to the input type (so integer
537+
# returns are always integral). Casting the result would give an
538+
# integral float, we want to cast the values to treat them the same as
539+
# if they were a float input.
540+
if not op.arg.dtype.is_floating():
541+
arg = self.cast(arg, dt.float64)
542+
return self.agg.approx_quantile(arg, quantile, where=where)
543+
544+
visit_ApproxMultiQuantile = visit_ApproxQuantile
545+
535546
def visit_HexDigest(self, op, *, arg, how):
536547
if how in ("md5", "sha256"):
537548
return getattr(self.f, how)(arg)

ibis/backends/sql/compilers/exasol.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ def visit_Quantile(self, op, *, arg, quantile, where):
223223
expression=sge.Order(expressions=[sge.Ordered(this=arg)]),
224224
)
225225

226+
visit_ApproxQuantile = visit_Quantile
227+
226228
def visit_TimestampTruncate(self, op, *, arg, unit):
227229
short_name = unit.short
228230
unit_mapping = {"W": "IW"}

ibis/backends/sql/compilers/mssql.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,14 @@ def visit_CountDistinct(self, op, *, arg, where):
219219
arg = self.if_(where, arg, NULL)
220220
return self.f.count(sge.Distinct(expressions=[arg]))
221221

222+
def visit_ApproxQuantile(self, op, *, arg, quantile, where):
223+
if where is not None:
224+
arg = self.if_(where, arg, NULL)
225+
return sge.WithinGroup(
226+
this=self.f.approx_percentile_cont(quantile),
227+
expression=sge.Order(expressions=[sge.Ordered(this=arg, nulls_first=True)]),
228+
)
229+
222230
def visit_DayOfWeekIndex(self, op, *, arg):
223231
return self.f.datepart(self.v.weekday, arg) - 1
224232

ibis/backends/sql/compilers/oracle.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,15 @@ def visit_Quantile(self, op, *, arg, quantile, where):
308308
)
309309
return expr
310310

311+
def visit_ApproxQuantile(self, op, *, arg, quantile, where):
312+
if where is not None:
313+
arg = self.if_(where, arg)
314+
315+
return sge.WithinGroup(
316+
this=self.f.approx_percentile(quantile),
317+
expression=sge.Order(expressions=[sge.Ordered(this=arg)]),
318+
)
319+
311320
def visit_CountDistinct(self, op, *, arg, where):
312321
if where is not None:
313322
arg = self.if_(where, arg)

ibis/backends/sql/compilers/postgres.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ def visit_Quantile(self, op, *, arg, quantile, where):
238238
return expr
239239

240240
visit_MultiQuantile = visit_Quantile
241+
visit_ApproxQuantile = visit_Quantile
242+
visit_ApproxMultiQuantile = visit_Quantile
241243

242244
def visit_Correlation(self, op, *, left, right, how, where):
243245
if how == "sample":

ibis/backends/sql/compilers/pyspark.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,15 @@ def visit_Quantile(self, op, *, arg, quantile, where):
290290

291291
visit_MultiQuantile = visit_Quantile
292292

293+
def visit_ApproxQuantile(self, op, *, arg, quantile, where):
294+
if not op.arg.dtype.is_floating():
295+
arg = self.cast(arg, dt.float64)
296+
if where is not None:
297+
arg = self.if_(where, arg, NULL)
298+
return self.f.approx_percentile(arg, quantile)
299+
300+
visit_ApproxMultiQuantile = visit_ApproxQuantile
301+
293302
def visit_Correlation(self, op, *, left, right, how, where):
294303
if (left_type := op.left.dtype).is_boolean():
295304
left = self.cast(left, dt.Int32(nullable=left_type.nullable))

ibis/backends/sql/compilers/risingwave.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import ibis.expr.datatypes as dt
77
import ibis.expr.operations as ops
88
from ibis.backends.sql.compilers import PostgresCompiler
9-
from ibis.backends.sql.compilers.base import ALL_OPERATIONS
9+
from ibis.backends.sql.compilers.base import ALL_OPERATIONS, NULL
1010
from ibis.backends.sql.datatypes import RisingWaveType
1111
from ibis.backends.sql.dialects import RisingWave
1212

@@ -22,6 +22,8 @@ class RisingWaveCompiler(PostgresCompiler):
2222
ops.DateFromYMD,
2323
ops.Mode,
2424
ops.RandomUUID,
25+
ops.MultiQuantile,
26+
ops.ApproxMultiQuantile,
2527
*(
2628
op
2729
for op in ALL_OPERATIONS
@@ -65,6 +67,17 @@ def visit_Correlation(self, op, *, left, right, how, where):
6567
op, left=left, right=right, how=how, where=where
6668
)
6769

70+
def visit_Quantile(self, op, *, arg, quantile, where):
71+
if where is not None:
72+
arg = self.if_(where, arg, NULL)
73+
suffix = "cont" if op.arg.dtype.is_numeric() else "disc"
74+
return sge.WithinGroup(
75+
this=self.f[f"percentile_{suffix}"](quantile),
76+
expression=sge.Order(expressions=[sge.Ordered(this=arg)]),
77+
)
78+
79+
visit_ApproxQuantile = visit_Quantile
80+
6881
def visit_TimestampTruncate(self, op, *, arg, unit):
6982
unit_mapping = {
7083
"Y": "year",

ibis/backends/sql/compilers/snowflake.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,12 @@ def visit_Quantile(self, op, *, arg, quantile, where):
608608
quantile = self.f.percentile_cont(quantile)
609609
return sge.WithinGroup(this=quantile, expression=order_by)
610610

611+
def visit_ApproxQuantile(self, op, *, arg, quantile, where):
612+
if where is not None:
613+
arg = self.if_(where, arg, NULL)
614+
615+
return self.f.approx_percentile(arg, quantile)
616+
611617
def visit_CountStar(self, op, *, arg, where):
612618
if where is None:
613619
return super().visit_CountStar(op, arg=arg, where=where)

ibis/backends/sql/compilers/trino.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,17 @@ def visit_Correlation(self, op, *, left, right, how, where):
133133

134134
return self.agg.corr(left, right, where=where)
135135

136+
def visit_ApproxQuantile(self, op, *, arg, quantile, where):
137+
# trino casts the return value back to the input type (so integer
138+
# returns are always integral). Casting the result would give an
139+
# integral float, we want to cast the values to treat them the same as
140+
# if they were a float input.
141+
if not op.arg.dtype.is_floating():
142+
arg = self.cast(arg, dt.float64)
143+
return self.agg.approx_quantile(arg, quantile, where=where)
144+
145+
visit_ApproxMultiQuantile = visit_ApproxQuantile
146+
136147
def visit_BitXor(self, op, *, arg, where):
137148
a, b = map(sg.to_identifier, "ab")
138149
input_fn = combine_fn = sge.Lambda(

0 commit comments

Comments
 (0)