From 8d35c7ea603c4f910bb0653927203476ad00141c Mon Sep 17 00:00:00 2001 From: ilongin Date: Tue, 7 Jan 2025 14:42:48 +0100 Subject: [PATCH 01/11] added isnone function and tests --- src/datachain/func/conditional.py | 26 ++++++++++++++++++++++++++ src/datachain/lib/dc.py | 7 ++++++- tests/unit/sql/test_conditional.py | 15 +++++++++++++++ tests/unit/test_func.py | 18 ++++++++++++++++++ 4 files changed, 65 insertions(+), 1 deletion(-) diff --git a/src/datachain/func/conditional.py b/src/datachain/func/conditional.py index 396d5b512..1c089b48e 100644 --- a/src/datachain/func/conditional.py +++ b/src/datachain/func/conditional.py @@ -4,6 +4,7 @@ from sqlalchemy.sql.elements import BinaryExpression from datachain.lib.utils import DataChainParamsError +from datachain.query.schema import Column from datachain.sql.functions import conditional from .func import ColT, Func @@ -131,3 +132,28 @@ def case( kwargs = {"else_": else_} return Func("case", inner=sql_case, args=args, kwargs=kwargs, result_type=type_) + + +def isnone(col: Union[str, Column]) -> Func: + """ + Returns True if column value or literal is None, otherwise False + Args: + col (str | Column | literal): Column or literal to check if None. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + + Returns: + Func: A Func object that represents the conditional to check if column is None. + + Example: + ```py + dc.mutate(test=isnone("value")) + ``` + """ + from datachain import C + + if isinstance(col, str): + # if string, it is assumed to be the name of the column + col = C(col) + + return case((col == None, True), else_=False) # type: ignore [arg-type] # noqa: E711 diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index a9542d9d8..7db3e6e4b 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1141,7 +1141,12 @@ def mutate(self, **kwargs) -> "Self": mutated[signal.name.replace(value.name, name, 1)] = signal # type: ignore[union-attr] elif isinstance(value, Func): # adding new signal - mutated[name] = value.get_column(schema) + # mutated[name] = value.get_column(schema) + v = value.get_column(schema) + print("in mutate") + print(v) + print(type(v)) + mutated[name] = v else: # adding new signal mutated[name] = value diff --git a/tests/unit/sql/test_conditional.py b/tests/unit/sql/test_conditional.py index 27466c91a..1a052cb79 100644 --- a/tests/unit/sql/test_conditional.py +++ b/tests/unit/sql/test_conditional.py @@ -107,3 +107,18 @@ def test_case_wrong_result_type(warehouse): "Case supports only python literals ([, , " ", , ]) for values" ) + + +@pytest.mark.parametrize( + "val,expected", + [ + [None, True], + [func.literal("abcd"), False], + ], +) +def test_isnone(warehouse, val, expected): + from datachain.func.conditional import isnone + + query = select(isnone(val)) + result = tuple(warehouse.db.execute(query)) + assert result == ((expected,),) diff --git a/tests/unit/test_func.py b/tests/unit/test_func.py index b8eb54a1d..2f9df5bb3 100644 --- a/tests/unit/test_func.py +++ b/tests/unit/test_func.py @@ -17,6 +17,7 @@ sqlite_byte_hamming_distance, sqlite_int_hash_64, ) +from tests.utils import skip_if_not_sqlite @pytest.fixture() @@ -660,3 +661,20 @@ def test_case_mutate(dc, val, else_, type_): [val, else_, else_, else_, else_] ) assert res.schema["test"] == type_ + + +@pytest.mark.parametrize("col", ["val", C("val")]) +@skip_if_not_sqlite +def test_isnone_mutate(col): + from datachain.func.conditional import isnone + + dc = DataChain.from_values( + num=list(range(1, 6)), + val=[None if i > 3 else "A" for i in range(1, 6)], + ) + + res = dc.mutate(test=isnone(col)) + assert list(res.order_by("test").collect("test")) == sorted( + [False, False, False, True, True] + ) + assert res.schema["test"] is bool From fddab8a2cd1f10276de1c418957d120157485f78 Mon Sep 17 00:00:00 2001 From: ilongin Date: Tue, 7 Jan 2025 15:32:25 +0100 Subject: [PATCH 02/11] removing prints --- src/datachain/lib/dc.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 7db3e6e4b..a9542d9d8 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1141,12 +1141,7 @@ def mutate(self, **kwargs) -> "Self": mutated[signal.name.replace(value.name, name, 1)] = signal # type: ignore[union-attr] elif isinstance(value, Func): # adding new signal - # mutated[name] = value.get_column(schema) - v = value.get_column(schema) - print("in mutate") - print(v) - print(type(v)) - mutated[name] = v + mutated[name] = value.get_column(schema) else: # adding new signal mutated[name] = value From ba5d315d834a8d32d17fa209ce8a32433a1b731f Mon Sep 17 00:00:00 2001 From: ilongin Date: Thu, 9 Jan 2025 15:52:31 +0100 Subject: [PATCH 03/11] added nested case ability --- src/datachain/func/conditional.py | 24 +++++++---- src/datachain/func/func.py | 11 +++-- tests/unit/sql/test_conditional.py | 4 +- tests/unit/test_func.py | 66 ++++++++++++++++++++++++++++++ 4 files changed, 93 insertions(+), 12 deletions(-) diff --git a/src/datachain/func/conditional.py b/src/datachain/func/conditional.py index eae46cbaa..27213a3ef 100644 --- a/src/datachain/func/conditional.py +++ b/src/datachain/func/conditional.py @@ -9,7 +9,7 @@ from .func import ColT, Func -CaseT = Union[int, float, complex, bool, str] +CaseT = Union[int, float, complex, bool, str, Func] def greatest(*args: Union[ColT, float]) -> Func: @@ -88,7 +88,7 @@ def least(*args: Union[ColT, float]) -> Func: ) -def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func: +def case(*args: tuple, else_=None) -> Func: """ Returns the case function that produces case expression which has a list of conditions and corresponding results. Results can only be python primitives @@ -112,15 +112,24 @@ def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func: """ supported_types = [int, float, complex, str, bool] - type_ = type(else_) if else_ else None + def _get_type(val): + if isinstance(val, Func): + # nested functions + return val.result_type + return type(val) if not args: raise DataChainParamsError("Missing statements") + type_ = _get_type(else_) if else_ is not None else None + for arg in args: - if type_ and not isinstance(arg[1], type_): - raise DataChainParamsError("Statement values must be of the same type") - type_ = type(arg[1]) + arg_type = _get_type(arg[1]) + if type_ and arg_type != type_: + raise DataChainParamsError( + f"Statement values must be of the same type, got {type_} amd {arg_type}" + ) + type_ = arg_type if type_ not in supported_types: raise DataChainParamsError( @@ -128,7 +137,8 @@ def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func: ) kwargs = {"else_": else_} - return Func("case", inner=sql_case, args=args, kwargs=kwargs, result_type=type_) + + return Func("case", inner=sql_case, cols=args, kwargs=kwargs, result_type=type_) def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func: diff --git a/src/datachain/func/func.py b/src/datachain/func/func.py index 90ee5796e..97b2b5acd 100644 --- a/src/datachain/func/func.py +++ b/src/datachain/func/func.py @@ -23,7 +23,7 @@ from .window import Window -ColT = Union[str, ColumnElement, "Func"] +ColT = Union[str, ColumnElement, "Func", tuple] class Func(Function): @@ -78,7 +78,7 @@ def _db_cols(self) -> Sequence[ColT]: return ( [ col - if isinstance(col, (Func, BindParameter, Case, Comparator)) + if isinstance(col, (Func, BindParameter, Case, Comparator, tuple)) else ColumnMeta.to_db_name( col.name if isinstance(col, ColumnElement) else col ) @@ -382,6 +382,8 @@ def get_column( sql_type = python_to_sql(col_type) def get_col(col: ColT) -> ColT: + if isinstance(col, tuple): + return tuple(get_col(x) for x in col) if isinstance(col, Func): return col.get_column(signals_schema, table=table) if isinstance(col, str): @@ -391,7 +393,8 @@ def get_col(col: ColT) -> ColT: return col cols = [get_col(col) for col in self._db_cols] - func_col = self.inner(*cols, *self.args, **self.kwargs) + kwargs = {k: get_col(v) for k, v in self.kwargs.items()} + func_col = self.inner(*cols, *self.args, **kwargs) if self.is_window: if not self.window: @@ -423,7 +426,7 @@ def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType": return sql_to_python(col) return signals_schema.get_column_type( - col.name if isinstance(col, ColumnElement) else col + col.name if isinstance(col, ColumnElement) else col # type: ignore[arg-type] ) diff --git a/tests/unit/sql/test_conditional.py b/tests/unit/sql/test_conditional.py index a37aaedcd..16f88bc57 100644 --- a/tests/unit/sql/test_conditional.py +++ b/tests/unit/sql/test_conditional.py @@ -96,7 +96,9 @@ def test_case_not_same_result_types(warehouse): val = 2 with pytest.raises(DataChainParamsError) as exc_info: select(func.case(*[(val > 1, "A"), (2 < val < 4, 5)], else_="D")) - assert str(exc_info.value) == "Statement values must be of the same type" + assert str(exc_info.value) == ( + "Statement values must be of the same type, got amd " + ) def test_case_wrong_result_type(warehouse): diff --git a/tests/unit/test_func.py b/tests/unit/test_func.py index 4974fce32..4b4237f12 100644 --- a/tests/unit/test_func.py +++ b/tests/unit/test_func.py @@ -665,6 +665,59 @@ def test_case_mutate(dc, val, else_, type_): assert res.schema["test"] == type_ +@pytest.mark.parametrize( + "val,else_,type_", + [ + ["A", "D", str], + [1, 2, int], + [1.5, 2.5, float], + [True, False, bool], + ], +) +def test_nested_case_on_condition_mutate(dc, val, else_, type_): + res = dc.mutate( + test=case((case((C("num") < 2, True), else_=False), val), else_=else_) + ) + assert list(res.order_by("test").collect("test")) == sorted( + [val, else_, else_, else_, else_] + ) + assert res.schema["test"] == type_ + + +@pytest.mark.parametrize( + "v1,v2,v3,type_", + [ + ["A", "B", "C", str], + [1, 2, 3, int], + [1.5, 2.5, 3.5, float], + [False, True, True, bool], + ], +) +def test_nested_case_on_value_mutate(dc, v1, v2, v3, type_): + res = dc.mutate( + test=case((C("num") < 4, case((C("num") < 2, v1), else_=v2)), else_=v3) + ) + assert list(res.order_by("num").collect("test")) == sorted([v1, v2, v2, v3, v3]) + assert res.schema["test"] == type_ + + +@pytest.mark.parametrize( + "v1,v2,v3,type_", + [ + ["A", "B", "C", str], + [1, 2, 3, int], + [1.5, 2.5, 3.5, float], + [False, True, True, bool], + ], +) +def test_nested_case_on_else_mutate(dc, v1, v2, v3, type_): + res = dc.mutate( + test=case((C("num") < 3, v1), else_=case((C("num") < 4, v2), else_=v3)) + ) + assert list(res.order_by("num").collect("test")) == sorted([v1, v1, v2, v3, v3]) + assert res.schema["test"] == type_ + + @pytest.mark.parametrize( "if_val,else_val,type_", [ @@ -695,3 +748,16 @@ def test_isnone_mutate(col): [False, False, False, True, True] ) assert res.schema["test"] is bool + + +@pytest.mark.parametrize("col", [C("val"), "val"]) +@skip_if_not_sqlite +def test_isnone_with_ifelse_mutate(col): + dc = DataChain.from_values( + num=list(range(1, 6)), + val=[None if i > 3 else "A" for i in range(1, 6)], + ) + + res = dc.mutate(test=ifelse(isnone(col), "NONE", "NOT_NONE")) + assert list(res.order_by("num").collect("test")) == ["NOT_NONE"] * 3 + ["NONE"] * 2 + assert res.schema["test"] is str From 40ea99f11920f47b77e8105d20fa6cd18eb70c70 Mon Sep 17 00:00:00 2001 From: ilongin Date: Fri, 10 Jan 2025 10:10:30 +0100 Subject: [PATCH 04/11] fixing typing of conditional functions --- src/datachain/func/conditional.py | 35 ++++++++++++++++++------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/datachain/func/conditional.py b/src/datachain/func/conditional.py index 27213a3ef..fe109291a 100644 --- a/src/datachain/func/conditional.py +++ b/src/datachain/func/conditional.py @@ -1,7 +1,7 @@ -from typing import Union +from typing import Optional, Union +from sqlalchemy import ColumnElement from sqlalchemy import case as sql_case -from sqlalchemy.sql.elements import BinaryExpression from datachain.lib.utils import DataChainParamsError from datachain.query.schema import Column @@ -88,17 +88,18 @@ def least(*args: Union[ColT, float]) -> Func: ) -def case(*args: tuple, else_=None) -> Func: +def case( + *args: tuple[Union[ColumnElement, Func], CaseT], else_: Optional[CaseT] = None +) -> Func: """ Returns the case function that produces case expression which has a list of conditions and corresponding results. Results can only be python primitives like string, numbes or booleans. Result type is inferred from condition results. Args: - args (tuple(BinaryExpression, value(str | int | float | complex | bool): - - Tuple of binary expression and values pair which corresponds to one - case condition - value - else_ (str | int | float | complex | bool): else value in case expression + args (tuple((ColumnElement, Func), (str | int | float | complex | bool, Func))): + - Tuple of condition and values pair + else_ (str | int | float | complex | bool, Func): else value in case expression Returns: Func: A Func object that represents the case function. @@ -141,17 +142,21 @@ def _get_type(val): return Func("case", inner=sql_case, cols=args, kwargs=kwargs, result_type=type_) -def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func: +def ifelse( + condition: Union[ColumnElement, Func], if_val: CaseT, else_val: CaseT +) -> Func: """ Returns the ifelse function that produces if expression which has a condition - and values for true and false outcome. Results can only be python primitives - like string, numbes or booleans. Result type is inferred from the values. + and values for true and false outcome. Results can be one of python primitives + like string, numbes or booleans, but can also be nested functions. + Result type is inferred from the values. Args: - condition: BinaryExpression - condition which is evaluated - if_val: (str | int | float | complex | bool): value for true condition outcome - else_val: (str | int | float | complex | bool): value for false condition - outcome + condition: (ColumnElement, Func) - condition which is evaluated + if_val: (str | int | float | complex | bool, Func): value for true + condition outcome + else_val: (str | int | float | complex | bool, Func): value for false condition + outcome Returns: Func: A Func object that represents the ifelse function. @@ -159,7 +164,7 @@ def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func: Example: ```py dc.mutate( - res=func.ifelse(C("num") > 0, "P", "N"), + res=func.ifelse(isnone("col"), "EMPTY", "NOT_EMPTY"), ) ``` """ From a6964350d4b7f7a75d0ebeb9fc552cec47a94f0c Mon Sep 17 00:00:00 2001 From: ilongin Date: Fri, 10 Jan 2025 11:06:35 +0100 Subject: [PATCH 05/11] added fix for columns - literals inside case --- src/datachain/func/func.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/datachain/func/func.py b/src/datachain/func/func.py index 97b2b5acd..d3e364cd1 100644 --- a/src/datachain/func/func.py +++ b/src/datachain/func/func.py @@ -381,19 +381,23 @@ def get_column( col_type = self.get_result_type(signals_schema) sql_type = python_to_sql(col_type) - def get_col(col: ColT) -> ColT: + def get_col(col: ColT, string_as_literal=False) -> ColT: + # string_as_literal is used only for conditionals like `case()` where + # literals are nested inside ColT as we have tuples of condition - values + # and if user wants to set some case value as column, explicit `C("col")` + # syntax must be used to distinguish from literals if isinstance(col, tuple): - return tuple(get_col(x) for x in col) + return tuple(get_col(x, string_as_literal=True) for x in col) if isinstance(col, Func): return col.get_column(signals_schema, table=table) - if isinstance(col, str): + if isinstance(col, str) and not string_as_literal: column = Column(col, sql_type) column.table = table return column return col cols = [get_col(col) for col in self._db_cols] - kwargs = {k: get_col(v) for k, v in self.kwargs.items()} + kwargs = {k: get_col(v, string_as_literal=True) for k, v in self.kwargs.items()} func_col = self.inner(*cols, *self.args, **kwargs) if self.is_window: From 9038cc036cf24df9c4316d595397805fb863c8b3 Mon Sep 17 00:00:00 2001 From: ilongin Date: Fri, 10 Jan 2025 12:14:15 +0100 Subject: [PATCH 06/11] changing docs --- src/datachain/func/conditional.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/datachain/func/conditional.py b/src/datachain/func/conditional.py index fe109291a..b4deab564 100644 --- a/src/datachain/func/conditional.py +++ b/src/datachain/func/conditional.py @@ -173,18 +173,17 @@ def ifelse( def isnone(col: Union[str, Column]) -> Func: """ - Returns True if column value or literal is None, otherwise False + Returns True if column value is None, otherwise False Args: - col (str | Column | literal): Column or literal to check if None. + col (str | Column): Column to check if it's None or not. If a string is provided, it is assumed to be the name of the column. - If a literal is provided, it is assumed to be a string literal. Returns: Func: A Func object that represents the conditional to check if column is None. Example: ```py - dc.mutate(test=isnone("value")) + dc.mutate(test=ifelse(isnone("col"), "NONE", "NOT_NONE")) ``` """ from datachain import C From 1f9d382fabd7e9e29cd2f915014d9a4704424492 Mon Sep 17 00:00:00 2001 From: ivan Date: Sat, 11 Jan 2025 00:54:56 +0100 Subject: [PATCH 07/11] fixing docs --- src/datachain/func/conditional.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/datachain/func/conditional.py b/src/datachain/func/conditional.py index b4deab564..ecb2d6265 100644 --- a/src/datachain/func/conditional.py +++ b/src/datachain/func/conditional.py @@ -94,12 +94,13 @@ def case( """ Returns the case function that produces case expression which has a list of conditions and corresponding results. Results can only be python primitives - like string, numbes or booleans. Result type is inferred from condition results. + like string, numbers or booleans. Result type is inferred from condition results. Args: args (tuple((ColumnElement, Func), (str | int | float | complex | bool, Func))): - - Tuple of condition and values pair - else_ (str | int | float | complex | bool, Func): else value in case expression + Tuple of condition and values pair. + else_ (str | int | float | complex | bool, Func): else value in case + expression. Returns: Func: A Func object that represents the case function. @@ -128,7 +129,7 @@ def _get_type(val): arg_type = _get_type(arg[1]) if type_ and arg_type != type_: raise DataChainParamsError( - f"Statement values must be of the same type, got {type_} amd {arg_type}" + f"Statement values must be of the same type, got {type_} and {arg_type}" ) type_ = arg_type @@ -152,11 +153,11 @@ def ifelse( Result type is inferred from the values. Args: - condition: (ColumnElement, Func) - condition which is evaluated - if_val: (str | int | float | complex | bool, Func): value for true - condition outcome - else_val: (str | int | float | complex | bool, Func): value for false condition - outcome + condition (ColumnElement, Func): Condition which is evaluated. + if_val (str | int | float | complex | bool, Func): Value for true + condition outcome. + else_val (str | int | float | complex | bool, Func): Value for false condition + outcome. Returns: Func: A Func object that represents the ifelse function. @@ -174,6 +175,7 @@ def ifelse( def isnone(col: Union[str, Column]) -> Func: """ Returns True if column value is None, otherwise False + Args: col (str | Column): Column to check if it's None or not. If a string is provided, it is assumed to be the name of the column. @@ -183,7 +185,7 @@ def isnone(col: Union[str, Column]) -> Func: Example: ```py - dc.mutate(test=ifelse(isnone("col"), "NONE", "NOT_NONE")) + dc.mutate(test=ifelse(isnone("col"), "EMPTY", "NOT_EMPTY")) ``` """ from datachain import C @@ -192,4 +194,4 @@ def isnone(col: Union[str, Column]) -> Func: # if string, it is assumed to be the name of the column col = C(col) - return case((col == None, True), else_=False) # type: ignore [arg-type] # noqa: E711 + return case((col == None, True), else_=False) # noqa: E711 From c25e35958f3f3765c6262db566fedb1d41fe77a8 Mon Sep 17 00:00:00 2001 From: ivan Date: Sat, 11 Jan 2025 01:17:30 +0100 Subject: [PATCH 08/11] fixing docs --- tests/unit/test_func.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_func.py b/tests/unit/test_func.py index 4b4237f12..4bed7ddec 100644 --- a/tests/unit/test_func.py +++ b/tests/unit/test_func.py @@ -736,7 +736,6 @@ def test_ifelse_mutate(dc, if_val, else_val, type_): @pytest.mark.parametrize("col", ["val", C("val")]) -@skip_if_not_sqlite def test_isnone_mutate(col): dc = DataChain.from_values( num=list(range(1, 6)), From a3ca3db50acca560efdc022960a69ee64a3eeb00 Mon Sep 17 00:00:00 2001 From: ilongin Date: Thu, 16 Jan 2025 14:39:57 +0100 Subject: [PATCH 09/11] fixing docs and other parts --- src/datachain/func/conditional.py | 18 ++++++++++-------- tests/unit/sql/test_conditional.py | 15 ++++++++++++++- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/datachain/func/conditional.py b/src/datachain/func/conditional.py index ecb2d6265..6f7457c54 100644 --- a/src/datachain/func/conditional.py +++ b/src/datachain/func/conditional.py @@ -93,14 +93,16 @@ def case( ) -> Func: """ Returns the case function that produces case expression which has a list of - conditions and corresponding results. Results can only be python primitives - like string, numbers or booleans. Result type is inferred from condition results. + conditions and corresponding results. Results can be python primitives like string, + numbers or booleans but can also be other nested function (including case function). + Result type is inferred from condition results. Args: args (tuple((ColumnElement, Func), (str | int | float | complex | bool, Func))): Tuple of condition and values pair. - else_ (str | int | float | complex | bool, Func): else value in case - expression. + else_ (str | int | float | complex | bool, Func): optional else value in case + expression. If omitted, and no case conditions are satisfied, the result + will be None (NULL in DB). Returns: Func: A Func object that represents the case function. @@ -149,7 +151,7 @@ def ifelse( """ Returns the ifelse function that produces if expression which has a condition and values for true and false outcome. Results can be one of python primitives - like string, numbes or booleans, but can also be nested functions. + like string, numbers or booleans, but can also be nested functions. Result type is inferred from the values. Args: @@ -165,7 +167,7 @@ def ifelse( Example: ```py dc.mutate( - res=func.ifelse(isnone("col"), "EMPTY", "NOT_EMPTY"), + res=func.ifelse(isnone("col"), "EMPTY", "NOT_EMPTY") ) ``` """ @@ -174,7 +176,7 @@ def ifelse( def isnone(col: Union[str, Column]) -> Func: """ - Returns True if column value is None, otherwise False + Returns True if column value is None, otherwise False. Args: col (str | Column): Column to check if it's None or not. @@ -194,4 +196,4 @@ def isnone(col: Union[str, Column]) -> Func: # if string, it is assumed to be the name of the column col = C(col) - return case((col == None, True), else_=False) # noqa: E711 + return case((col.is_(None) if col is not None else True, True), else_=False) diff --git a/tests/unit/sql/test_conditional.py b/tests/unit/sql/test_conditional.py index 16f88bc57..e78cdcb24 100644 --- a/tests/unit/sql/test_conditional.py +++ b/tests/unit/sql/test_conditional.py @@ -86,6 +86,19 @@ def test_case(warehouse, val, expected): assert result == ((expected,),) +@pytest.mark.parametrize( + "val,expected", + [ + (1, "A"), + (2, None), + ], +) +def test_case_without_else(warehouse, val, expected): + query = select(func.case(*[(val < 2, "A")])) + result = tuple(warehouse.db.execute(query)) + assert result == ((expected,),) + + def test_case_missing_statements(warehouse): with pytest.raises(DataChainParamsError) as exc_info: select(func.case(*[], else_="D")) @@ -97,7 +110,7 @@ def test_case_not_same_result_types(warehouse): with pytest.raises(DataChainParamsError) as exc_info: select(func.case(*[(val > 1, "A"), (2 < val < 4, 5)], else_="D")) assert str(exc_info.value) == ( - "Statement values must be of the same type, got amd " + "Statement values must be of the same type, got and " ) From fbbd8431300c167360f45c23887fab61e7ed94fa Mon Sep 17 00:00:00 2001 From: ilongin Date: Thu, 16 Jan 2025 15:53:06 +0100 Subject: [PATCH 10/11] added check for tuple --- src/datachain/func/func.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/datachain/func/func.py b/src/datachain/func/func.py index d3e364cd1..072519df2 100644 --- a/src/datachain/func/func.py +++ b/src/datachain/func/func.py @@ -423,6 +423,11 @@ def get_col(col: ColT, string_as_literal=False) -> ColT: def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType": + if isinstance(col, tuple): + raise DataChainParamsError( + "Cannot get type from tuple, please provide type hint to the function" + ) + if isinstance(col, Func): return col.get_result_type(signals_schema) @@ -430,7 +435,7 @@ def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType": return sql_to_python(col) return signals_schema.get_column_type( - col.name if isinstance(col, ColumnElement) else col # type: ignore[arg-type] + col.name if isinstance(col, ColumnElement) else col ) From 33faee1d5c011fabb573116ca0b3567a1bfd5112 Mon Sep 17 00:00:00 2001 From: ilongin Date: Fri, 17 Jan 2025 11:32:15 +0100 Subject: [PATCH 11/11] skipping isnone for clickhouse --- tests/unit/test_func.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/test_func.py b/tests/unit/test_func.py index 4bed7ddec..4b4237f12 100644 --- a/tests/unit/test_func.py +++ b/tests/unit/test_func.py @@ -736,6 +736,7 @@ def test_ifelse_mutate(dc, if_val, else_val, type_): @pytest.mark.parametrize("col", ["val", C("val")]) +@skip_if_not_sqlite def test_isnone_mutate(col): dc = DataChain.from_values( num=list(range(1, 6)),