diff --git a/docs/source/user-guide/common-operations/functions.rst b/docs/source/user-guide/common-operations/functions.rst index ad71c72a..12097be8 100644 --- a/docs/source/user-guide/common-operations/functions.rst +++ b/docs/source/user-guide/common-operations/functions.rst @@ -38,7 +38,7 @@ DataFusion offers mathematical functions such as :py:func:`~datafusion.functions .. ipython:: python - from datafusion import col, literal + from datafusion import col, literal, string_literal, str_lit from datafusion import functions as f df.select( @@ -104,6 +104,17 @@ This also includes the functions for regular expressions like :py:func:`~datafus f.regexp_replace(col('"Name"'), literal("saur"), literal("fleur")).alias("flowers") ) +Casting +------- + +Casting expressions to different data types using :py:func:`~datafusion.functions.arrow_cast` + +.. ipython:: python + + df.select( + f.arrow_cast(col('"Total"'), string_literal("Float64")).alias("total_as_float"), + f.arrow_cast(col('"Total"'), str_lit("Int32")).alias("total_as_int") + ) Other ----- diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index e0bc57f4..7367b0d3 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -107,6 +107,19 @@ def literal(value): return Expr.literal(value) +def string_literal(value): + """Create a UTF8 literal expression. + + It differs from `literal` which creates a UTF8view literal. + """ + return Expr.string_literal(value) + + +def str_lit(value): + """Alias for `string_literal`.""" + return string_literal(value) + + def lit(value): """Create a literal expression.""" return Expr.literal(value) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index b1072438..16add16f 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -380,6 +380,22 @@ def literal(value: Any) -> Expr: value = pa.scalar(value) return Expr(expr_internal.Expr.literal(value)) + @staticmethod + def string_literal(value: str) -> Expr: + """Creates a new expression representing a UTF8 literal value. + + It is different from `literal` because it is pa.string() instead of + pa.string_view() + + This is needed for cases where DataFusion is expecting a UTF8 instead of + UTF8View literal, like in: + https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179 + """ + if isinstance(value, str): + value = pa.scalar(value, type=pa.string()) + return Expr(expr_internal.Expr.literal(value)) + return Expr.literal(value) + @staticmethod def column(value: str) -> Expr: """Creates a new expression representing a column.""" diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 15ad8822..2e74073b 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -82,6 +82,7 @@ "array_to_string", "array_union", "arrow_typeof", + "arrow_cast", "ascii", "asin", "asinh", @@ -1108,6 +1109,11 @@ def arrow_typeof(arg: Expr) -> Expr: return Expr(f.arrow_typeof(arg.expr)) +def arrow_cast(expr: Expr, data_type: Expr) -> Expr: + """Casts an expression to a specified data type.""" + return Expr(f.arrow_cast(expr.expr, data_type.expr)) + + def random() -> Expr: """Returns a random value in the range ``0.0 <= x < 1.0``.""" return Expr(f.random()) diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 0d2fa8f9..5dce188e 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -23,7 +23,7 @@ from datafusion import SessionContext, column from datafusion import functions as f -from datafusion import literal +from datafusion import literal, string_literal np.seterr(invalid="ignore") @@ -907,6 +907,22 @@ def test_temporal_functions(df): assert result.column(10) == pa.array([31, 26, 2], type=pa.float64()) +def test_arrow_cast(df): + df = df.select( + # we use `string_literal` to return utf8 instead of `literal` which returns + # utf8view because datafusion.arrow_cast expects a utf8 instead of utf8view + # https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179 + f.arrow_cast(column("b"), string_literal("Float64")).alias("b_as_float"), + f.arrow_cast(column("b"), string_literal("Int32")).alias("b_as_int"), + ) + result = df.collect() + assert len(result) == 1 + result = result[0] + + assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64()) + assert result.column(1) == pa.array([4, 5, 6], type=pa.int32()) + + def test_case(df): df = df.select( f.case(column("b")).when(literal(4), literal(10)).otherwise(literal(8)), diff --git a/src/functions.rs b/src/functions.rs index e29c57f9..2f8a96d9 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -389,7 +389,6 @@ macro_rules! expr_fn { } }; } - /// Generates a [pyo3] wrapper for [datafusion::functions::expr_fn] /// /// These functions take a single `Vec` argument using `pyo3(signature = (*args))`. @@ -564,6 +563,7 @@ expr_fn_vec!(r#struct); // Use raw identifier since struct is a keyword expr_fn_vec!(named_struct); expr_fn!(from_unixtime, unixtime); expr_fn!(arrow_typeof, arg_1); +expr_fn!(arrow_cast, arg_1 datatype); expr_fn!(random); // Array Functions @@ -856,6 +856,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(range))?; m.add_wrapped(wrap_pyfunction!(array_agg))?; m.add_wrapped(wrap_pyfunction!(arrow_typeof))?; + m.add_wrapped(wrap_pyfunction!(arrow_cast))?; m.add_wrapped(wrap_pyfunction!(ascii))?; m.add_wrapped(wrap_pyfunction!(asin))?; m.add_wrapped(wrap_pyfunction!(asinh))?;