Skip to content
This repository has been archived by the owner on Mar 29, 2023. It is now read-only.

Commit

Permalink
fix: ensure that ScalarParameter names are used instead of Alias names (
Browse files Browse the repository at this point in the history
#135)

* fix: ensure that ScalarParameter names are used instead of Alias names

* fix: handle new properties required of UDFs

* test: fix scalar-parameterized tests
  • Loading branch information
cpcloud authored May 25, 2022
1 parent 71a01b9 commit bfe539a
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 45 deletions.
26 changes: 25 additions & 1 deletion ibis_bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@
except ImportError:
pass

try:
from ibis.expr.operations import Alias
except ImportError:
# Allow older versions of ibis to work with ScalarParameters as well as
# versions >= 3.0.0
Alias = None


__version__: str = ibis_bigquery_version.__version__

Expand Down Expand Up @@ -222,7 +229,24 @@ def _execute(self, stmt, results=True, query_parameters=None):

def raw_sql(self, query: str, results=False, params=None):
query_parameters = [
bigquery_param(param, value) for param, value in (params or {}).items()
bigquery_param(
# unwrap Alias instances
#
# Without unwrapping we try to execute compiled code that uses
# the ScalarParameter's raw name (e.g., @param_1) and not the
# alias's name which will fail. By unwrapping, we always use
# the raw name.
#
# This workaround is backwards compatible and doesn't require
# changes to ibis.
(
param
if Alias is None or not isinstance(param.op(), Alias)
else param.op().arg
),
value,
)
for param, value in (params or {}).items()
]
return self._execute(query, results=results, query_parameters=query_parameters)

Expand Down
27 changes: 12 additions & 15 deletions ibis_bigquery/udf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,21 +181,18 @@ def wrapper(f):
signature = inspect.signature(f)
parameter_names = signature.parameters.keys()

udf_node_fields = collections.OrderedDict(
[
(name, Arg(rlz.value(type)))
for name, type in zip(parameter_names, input_type)
]
+ [
(
"output_type",
lambda self, output_type=output_type: rlz.shape_like(
self.args, dtype=output_type
),
),
("__slots__", ("js",)),
]
)
udf_node_fields = {
name: Arg(rlz.value(type))
for name, type in zip(parameter_names, input_type)
}

try:
udf_node_fields["output_type"] = rlz.shape_like("args", dtype=output_type)
except TypeError:
udf_node_fields["output_dtype"] = property(lambda _: output_type)
udf_node_fields["output_shape"] = rlz.shape_like("args")

udf_node_fields["__slots__"] = ("js",)

udf_node = create_udf_node(f.__name__, udf_node_fields)

Expand Down
85 changes: 60 additions & 25 deletions tests/system/test_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import datetime
import decimal
import re

import ibis
import ibis.expr.datatypes as dt
Expand All @@ -11,6 +12,7 @@
import pandas.testing as tm
import pytest
import pytz
from pytest import param

import ibis_bigquery
from ibis_bigquery.client import bigquery_param
Expand All @@ -19,6 +21,13 @@
IBIS_1_4_VERSION = packaging.version.Version("1.4.0")
IBIS_3_0_VERSION = packaging.version.Version("3.0.0")

older_than_3 = pytest.mark.xfail(
IBIS_VERSION < IBIS_3_0_VERSION, reason="requires ibis >= 3"
)
at_least_3 = pytest.mark.xfail(
IBIS_VERSION >= IBIS_3_0_VERSION, reason="requires ibis < 3"
)


def test_table(alltypes):
assert isinstance(alltypes, ir.TableExpr)
Expand Down Expand Up @@ -204,7 +213,43 @@ def test_different_partition_col_name(monkeypatch, client):
assert col in parted_alltypes.columns


def test_subquery_scalar_params(alltypes, project_id, dataset_id):
def scalar_params_ibis3(project_id, dataset_id):
return f"""\
SELECT count\\(`foo`\\) AS `count`
FROM \\(
SELECT `string_col`, sum\\(`float_col`\\) AS `foo`
FROM \\(
SELECT `float_col`, `timestamp_col`, `int_col`, `string_col`
FROM `{project_id}\\.{dataset_id}\\.functional_alltypes`
\\) t1
WHERE `timestamp_col` < @param_\\d+
GROUP BY 1
\\) t0"""


def scalar_params_not_ibis3(project_id, dataset_id):
return f"""\
SELECT count\\(`foo`\\) AS `count`
FROM \\(
SELECT `string_col`, sum\\(`float_col`\\) AS `foo`
FROM \\(
SELECT `float_col`, `timestamp_col`, `int_col`, `string_col`
FROM `{project_id}\\.{dataset_id}\\.functional_alltypes`
WHERE `timestamp_col` < @my_param
\\) t1
GROUP BY 1
\\) t0"""


@pytest.mark.parametrize(
"expected_fn",
[
param(scalar_params_ibis3, marks=[older_than_3], id="ibis3"),
param(scalar_params_not_ibis3, marks=[at_least_3], id="not_ibis3"),
],
)
def test_subquery_scalar_params(alltypes, project_id, dataset_id, expected_fn):
expected = expected_fn(project_id, dataset_id)
t = alltypes
param = ibis.param("timestamp").name("my_param")
expr = (
Expand All @@ -216,20 +261,7 @@ def test_subquery_scalar_params(alltypes, project_id, dataset_id):
.foo.count()
)
result = expr.compile(params={param: "20140101"})
expected = """\
SELECT count(`foo`) AS `count`
FROM (
SELECT `string_col`, sum(`float_col`) AS `foo`
FROM (
SELECT `float_col`, `timestamp_col`, `int_col`, `string_col`
FROM `{}.{}.functional_alltypes`
WHERE `timestamp_col` < @my_param
) t1
GROUP BY 1
) t0""".format(
project_id, dataset_id
)
assert result == expected
assert re.match(expected, result) is not None


def test_scalar_param_string(alltypes, df):
Expand Down Expand Up @@ -457,18 +489,21 @@ def test_raw_sql(client):
assert client.raw_sql("SELECT 1").fetchall() == [(1,)]


def test_scalar_param_scope(alltypes, project_id, dataset_id):
@pytest.mark.parametrize(
"pattern",
[
param(r"@param_\d+", marks=[older_than_3], id="ibis3"),
param("@param", marks=[at_least_3], id="not_ibis3"),
],
)
def test_scalar_param_scope(alltypes, project_id, dataset_id, pattern):
t = alltypes
param = ibis.param("timestamp")
mut = t.mutate(param=param).compile(params={param: "2017-01-01"})
assert (
mut
== """\
SELECT *, @param AS `param`
FROM `{}.{}.functional_alltypes`""".format(
project_id, dataset_id
)
)
result = t.mutate(param=param).compile(params={param: "2017-01-01"})
expected = f"""\
SELECT \\*, {pattern} AS `param`
FROM `{project_id}\\.{dataset_id}\\.functional_alltypes`"""
assert re.match(expected, result) is not None


def test_parted_column_rename(parted_alltypes):
Expand Down
26 changes: 22 additions & 4 deletions tests/system/test_compiler.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,42 @@
import re

import ibis
import ibis.expr.datatypes as dt
import packaging.version
import pytest
from pytest import param

pytestmark = pytest.mark.bigquery

IBIS_VERSION = packaging.version.Version(ibis.__version__)
IBIS_1_VERSION = packaging.version.Version("1.4.0")
IBIS_3_0_VERSION = packaging.version.Version("3.0.0")

older_than_3 = pytest.mark.xfail(
IBIS_VERSION < IBIS_3_0_VERSION, reason="requires ibis >= 3"
)
at_least_3 = pytest.mark.xfail(
IBIS_VERSION >= IBIS_3_0_VERSION, reason="requires ibis < 3"
)

def test_timestamp_accepts_date_literals(alltypes, project_id, dataset_id):

@pytest.mark.parametrize(
"pattern",
[
param(r"@param_\d+", marks=[older_than_3], id="ibis3"),
param("@param", marks=[at_least_3], id="not_ibis3"),
],
)
def test_timestamp_accepts_date_literals(alltypes, project_id, dataset_id, pattern):
date_string = "2009-03-01"
param = ibis.param(dt.timestamp).name("param_0")
expr = alltypes.mutate(param=param)
params = {param: date_string}
result = expr.compile(params=params)
expected = f"""\
SELECT *, @param AS `param`
FROM `{project_id}.{dataset_id}.functional_alltypes`"""
assert result == expected
SELECT \\*, {pattern} AS `param`
FROM `{project_id}\\.{dataset_id}\\.functional_alltypes`"""
assert re.match(expected, result) is not None


@pytest.mark.parametrize(
Expand Down

0 comments on commit bfe539a

Please sign in to comment.