Skip to content

Commit

Permalink
refactor: port the internals to use koerce
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Sep 10, 2024
1 parent 77af14b commit df7ab80
Show file tree
Hide file tree
Showing 92 changed files with 1,769 additions and 9,441 deletions.
5 changes: 4 additions & 1 deletion ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,7 +1642,10 @@ def _register_udf(self, udf_node: ops.ScalarUDF):
name = type(udf_node).__name__
type_mapper = self.compiler.type_mapper
input_types = [
type_mapper.to_string(param.annotation.pattern.dtype)
# TODO(kszucs): the data type of the input parameters should be
# retrieved differently rather than relying on the validator
# in the signature
type_mapper.to_string(param.pattern.func.dtype)
for param in udf_node.__signature__.parameters.values()
]
output_type = type_mapper.to_string(udf_node.dtype)
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

import sqlglot as sg
import sqlglot.expressions as sge
from koerce import Replace
from public import public

import ibis.common.exceptions as com
import ibis.common.patterns as pats
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.rewrites import (
Expand Down Expand Up @@ -239,15 +239,15 @@ class SQLGlotCompiler(abc.ABC):
agg = AggGen()
"""A generator for handling aggregate functions"""

rewrites: tuple[type[pats.Replace], ...] = (
rewrites: tuple[type[Replace], ...] = (
empty_in_values_right_side,
add_order_by_to_empty_ranking_window_functions,
one_to_zero_index,
add_one_to_nth_value_input,
)
"""A sequence of rewrites to apply to the expression tree before SQL-specific transforms."""

post_rewrites: tuple[type[pats.Replace], ...] = ()
post_rewrites: tuple[type[Replace], ...] = ()
"""A sequence of rewrites to apply to the expression tree after SQL-specific transforms."""

no_limit_value: sge.Null | None = None
Expand Down Expand Up @@ -290,7 +290,7 @@ class SQLGlotCompiler(abc.ABC):
UNSUPPORTED_OPS: tuple[type[ops.Node], ...] = ()
"""Tuple of operations the backend doesn't support."""

LOWERED_OPS: dict[type[ops.Node], pats.Replace | None] = {
LOWERED_OPS: dict[type[ops.Node], Replace | None] = {
ops.Bucket: lower_bucket,
ops.Capitalize: lower_capitalize,
ops.Sample: lower_sample,
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import sqlglot as sg
import sqlglot.expressions as sge
from koerce import var

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
Expand All @@ -26,7 +27,6 @@
replace,
split_select_distinct_with_order_by,
)
from ibis.common.deferred import var

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import sqlglot as sg
import sqlglot.expressions as sge
from koerce import replace

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
Expand All @@ -18,7 +19,6 @@
exclude_unsupported_window_frame_from_row_number,
rewrite_empty_order_by_window,
)
from ibis.common.patterns import replace
from ibis.expr.rewrites import p


Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import sqlglot as sg
import sqlglot.expressions as sge
from koerce import replace

import ibis
import ibis.common.exceptions as com
Expand All @@ -21,7 +22,6 @@
p,
split_select_distinct_with_order_by,
)
from ibis.common.patterns import replace
from ibis.config import options
from ibis.expr.operations.udf import InputType
from ibis.util import gen_name
Expand Down
10 changes: 4 additions & 6 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
from typing import TYPE_CHECKING, Any

import toolz
from koerce import Is, Object, Pattern, attribute, replace, var
from public import public

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.annotations import attribute
from ibis.common.collections import FrozenDict # noqa: TCH001
from ibis.common.deferred import var
from ibis.common.graph import Graph
from ibis.common.patterns import InstanceOf, Object, Pattern, replace
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.rewrites import d, p, replace_parameter
from ibis.expr.schema import Schema
Expand Down Expand Up @@ -330,7 +328,7 @@ def extract_ctes(node: ops.Relation) -> set[ops.Relation]:
cte_types = (Select, ops.Aggregate, ops.JoinChain, ops.Set, ops.Limit, ops.Sample)
dont_count = (ops.Field, ops.CountStar, ops.CountDistinctStar)

g = Graph.from_bfs(node, filter=~InstanceOf(dont_count))
g = Graph.from_bfs(node, filter=~Is(dont_count))
result = set()
for op, dependents in g.invert().items():
if isinstance(op, ops.View) or (
Expand Down Expand Up @@ -403,7 +401,7 @@ def sqlize(
if ctes:

def apply_ctes(node, kwargs):
new = node.__recreate__(kwargs) if kwargs else node
new = node.__class__(**kwargs) if kwargs else node
return CTE(new) if node in ctes else new

result = result.replace(apply_ctes)
Expand Down Expand Up @@ -454,7 +452,7 @@ def split_select_distinct_with_order_by(_):
return _


@replace(p.WindowFunction(func=p.NTile(y), order_by=()))
@replace(p.WindowFunction(func=p.NTile(+y), order_by=()))
def add_order_by_to_empty_ranking_window_functions(_, **kwargs):
"""Add an ORDER BY clause to rank window functions that don't have one."""
return _.copy(order_by=(y,))
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
SnowflakeProgrammingError,
TrinoUserError,
)
from ibis.common.annotations import ValidationError
from ibis.common.grounds import ValidationError

np = pytest.importorskip("numpy")
pd = pytest.importorskip("pandas")
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
PsycoPg2InternalError,
PyODBCProgrammingError,
)
from ibis.common.annotations import ValidationError
from ibis.common.grounds import ValidationError
from ibis.util import gen_name

np = pytest.importorskip("numpy")
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
SnowflakeProgrammingError,
TrinoUserError,
)
from ibis.common.annotations import ValidationError
from ibis.common.grounds import ValidationError

np = pytest.importorskip("numpy")
pd = pytest.importorskip("pandas")
Expand Down
Loading

0 comments on commit df7ab80

Please sign in to comment.