From bfe539a7c60439f7a521e230736aab3961dbabcc Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Wed, 25 May 2022 09:42:00 -0700 Subject: [PATCH] fix: ensure that ScalarParameter names are used instead of Alias names (#135) * fix: ensure that ScalarParameter names are used instead of Alias names * fix: handle new properties required of UDFs * test: fix scalar-parameterized tests --- ibis_bigquery/__init__.py | 26 ++++++++++- ibis_bigquery/udf/__init__.py | 27 +++++------ tests/system/test_client.py | 85 ++++++++++++++++++++++++----------- tests/system/test_compiler.py | 26 +++++++++-- 4 files changed, 119 insertions(+), 45 deletions(-) diff --git a/ibis_bigquery/__init__.py b/ibis_bigquery/__init__.py index 8ed5703..53a6728 100644 --- a/ibis_bigquery/__init__.py +++ b/ibis_bigquery/__init__.py @@ -29,6 +29,13 @@ except ImportError: pass +try: + from ibis.expr.operations import Alias +except ImportError: + # Allow older versions of ibis to work with ScalarParameters as well as + # versions >= 3.0.0 + Alias = None + __version__: str = ibis_bigquery_version.__version__ @@ -222,7 +229,24 @@ def _execute(self, stmt, results=True, query_parameters=None): def raw_sql(self, query: str, results=False, params=None): query_parameters = [ - bigquery_param(param, value) for param, value in (params or {}).items() + bigquery_param( + # unwrap Alias instances + # + # Without unwrapping we try to execute compiled code that uses + # the ScalarParameter's raw name (e.g., @param_1) and not the + # alias's name which will fail. By unwrapping, we always use + # the raw name. + # + # This workaround is backwards compatible and doesn't require + # changes to ibis. + ( + param + if Alias is None or not isinstance(param.op(), Alias) + else param.op().arg + ), + value, + ) + for param, value in (params or {}).items() ] return self._execute(query, results=results, query_parameters=query_parameters) diff --git a/ibis_bigquery/udf/__init__.py b/ibis_bigquery/udf/__init__.py index 457e3ce..d528a0e 100644 --- a/ibis_bigquery/udf/__init__.py +++ b/ibis_bigquery/udf/__init__.py @@ -181,21 +181,18 @@ def wrapper(f): signature = inspect.signature(f) parameter_names = signature.parameters.keys() - udf_node_fields = collections.OrderedDict( - [ - (name, Arg(rlz.value(type))) - for name, type in zip(parameter_names, input_type) - ] - + [ - ( - "output_type", - lambda self, output_type=output_type: rlz.shape_like( - self.args, dtype=output_type - ), - ), - ("__slots__", ("js",)), - ] - ) + udf_node_fields = { + name: Arg(rlz.value(type)) + for name, type in zip(parameter_names, input_type) + } + + try: + udf_node_fields["output_type"] = rlz.shape_like("args", dtype=output_type) + except TypeError: + udf_node_fields["output_dtype"] = property(lambda _: output_type) + udf_node_fields["output_shape"] = rlz.shape_like("args") + + udf_node_fields["__slots__"] = ("js",) udf_node = create_udf_node(f.__name__, udf_node_fields) diff --git a/tests/system/test_client.py b/tests/system/test_client.py index bfb998b..602f26c 100644 --- a/tests/system/test_client.py +++ b/tests/system/test_client.py @@ -1,6 +1,7 @@ import collections import datetime import decimal +import re import ibis import ibis.expr.datatypes as dt @@ -11,6 +12,7 @@ import pandas.testing as tm import pytest import pytz +from pytest import param import ibis_bigquery from ibis_bigquery.client import bigquery_param @@ -19,6 +21,13 @@ IBIS_1_4_VERSION = packaging.version.Version("1.4.0") IBIS_3_0_VERSION = packaging.version.Version("3.0.0") +older_than_3 = pytest.mark.xfail( + IBIS_VERSION < IBIS_3_0_VERSION, reason="requires ibis >= 3" +) +at_least_3 = pytest.mark.xfail( + IBIS_VERSION >= IBIS_3_0_VERSION, reason="requires ibis < 3" +) + def test_table(alltypes): assert isinstance(alltypes, ir.TableExpr) @@ -204,7 +213,43 @@ def test_different_partition_col_name(monkeypatch, client): assert col in parted_alltypes.columns -def test_subquery_scalar_params(alltypes, project_id, dataset_id): +def scalar_params_ibis3(project_id, dataset_id): + return f"""\ +SELECT count\\(`foo`\\) AS `count` +FROM \\( + SELECT `string_col`, sum\\(`float_col`\\) AS `foo` + FROM \\( + SELECT `float_col`, `timestamp_col`, `int_col`, `string_col` + FROM `{project_id}\\.{dataset_id}\\.functional_alltypes` + \\) t1 + WHERE `timestamp_col` < @param_\\d+ + GROUP BY 1 +\\) t0""" + + +def scalar_params_not_ibis3(project_id, dataset_id): + return f"""\ +SELECT count\\(`foo`\\) AS `count` +FROM \\( + SELECT `string_col`, sum\\(`float_col`\\) AS `foo` + FROM \\( + SELECT `float_col`, `timestamp_col`, `int_col`, `string_col` + FROM `{project_id}\\.{dataset_id}\\.functional_alltypes` + WHERE `timestamp_col` < @my_param + \\) t1 + GROUP BY 1 +\\) t0""" + + +@pytest.mark.parametrize( + "expected_fn", + [ + param(scalar_params_ibis3, marks=[older_than_3], id="ibis3"), + param(scalar_params_not_ibis3, marks=[at_least_3], id="not_ibis3"), + ], +) +def test_subquery_scalar_params(alltypes, project_id, dataset_id, expected_fn): + expected = expected_fn(project_id, dataset_id) t = alltypes param = ibis.param("timestamp").name("my_param") expr = ( @@ -216,20 +261,7 @@ def test_subquery_scalar_params(alltypes, project_id, dataset_id): .foo.count() ) result = expr.compile(params={param: "20140101"}) - expected = """\ -SELECT count(`foo`) AS `count` -FROM ( - SELECT `string_col`, sum(`float_col`) AS `foo` - FROM ( - SELECT `float_col`, `timestamp_col`, `int_col`, `string_col` - FROM `{}.{}.functional_alltypes` - WHERE `timestamp_col` < @my_param - ) t1 - GROUP BY 1 -) t0""".format( - project_id, dataset_id - ) - assert result == expected + assert re.match(expected, result) is not None def test_scalar_param_string(alltypes, df): @@ -457,18 +489,21 @@ def test_raw_sql(client): assert client.raw_sql("SELECT 1").fetchall() == [(1,)] -def test_scalar_param_scope(alltypes, project_id, dataset_id): +@pytest.mark.parametrize( + "pattern", + [ + param(r"@param_\d+", marks=[older_than_3], id="ibis3"), + param("@param", marks=[at_least_3], id="not_ibis3"), + ], +) +def test_scalar_param_scope(alltypes, project_id, dataset_id, pattern): t = alltypes param = ibis.param("timestamp") - mut = t.mutate(param=param).compile(params={param: "2017-01-01"}) - assert ( - mut - == """\ -SELECT *, @param AS `param` -FROM `{}.{}.functional_alltypes`""".format( - project_id, dataset_id - ) - ) + result = t.mutate(param=param).compile(params={param: "2017-01-01"}) + expected = f"""\ +SELECT \\*, {pattern} AS `param` +FROM `{project_id}\\.{dataset_id}\\.functional_alltypes`""" + assert re.match(expected, result) is not None def test_parted_column_rename(parted_alltypes): diff --git a/tests/system/test_compiler.py b/tests/system/test_compiler.py index a261efb..629ab3b 100644 --- a/tests/system/test_compiler.py +++ b/tests/system/test_compiler.py @@ -1,24 +1,42 @@ +import re + import ibis import ibis.expr.datatypes as dt import packaging.version import pytest +from pytest import param pytestmark = pytest.mark.bigquery IBIS_VERSION = packaging.version.Version(ibis.__version__) IBIS_1_VERSION = packaging.version.Version("1.4.0") +IBIS_3_0_VERSION = packaging.version.Version("3.0.0") +older_than_3 = pytest.mark.xfail( + IBIS_VERSION < IBIS_3_0_VERSION, reason="requires ibis >= 3" +) +at_least_3 = pytest.mark.xfail( + IBIS_VERSION >= IBIS_3_0_VERSION, reason="requires ibis < 3" +) -def test_timestamp_accepts_date_literals(alltypes, project_id, dataset_id): + +@pytest.mark.parametrize( + "pattern", + [ + param(r"@param_\d+", marks=[older_than_3], id="ibis3"), + param("@param", marks=[at_least_3], id="not_ibis3"), + ], +) +def test_timestamp_accepts_date_literals(alltypes, project_id, dataset_id, pattern): date_string = "2009-03-01" param = ibis.param(dt.timestamp).name("param_0") expr = alltypes.mutate(param=param) params = {param: date_string} result = expr.compile(params=params) expected = f"""\ -SELECT *, @param AS `param` -FROM `{project_id}.{dataset_id}.functional_alltypes`""" - assert result == expected +SELECT \\*, {pattern} AS `param` +FROM `{project_id}\\.{dataset_id}\\.functional_alltypes`""" + assert re.match(expected, result) is not None @pytest.mark.parametrize(