|
| 1 | +from .utils import to_substrait_type |
| 2 | +from typing import Optional |
| 3 | +from substrait.gen.proto.type_expressions_pb2 import DerivationExpression |
| 4 | +from substrait.gen.proto.type_pb2 import Type |
| 5 | +from pyparsing import ( |
| 6 | + Forward, |
| 7 | + Literal, |
| 8 | + ParseResults, |
| 9 | + Word, |
| 10 | + ZeroOrMore, |
| 11 | + identchars, |
| 12 | + infix_notation, |
| 13 | + nums, |
| 14 | + oneOf, |
| 15 | + opAssoc, |
| 16 | +) |
| 17 | + |
| 18 | +expr = Forward() |
| 19 | + |
| 20 | + |
| 21 | +def parse_dtype(tokens: ParseResults): |
| 22 | + tokens_dict = tokens.as_dict() |
| 23 | + dtype = tokens_dict["dtype"].lower() |
| 24 | + if dtype == "decimal": |
| 25 | + return DerivationExpression( |
| 26 | + decimal=DerivationExpression.ExpressionDecimal( |
| 27 | + scale=tokens_dict["scale"], precision=tokens_dict["precision"] |
| 28 | + ) |
| 29 | + ) |
| 30 | + elif tokens_dict["dtype"] == "boolean": |
| 31 | + return DerivationExpression(bool=Type.Boolean()) |
| 32 | + elif tokens_dict["dtype"] == "i8": |
| 33 | + return DerivationExpression(i8=Type.I8()) |
| 34 | + elif tokens_dict["dtype"] == "i16": |
| 35 | + return DerivationExpression(i16=Type.I16()) |
| 36 | + elif tokens_dict["dtype"] == "i32": |
| 37 | + return DerivationExpression(i32=Type.I32()) |
| 38 | + elif tokens_dict["dtype"] == "i64": |
| 39 | + return DerivationExpression(i64=Type.I64()) |
| 40 | + elif tokens_dict["dtype"] == "fp32": |
| 41 | + return DerivationExpression(fp32=Type.FP32()) |
| 42 | + elif tokens_dict["dtype"] == "fp64": |
| 43 | + return DerivationExpression(fp64=Type.FP64()) |
| 44 | + else: |
| 45 | + raise Exception(f"Unknown dtype - {tokens_dict['dtype']}") |
| 46 | + |
| 47 | + |
| 48 | +dtype = ( |
| 49 | + Literal("i8")("dtype") |
| 50 | + | Literal("i16")("dtype") |
| 51 | + | Literal("i32")("dtype") |
| 52 | + | Literal("i64")("dtype") |
| 53 | + | Literal("fp32")("dtype") |
| 54 | + | Literal("fp64")("dtype") |
| 55 | + | Literal("boolean")("dtype") |
| 56 | + | oneOf("DECIMAL decimal")("dtype") |
| 57 | + + Literal("<").suppress() |
| 58 | + + expr("scale") |
| 59 | + + Literal(",").suppress() |
| 60 | + + expr("precision") |
| 61 | + + Literal(">").suppress() |
| 62 | +).set_parse_action(parse_dtype) |
| 63 | + |
| 64 | +supported_functions = ["max", "min"] |
| 65 | + |
| 66 | + |
| 67 | +def parse_binary_fn(tokens: ParseResults): |
| 68 | + if tokens[0] == "min": |
| 69 | + op_type = DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MIN |
| 70 | + elif tokens[0] == "max": |
| 71 | + op_type = DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MAX |
| 72 | + else: |
| 73 | + raise Exception(f"Unknown operation {tokens[0]}") |
| 74 | + |
| 75 | + return DerivationExpression( |
| 76 | + binary_op=DerivationExpression.BinaryOp( |
| 77 | + op_type=op_type, arg1=tokens[1], arg2=tokens[2] |
| 78 | + ) |
| 79 | + ) |
| 80 | + |
| 81 | + |
| 82 | +binary_fn = ( |
| 83 | + oneOf(supported_functions)("fn") |
| 84 | + + Literal("(").suppress() |
| 85 | + + expr |
| 86 | + + Literal(",").suppress() |
| 87 | + + expr |
| 88 | + + Literal(")").suppress() |
| 89 | +).set_parse_action(parse_binary_fn) |
| 90 | + |
| 91 | +integer_literal = Word(nums).set_parse_action( |
| 92 | + lambda toks: DerivationExpression(integer_literal=int(toks[0])) |
| 93 | +) |
| 94 | + |
| 95 | + |
| 96 | +def parse_parameter(pr: ParseResults): |
| 97 | + return DerivationExpression(integer_parameter_name=pr[0]) |
| 98 | + |
| 99 | + |
| 100 | +parameter = Word(identchars + nums).set_parse_action(parse_parameter) |
| 101 | + |
| 102 | +operand = integer_literal | binary_fn | dtype | parameter |
| 103 | + |
| 104 | + |
| 105 | +def parse_binary_op(pr): |
| 106 | + tokens = pr[0] |
| 107 | + prev_expression = None |
| 108 | + for i in range(1, len(tokens), 2): |
| 109 | + if tokens[i] == "*": |
| 110 | + op_type = DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MULTIPLY |
| 111 | + elif tokens[i] == "+": |
| 112 | + op_type = DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_PLUS |
| 113 | + elif tokens[i] == "-": |
| 114 | + op_type = DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_MINUS |
| 115 | + elif tokens[i] == ">": |
| 116 | + op_type = ( |
| 117 | + DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_GREATER_THAN |
| 118 | + ) |
| 119 | + elif tokens[i] == "<": |
| 120 | + op_type = ( |
| 121 | + DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_LESS_THAN |
| 122 | + ) |
| 123 | + else: |
| 124 | + raise Exception(f"Unknown operation {tokens[i]}") |
| 125 | + |
| 126 | + prev_expression = DerivationExpression( |
| 127 | + binary_op=DerivationExpression.BinaryOp( |
| 128 | + op_type=op_type, |
| 129 | + arg1=prev_expression if prev_expression else tokens[i - 1], |
| 130 | + arg2=tokens[i + 1], |
| 131 | + ) |
| 132 | + ) |
| 133 | + |
| 134 | + return prev_expression |
| 135 | + |
| 136 | + |
| 137 | +def parse_ternary(pr): |
| 138 | + tokens = pr[0] |
| 139 | + return DerivationExpression( |
| 140 | + if_else=DerivationExpression.IfElse( |
| 141 | + if_condition=tokens[0], if_return=tokens[1], else_return=tokens[2] |
| 142 | + ) |
| 143 | + ) |
| 144 | + |
| 145 | + |
| 146 | +expr << infix_notation( |
| 147 | + operand, |
| 148 | + [ |
| 149 | + (oneOf("* /")("binary_op"), 2, opAssoc.LEFT, parse_binary_op), |
| 150 | + (oneOf("+ -")("binary_op"), 2, opAssoc.LEFT, parse_binary_op), |
| 151 | + (oneOf("> <")("binary_op"), 2, opAssoc.LEFT, parse_binary_op), |
| 152 | + ( |
| 153 | + (Literal("?").suppress(), Literal(":").suppress()), |
| 154 | + 3, |
| 155 | + opAssoc.RIGHT, |
| 156 | + parse_ternary, |
| 157 | + ), |
| 158 | + ], |
| 159 | +) |
| 160 | + |
| 161 | + |
| 162 | +def parse_assignment(toks): |
| 163 | + tokens_dict = toks.as_dict() |
| 164 | + return DerivationExpression.ReturnProgram.Assignment( |
| 165 | + name=tokens_dict["name"], expression=tokens_dict["expression"] |
| 166 | + ) |
| 167 | + |
| 168 | + |
| 169 | +assignment = ( |
| 170 | + Word(identchars + nums)("name") + Literal("=").suppress() + expr("expression") |
| 171 | +).set_parse_action(parse_assignment) |
| 172 | + |
| 173 | + |
| 174 | +def parse_return_program(toks): |
| 175 | + return DerivationExpression( |
| 176 | + return_program=DerivationExpression.ReturnProgram( |
| 177 | + assignments=toks.as_dict()["assignments"], |
| 178 | + final_expression=toks.as_dict()["final_expression"], |
| 179 | + ) |
| 180 | + ) |
| 181 | + |
| 182 | + |
| 183 | +return_program = ( |
| 184 | + ZeroOrMore(assignment)("assignments") + expr("final_expression") |
| 185 | +).set_parse_action(parse_return_program) |
| 186 | + |
| 187 | + |
| 188 | +def to_proto(txt: str): |
| 189 | + return return_program.parseString(txt)[0] |
| 190 | + |
| 191 | + |
| 192 | +def evaluate_expression(de: DerivationExpression, values: Optional[dict] = None): |
| 193 | + kind = de.WhichOneof("kind") |
| 194 | + if kind == "return_program": |
| 195 | + for assign in de.return_program.assignments: |
| 196 | + values[assign.name] = evaluate_expression(assign.expression, values) |
| 197 | + return evaluate_expression(de.return_program.final_expression, values) |
| 198 | + elif kind == "integer_literal": |
| 199 | + return de.integer_literal |
| 200 | + elif kind == "integer_parameter_name": |
| 201 | + return values[de.integer_parameter_name] |
| 202 | + elif kind == "binary_op": |
| 203 | + binary_op = de.binary_op |
| 204 | + arg1_eval = evaluate_expression(binary_op.arg1, values) |
| 205 | + arg2_eval = evaluate_expression(binary_op.arg2, values) |
| 206 | + if binary_op.op_type == DerivationExpression.BinaryOp.BINARY_OP_TYPE_PLUS: |
| 207 | + return arg1_eval + arg2_eval |
| 208 | + elif binary_op.op_type == DerivationExpression.BinaryOp.BINARY_OP_TYPE_MINUS: |
| 209 | + return arg1_eval - arg2_eval |
| 210 | + elif binary_op.op_type == DerivationExpression.BinaryOp.BINARY_OP_TYPE_MULTIPLY: |
| 211 | + return arg1_eval * arg2_eval |
| 212 | + elif binary_op.op_type == DerivationExpression.BinaryOp.BINARY_OP_TYPE_MIN: |
| 213 | + return min(arg1_eval, arg2_eval) |
| 214 | + elif binary_op.op_type == DerivationExpression.BinaryOp.BINARY_OP_TYPE_MAX: |
| 215 | + return max(arg1_eval, arg2_eval) |
| 216 | + elif ( |
| 217 | + binary_op.op_type |
| 218 | + == DerivationExpression.BinaryOp.BINARY_OP_TYPE_GREATER_THAN |
| 219 | + ): |
| 220 | + return arg1_eval > arg2_eval |
| 221 | + elif ( |
| 222 | + binary_op.op_type == DerivationExpression.BinaryOp.BINARY_OP_TYPE_LESS_THAN |
| 223 | + ): |
| 224 | + return arg1_eval < arg2_eval |
| 225 | + else: |
| 226 | + raise Exception(f"Unknown binary op type - {binary_op.op_type}") |
| 227 | + elif kind == "if_else": |
| 228 | + if_else = de.if_else |
| 229 | + if_return_eval = evaluate_expression(if_else.if_return, values) |
| 230 | + if_condition_eval = evaluate_expression(if_else.if_condition, values) |
| 231 | + else_return_eval = evaluate_expression(if_else.else_return, values) |
| 232 | + return if_return_eval if if_condition_eval else else_return_eval |
| 233 | + elif kind == "decimal": |
| 234 | + decimal = de.decimal |
| 235 | + scale_eval = evaluate_expression(decimal.scale, values) |
| 236 | + precision_eval = evaluate_expression(decimal.precision, values) |
| 237 | + return to_substrait_type(f"decimal<{scale_eval},{precision_eval}>") |
| 238 | + elif kind in ("i8", "i16", "i32", "i64", "fp32", "fp64"): |
| 239 | + return to_substrait_type(kind) |
| 240 | + elif kind == "bool": |
| 241 | + return to_substrait_type("boolean") |
| 242 | + else: |
| 243 | + raise Exception(f"Unknown derivation expression type - {kind}") |
| 244 | + |
| 245 | + |
| 246 | +def evaluate(txt: str, values: Optional[dict] = None): |
| 247 | + if not values: |
| 248 | + values = {} |
| 249 | + return evaluate_expression(to_proto(txt), values) |
0 commit comments