Skip to content

Commit c2e5330

Browse files
authored
feat: add lead and lag functions (#25)
1 parent cd41f21 commit c2e5330

File tree

4 files changed

+102
-22
lines changed

4 files changed

+102
-22
lines changed

subframe/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,10 @@ def literal(value: Any, type: str = None) -> Value:
125125
else:
126126
raise Exception(f"Unknown literal type - {type}")
127127

128-
print(literal)
129-
130128
return Value(
131129
expression=stalg.Expression(literal=literal),
132130
data_type=infer_literal_type(literal),
131+
name=str(value),
133132
)
134133

135134

subframe/utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,18 @@ def infer_expression_type(
119119
return infer_literal_type(expression.literal)
120120
elif rex_type == "scalar_function":
121121
return expression.scalar_function.output_type
122-
# WindowFunction window_function = 5;
123-
# IfThen if_then = 6;
124-
# SwitchExpression switch_expression = 7;
125-
# SingularOrList singular_or_list = 8;
126-
# MultiOrList multi_or_list = 9;
127-
# Cast cast = 11;
122+
elif rex_type == "window_function":
123+
return expression.window_function.output_type
124+
elif rex_type == "if_then":
125+
return infer_expression_type(expression.if_then.ifs[0].then)
126+
elif rex_type == "switch_expression":
127+
return infer_expression_type(expression.switch_expression.ifs[0].then)
128+
elif rex_type == "cast":
129+
return expression.cast.type
130+
elif rex_type == "singular_or_list" or rex_type == "multi_or_list":
131+
return Type(
132+
bool=Type.Boolean(nullability=stt.Type.Nullability.NULLABILITY_NULLABLE)
133+
)
128134
# Subquery subquery = 12;
129135
# Nested nested = 13;
130136
else:

subframe/value.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,6 @@
44
# from .table import Table
55

66

7-
def substrait_type_from_substrait_str(data_type: str) -> stt.Type:
8-
data_type = data_type.replace("?", "") # TODO
9-
if data_type == "i64":
10-
return stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE))
11-
elif data_type == "fp64":
12-
return stt.Type(fp64=stt.Type.FP64(nullability=stt.Type.NULLABILITY_NULLABLE))
13-
elif data_type == "boolean":
14-
return stt.Type(
15-
bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_NULLABLE)
16-
)
17-
else:
18-
raise Exception(f"Unknown data type {data_type}")
19-
20-
217
class Value:
228
def __init__(
239
self,
@@ -205,6 +191,66 @@ def count(self):
205191
col_name="Count",
206192
)
207193

194+
def _apply_window_function(
195+
self, additional_arguments: list["Value"], url: str, func: str, col_name: str
196+
):
197+
from subframe import registry
198+
199+
(func_entry, rtn) = registry.lookup_function(
200+
url,
201+
function_name=func,
202+
signature=[
203+
self.data_type,
204+
*[a.data_type for a in additional_arguments],
205+
],
206+
)
207+
208+
output_type = rtn
209+
210+
expression = stalg.Expression(
211+
window_function=stalg.Expression.WindowFunction(
212+
function_reference=func_entry.anchor,
213+
arguments=[
214+
stalg.FunctionArgument(value=self.expression),
215+
*[
216+
stalg.FunctionArgument(value=a.expression)
217+
for a in additional_arguments
218+
],
219+
],
220+
options=[],
221+
phase=stalg.AggregationPhase.AGGREGATION_PHASE_INITIAL_TO_RESULT,
222+
sorts=[],
223+
partitions=[],
224+
)
225+
)
226+
227+
return Value(
228+
expression=expression,
229+
data_type=output_type,
230+
name=f"{col_name}({self._name}, {' ,'.join([a._name for a in additional_arguments])})",
231+
extensions={func_entry.uri: {str(func_entry): func_entry.anchor}},
232+
)
233+
234+
def lead(self, offset): # TODO default
235+
from subframe import literal
236+
237+
if type(offset) == int:
238+
offset = literal(offset, type="i32")
239+
240+
return self._apply_window_function(
241+
[offset], "functions_arithmetic.yaml", "lead", "Lead"
242+
)
243+
244+
def lag(self, offset): # TODO default
245+
from subframe import literal
246+
247+
if type(offset) == int:
248+
offset = literal(offset, type="i32")
249+
250+
return self._apply_window_function(
251+
[offset], "functions_arithmetic.yaml", "lag", "Lag"
252+
)
253+
208254

209255
class AggregateValue:
210256
def __init__(

tests/test_execution.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,3 +569,32 @@ def transform(module):
569569
sf_expr = transform(subframe)
570570

571571
run_parity_test(request.getfixturevalue(consumer), ibis_expr, sf_expr)
572+
573+
574+
@pytest.mark.parametrize(
575+
"consumer",
576+
[
577+
pytest.param(
578+
"acero_consumer",
579+
marks=[pytest.mark.xfail(Exception, reason="Unimplemented")],
580+
),
581+
"datafusion_consumer",
582+
pytest.param(
583+
"duckdb_consumer",
584+
marks=[pytest.mark.xfail(Exception, reason="Unimplemented")],
585+
),
586+
],
587+
)
588+
def test_lead_lag(consumer, request):
589+
590+
def transform(module):
591+
t1 = _orders(module)
592+
593+
return t1.select(
594+
t1["order_total"].lead(offset=2), t1["order_total"].lag(offset=2)
595+
)
596+
597+
ibis_expr = transform(ibis)
598+
sf_expr = transform(subframe)
599+
600+
run_parity_test(request.getfixturevalue(consumer), ibis_expr, sf_expr)

0 commit comments

Comments
 (0)