Skip to content

Commit 7ad4e22

Browse files
authored
feat: add improved extension registry (#12)
* add improved extension registry * fix: update lock file
1 parent e757d2c commit 7ad4e22

28 files changed

+937
-7022
lines changed

fetch_extensions.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

pixi.lock

Lines changed: 0 additions & 828 deletions
This file was deleted.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "0.0.1"
44
requires-python = ">=3.9"
55
dependencies = [
66
"ibis-framework[duckdb]",
7+
"pyparsing",
78
"ibis-substrait",
89
"pyarrow",
910
"pytest",

requirements.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ pyarrow-hotfix==0.6
4747
# via ibis-framework
4848
pygments==2.18.0
4949
# via rich
50+
pyparsing==3.1.4
51+
# via subframe (pyproject.toml)
5052
pytest==8.3.3
5153
# via subframe (pyproject.toml)
5254
python-dateutil==2.9.0.post0

subframe/__init__.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,9 @@
55
from substrait.gen.proto import algebra_pb2 as stalg
66
from .table import Table
77
from .value import Value
8-
from .extensions.extension_registry import ExtensionRegistry
8+
from .extension_registry import FunctionRegistry
99

10-
registry = ExtensionRegistry(
11-
[
12-
"https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_approx.yaml",
13-
"https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_decimal_output.yaml",
14-
"https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml",
15-
"https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml",
16-
"https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic_decimal.yaml",
17-
"https://github.com/substrait-io/substrait/blob/main/extensions/functions_boolean.yaml",
18-
"https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml",
19-
"https://github.com/substrait-io/substrait/blob/main/extensions/functions_datetime.yaml",
20-
"https://github.com/substrait-io/substrait/blob/main/extensions/functions_geometry.yaml",
21-
"https://github.com/substrait-io/substrait/blob/main/extensions/functions_logarithmic.yaml",
22-
"https://github.com/substrait-io/substrait/blob/main/extensions/functions_rounding.yaml",
23-
"https://github.com/substrait-io/substrait/blob/main/extensions/functions_set.yaml",
24-
"https://github.com/substrait-io/substrait/blob/main/extensions/functions_string.yaml",
25-
"https://github.com/substrait-io/substrait/blob/main/extensions/type_variations.yaml",
26-
]
27-
)
10+
registry = FunctionRegistry()
2811

2912

3013
def substrait_type_from_string(type: str):

subframe/derivation_expression.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
from .utils import to_substrait_type
2+
from typing import Optional
3+
from substrait.gen.proto.type_expressions_pb2 import DerivationExpression
4+
from substrait.gen.proto.type_pb2 import Type
5+
from pyparsing import (
6+
Forward,
7+
Literal,
8+
ParseResults,
9+
Word,
10+
ZeroOrMore,
11+
identchars,
12+
infix_notation,
13+
nums,
14+
oneOf,
15+
opAssoc,
16+
)
17+
18+
expr = Forward()
19+
20+
21+
def parse_dtype(tokens: ParseResults):
22+
tokens_dict = tokens.as_dict()
23+
dtype = tokens_dict["dtype"].lower()
24+
if dtype == "decimal":
25+
return DerivationExpression(
26+
decimal=DerivationExpression.ExpressionDecimal(
27+
scale=tokens_dict["scale"], precision=tokens_dict["precision"]
28+
)
29+
)
30+
elif tokens_dict["dtype"] == "boolean":
31+
return DerivationExpression(bool=Type.Boolean())
32+
elif tokens_dict["dtype"] == "i8":
33+
return DerivationExpression(i8=Type.I8())
34+
elif tokens_dict["dtype"] == "i16":
35+
return DerivationExpression(i16=Type.I16())
36+
elif tokens_dict["dtype"] == "i32":
37+
return DerivationExpression(i32=Type.I32())
38+
elif tokens_dict["dtype"] == "i64":
39+
return DerivationExpression(i64=Type.I64())
40+
elif tokens_dict["dtype"] == "fp32":
41+
return DerivationExpression(fp32=Type.FP32())
42+
elif tokens_dict["dtype"] == "fp64":
43+
return DerivationExpression(fp64=Type.FP64())
44+
else:
45+
raise Exception(f"Unknown dtype - {tokens_dict['dtype']}")
46+
47+
48+
dtype = (
49+
Literal("i8")("dtype")
50+
| Literal("i16")("dtype")
51+
| Literal("i32")("dtype")
52+
| Literal("i64")("dtype")
53+
| Literal("fp32")("dtype")
54+
| Literal("fp64")("dtype")
55+
| Literal("boolean")("dtype")
56+
| oneOf("DECIMAL decimal")("dtype")
57+
+ Literal("<").suppress()
58+
+ expr("scale")
59+
+ Literal(",").suppress()
60+
+ expr("precision")
61+
+ Literal(">").suppress()
62+
).set_parse_action(parse_dtype)
63+
64+
supported_functions = ["max", "min"]
65+
66+
67+
def parse_binary_fn(tokens: ParseResults):
68+
if tokens[0] == "min":
69+
op_type = DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MIN
70+
elif tokens[0] == "max":
71+
op_type = DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MAX
72+
else:
73+
raise Exception(f"Unknown operation {tokens[0]}")
74+
75+
return DerivationExpression(
76+
binary_op=DerivationExpression.BinaryOp(
77+
op_type=op_type, arg1=tokens[1], arg2=tokens[2]
78+
)
79+
)
80+
81+
82+
binary_fn = (
83+
oneOf(supported_functions)("fn")
84+
+ Literal("(").suppress()
85+
+ expr
86+
+ Literal(",").suppress()
87+
+ expr
88+
+ Literal(")").suppress()
89+
).set_parse_action(parse_binary_fn)
90+
91+
integer_literal = Word(nums).set_parse_action(
92+
lambda toks: DerivationExpression(integer_literal=int(toks[0]))
93+
)
94+
95+
96+
def parse_parameter(pr: ParseResults):
97+
return DerivationExpression(integer_parameter_name=pr[0])
98+
99+
100+
parameter = Word(identchars + nums).set_parse_action(parse_parameter)
101+
102+
operand = integer_literal | binary_fn | dtype | parameter
103+
104+
105+
def parse_binary_op(pr):
106+
tokens = pr[0]
107+
prev_expression = None
108+
for i in range(1, len(tokens), 2):
109+
if tokens[i] == "*":
110+
op_type = DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MULTIPLY
111+
elif tokens[i] == "+":
112+
op_type = DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_PLUS
113+
elif tokens[i] == "-":
114+
op_type = DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MINUS
115+
elif tokens[i] == ">":
116+
op_type = (
117+
DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_GREATER_THAN
118+
)
119+
elif tokens[i] == "<":
120+
op_type = (
121+
DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_LESS_THAN
122+
)
123+
else:
124+
raise Exception(f"Unknown operation {tokens[i]}")
125+
126+
prev_expression = DerivationExpression(
127+
binary_op=DerivationExpression.BinaryOp(
128+
op_type=op_type,
129+
arg1=prev_expression if prev_expression else tokens[i - 1],
130+
arg2=tokens[i + 1],
131+
)
132+
)
133+
134+
return prev_expression
135+
136+
137+
def parse_ternary(pr):
138+
tokens = pr[0]
139+
return DerivationExpression(
140+
if_else=DerivationExpression.IfElse(
141+
if_condition=tokens[0], if_return=tokens[1], else_return=tokens[2]
142+
)
143+
)
144+
145+
146+
expr << infix_notation(
147+
operand,
148+
[
149+
(oneOf("* /")("binary_op"), 2, opAssoc.LEFT, parse_binary_op),
150+
(oneOf("+ -")("binary_op"), 2, opAssoc.LEFT, parse_binary_op),
151+
(oneOf("> <")("binary_op"), 2, opAssoc.LEFT, parse_binary_op),
152+
(
153+
(Literal("?").suppress(), Literal(":").suppress()),
154+
3,
155+
opAssoc.RIGHT,
156+
parse_ternary,
157+
),
158+
],
159+
)
160+
161+
162+
def parse_assignment(toks):
163+
tokens_dict = toks.as_dict()
164+
return DerivationExpression.ReturnProgram.Assignment(
165+
name=tokens_dict["name"], expression=tokens_dict["expression"]
166+
)
167+
168+
169+
assignment = (
170+
Word(identchars + nums)("name") + Literal("=").suppress() + expr("expression")
171+
).set_parse_action(parse_assignment)
172+
173+
174+
def parse_return_program(toks):
175+
return DerivationExpression(
176+
return_program=DerivationExpression.ReturnProgram(
177+
assignments=toks.as_dict()["assignments"],
178+
final_expression=toks.as_dict()["final_expression"],
179+
)
180+
)
181+
182+
183+
return_program = (
184+
ZeroOrMore(assignment)("assignments") + expr("final_expression")
185+
).set_parse_action(parse_return_program)
186+
187+
188+
def to_proto(txt: str):
189+
return return_program.parseString(txt)[0]
190+
191+
192+
def evaluate_expression(de: DerivationExpression, values: Optional[dict] = None):
193+
kind = de.WhichOneof("kind")
194+
if kind == "return_program":
195+
for assign in de.return_program.assignments:
196+
values[assign.name] = evaluate_expression(assign.expression, values)
197+
return evaluate_expression(de.return_program.final_expression, values)
198+
elif kind == "integer_literal":
199+
return de.integer_literal
200+
elif kind == "integer_parameter_name":
201+
return values[de.integer_parameter_name]
202+
elif kind == "binary_op":
203+
binary_op = de.binary_op
204+
arg1_eval = evaluate_expression(binary_op.arg1, values)
205+
arg2_eval = evaluate_expression(binary_op.arg2, values)
206+
if binary_op.op_type == DerivationExpression.BinaryOp.BINARY_OP_TYPE_PLUS:
207+
return arg1_eval + arg2_eval
208+
elif binary_op.op_type == DerivationExpression.BinaryOp.BINARY_OP_TYPE_MINUS:
209+
return arg1_eval - arg2_eval
210+
elif binary_op.op_type == DerivationExpression.BinaryOp.BINARY_OP_TYPE_MULTIPLY:
211+
return arg1_eval * arg2_eval
212+
elif binary_op.op_type == DerivationExpression.BinaryOp.BINARY_OP_TYPE_MIN:
213+
return min(arg1_eval, arg2_eval)
214+
elif binary_op.op_type == DerivationExpression.BinaryOp.BINARY_OP_TYPE_MAX:
215+
return max(arg1_eval, arg2_eval)
216+
elif (
217+
binary_op.op_type
218+
== DerivationExpression.BinaryOp.BINARY_OP_TYPE_GREATER_THAN
219+
):
220+
return arg1_eval > arg2_eval
221+
elif (
222+
binary_op.op_type == DerivationExpression.BinaryOp.BINARY_OP_TYPE_LESS_THAN
223+
):
224+
return arg1_eval < arg2_eval
225+
else:
226+
raise Exception(f"Unknown binary op type - {binary_op.op_type}")
227+
elif kind == "if_else":
228+
if_else = de.if_else
229+
if_return_eval = evaluate_expression(if_else.if_return, values)
230+
if_condition_eval = evaluate_expression(if_else.if_condition, values)
231+
else_return_eval = evaluate_expression(if_else.else_return, values)
232+
return if_return_eval if if_condition_eval else else_return_eval
233+
elif kind == "decimal":
234+
decimal = de.decimal
235+
scale_eval = evaluate_expression(decimal.scale, values)
236+
precision_eval = evaluate_expression(decimal.precision, values)
237+
return to_substrait_type(f"decimal<{scale_eval},{precision_eval}>")
238+
elif kind in ("i8", "i16", "i32", "i64", "fp32", "fp64"):
239+
return to_substrait_type(kind)
240+
elif kind == "bool":
241+
return to_substrait_type("boolean")
242+
else:
243+
raise Exception(f"Unknown derivation expression type - {kind}")
244+
245+
246+
def evaluate(txt: str, values: Optional[dict] = None):
247+
if not values:
248+
values = {}
249+
return evaluate_expression(to_proto(txt), values)

0 commit comments

Comments
 (0)