diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index 44eb7323943b..c8d5b04e2895 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -15,6 +15,7 @@ import ibis.expr.operations as ops from ibis import NA from ibis.backends.datafusion import registry +from ibis.common.temporal import IntervalUnit from ibis.expr.operations.udf import InputType from ibis.formats.pyarrow import PyArrowType @@ -927,20 +928,36 @@ def extract_quarter(op, **kw): @translate.register(ops.ExtractMinute) def extract_minute(op, **kw): arg = translate(op.arg, **kw) - return df.functions.date_part(df.literal("minute"), arg) + + if op.arg.dtype.is_time(): + return registry.UDFS["extract_minute_time"](arg) + elif op.arg.dtype.is_timestamp(): + return df.functions.date_part(df.literal("minute"), arg) + else: + raise com.OperationNotDefinedError( + f"The function is not defined for {type(op.arg)}" + ) @translate.register(ops.ExtractHour) def extract_hour(op, **kw): arg = translate(op.arg, **kw) - return df.functions.date_part(df.literal("hour"), arg) + + if op.arg.dtype.is_time(): + return registry.UDFS["extract_hour_time"](arg) + elif op.arg.dtype.is_timestamp(): + return df.functions.date_part(df.literal("hour"), arg) + else: + raise com.OperationNotDefinedError( + f"The function is not defined for {type(op.arg)}" + ) @translate.register(ops.ExtractMillisecond) def extract_millisecond(op, **kw): arg = translate(op.arg, **kw) - if op.arg.dtype.is_date(): + if op.arg.dtype.is_time(): return registry.UDFS["extract_millisecond_time"](arg) elif op.arg.dtype.is_timestamp(): return registry.UDFS["extract_millisecond_timestamp"](arg) @@ -954,7 +971,7 @@ def extract_millisecond(op, **kw): def extract_second(op, **kw): arg = translate(op.arg, **kw) - if op.arg.dtype.is_date(): + if op.arg.dtype.is_time(): return registry.UDFS["extract_second_time"](arg) elif op.arg.dtype.is_timestamp(): return registry.UDFS["extract_second_timestamp"](arg) @@ -1028,3 +1045,19 @@ def extract_epoch_seconds(op, **kw): raise com.OperationNotDefinedError( f"The function is not defined for {type(op.arg)}" ) + + +@translate.register(ops.TimestampTruncate) +def timestamp_truncate(op, **kw): + arg = translate(op.arg, **kw) + unit = op.unit + if unit in ( + IntervalUnit.MILLISECOND, + IntervalUnit.MICROSECOND, + IntervalUnit.NANOSECOND, + ): + raise com.UnsupportedOperationError( + f"The function is not defined for time unit {unit}" + ) + + return df.functions.date_trunc(df.literal(unit.name.lower()), arg) diff --git a/ibis/backends/datafusion/registry.py b/ibis/backends/datafusion/registry.py index 7b152d827b5a..90b2fe2c740c 100644 --- a/ibis/backends/datafusion/registry.py +++ b/ibis/backends/datafusion/registry.py @@ -50,6 +50,14 @@ def extract_millisecond(array: pa.Array) -> pa.Array: return pc.cast(pc.millisecond(array), pa.int32()) +def extract_hour(array: pa.Array) -> pa.Array: + return pc.cast(pc.hour(array), pa.int32()) + + +def extract_minute(array: pa.Array) -> pa.Array: + return pc.cast(pc.minute(array), pa.int32()) + + UDFS = { "extract_microseconds_time": create_udf( ops.ExtractMicrosecond, @@ -111,4 +119,13 @@ def extract_millisecond(array: pa.Array) -> pa.Array: input_types=[dt.timestamp], name="extract_millisecond_timestamp", ), + "extract_hour_time": create_udf( + ops.ExtractHour, extract_hour, input_types=[dt.time], name="extract_hour_time" + ), + "extract_minute_time": create_udf( + ops.ExtractMinute, + extract_minute, + input_types=[dt.time], + name="extract_minute_time", + ), } diff --git a/ibis/backends/datafusion/tests/test_temporal.py b/ibis/backends/datafusion/tests/test_temporal.py new file mode 100644 index 000000000000..009b6041c022 --- /dev/null +++ b/ibis/backends/datafusion/tests/test_temporal.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from operator import methodcaller + +import pytest +from pytest import param + +import ibis + + +@pytest.mark.parametrize( + ("func", "expected"), + [ + param( + methodcaller("hour"), + 14, + id="hour", + ), + param( + methodcaller("minute"), + 48, + id="minute", + ), + param( + methodcaller("second"), + 5, + id="second", + ), + param( + methodcaller("millisecond"), + 359, + id="millisecond", + ), + ], +) +def test_time_extract_literal(con, func, expected): + value = ibis.time("14:48:05.359") + assert con.execute(func(value).name("tmp")) == expected diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 6951339cda3b..75db0e35c6b5 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -433,7 +433,14 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): "ms", marks=[ pytest.mark.notimpl( - ["clickhouse", "impala", "mysql", "pyspark", "sqlite"], + [ + "clickhouse", + "impala", + "mysql", + "pyspark", + "sqlite", + "datafusion", + ], raises=com.UnsupportedOperationError, ), pytest.mark.broken( @@ -447,7 +454,15 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): "us", marks=[ pytest.mark.notimpl( - ["clickhouse", "impala", "mysql", "pyspark", "sqlite", "trino"], + [ + "clickhouse", + "impala", + "mysql", + "pyspark", + "sqlite", + "trino", + "datafusion", + ], raises=com.UnsupportedOperationError, ), pytest.mark.broken( @@ -473,6 +488,7 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): "snowflake", "trino", "mssql", + "datafusion", ], raises=com.UnsupportedOperationError, ), @@ -485,7 +501,7 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): ), ], ) -@pytest.mark.notimpl(["datafusion", "oracle"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError) @pytest.mark.broken( ["druid"], raises=AttributeError,