diff --git a/ibis/backends/base/sqlglot/compiler.py b/ibis/backends/base/sqlglot/compiler.py index 16a8403e4d52..3338d13533dc 100644 --- a/ibis/backends/base/sqlglot/compiler.py +++ b/ibis/backends/base/sqlglot/compiler.py @@ -23,6 +23,7 @@ add_order_by_to_empty_ranking_window_functions, empty_in_values_right_side, one_to_zero_index, + rewrite_capitalize, sqlize, ) from ibis.expr.operations.udf import InputType @@ -179,6 +180,7 @@ class SQLGlotCompiler(abc.ABC): one_to_zero_index, add_one_to_nth_value_input, replace_bucket, + rewrite_capitalize, ) """A sequence of rewrites to apply to the expression tree before compilation.""" @@ -223,7 +225,6 @@ class SQLGlotCompiler(abc.ABC): ops.Asin: "asin", ops.Atan2: "atan2", ops.Atan: "atan", - ops.Capitalize: "initcap", ops.Cos: "cos", ops.Cot: "cot", ops.Count: "count", diff --git a/ibis/backends/base/sqlglot/rewrites.py b/ibis/backends/base/sqlglot/rewrites.py index 93cf4db49158..f3cdf7364231 100644 --- a/ibis/backends/base/sqlglot/rewrites.py +++ b/ibis/backends/base/sqlglot/rewrites.py @@ -245,6 +245,18 @@ def add_one_to_nth_value_input(_, **kwargs): return _.copy(nth=nth) +@replace(p.Capitalize) +def rewrite_capitalize(_, **kwargs): + """Rewrite Capitalize in terms of substring, concat, upper, and lower.""" + first = ops.Uppercase(ops.Substring(_.arg, start=0, length=1)) + # use length instead of length - 1 to avoid backends complaining about + # asking for negative length + # + # there are at most length - 1 characters, so asking for length is fine + rest = ops.Lowercase(ops.Substring(_.arg, start=1, length=ops.StringLength(_.arg))) + return ops.StringConcat((first, rest)) + + @replace(p.Sample) def rewrite_sample_as_filter(_, **kwargs): """Rewrite Sample as `t.filter(random() <= fraction)`. diff --git a/ibis/backends/bigquery/compiler.py b/ibis/backends/bigquery/compiler.py index accb2487b2e4..e4cfdfff298c 100644 --- a/ibis/backends/bigquery/compiler.py +++ b/ibis/backends/bigquery/compiler.py @@ -225,11 +225,6 @@ def visit_ArrayRepeat(self, op, *, arg, times): sg.select(arg[self.f.safe_ordinal(idx)]).from_(self._unnest(series, as_=i)) ) - def visit_Capitalize(self, op, *, arg): - return self.f.concat( - self.f.upper(self.f.substr(arg, 1, 1)), self.f.lower(self.f.substr(arg, 2)) - ) - def visit_NthValue(self, op, *, arg, nth): if not isinstance(op.nth, ops.Literal): raise com.UnsupportedOperationError( diff --git a/ibis/backends/clickhouse/compiler.py b/ibis/backends/clickhouse/compiler.py index dbb1a6e1066a..d94fe936763f 100644 --- a/ibis/backends/clickhouse/compiler.py +++ b/ibis/backends/clickhouse/compiler.py @@ -452,11 +452,6 @@ def visit_StringSplit(self, op, *, arg, delimiter): delimiter, self.cast(arg, dt.String(nullable=False)) ) - def visit_Capitalize(self, op, *, arg): - return self.f.concat( - self.f.upper(self.f.substr(arg, 1, 1)), self.f.lower(self.f.substr(arg, 2)) - ) - def visit_GroupConcat(self, op, *, arg, sep, where): call = self.agg.groupArray(arg, where=where) return self.if_(self.f.empty(call), NULL, self.f.arrayStringConcat(call, sep)) diff --git a/ibis/backends/dask/tests/test_strings.py b/ibis/backends/dask/tests/test_strings.py index 240d78b2fd17..dfd70e291d3e 100644 --- a/ibis/backends/dask/tests/test_strings.py +++ b/ibis/backends/dask/tests/test_strings.py @@ -51,11 +51,6 @@ param(lambda s: s.reverse(), lambda s: s.str[::-1], id="reverse"), param(lambda s: s.lower(), lambda s: s.str.lower(), id="lower"), param(lambda s: s.upper(), lambda s: s.str.upper(), id="upper"), - param( - lambda s: s.capitalize(), - lambda s: s.str.capitalize(), - id="capitalize", - ), param(lambda s: s.repeat(2), lambda s: s * 2, id="repeat"), param( lambda s: s.contains("a"), diff --git a/ibis/backends/druid/compiler.py b/ibis/backends/druid/compiler.py index 2be8891c7a15..eb64b6d03fcb 100644 --- a/ibis/backends/druid/compiler.py +++ b/ibis/backends/druid/compiler.py @@ -9,7 +9,10 @@ from ibis.backends.base.sqlglot.compiler import NULL, SQLGlotCompiler from ibis.backends.base.sqlglot.datatypes import DruidType from ibis.backends.base.sqlglot.dialects import Druid -from ibis.backends.base.sqlglot.rewrites import rewrite_sample_as_filter +from ibis.backends.base.sqlglot.rewrites import ( + rewrite_capitalize, + rewrite_sample_as_filter, +) class DruidCompiler(SQLGlotCompiler): @@ -17,7 +20,14 @@ class DruidCompiler(SQLGlotCompiler): dialect = Druid type_mapper = DruidType - rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites) + rewrites = ( + rewrite_sample_as_filter, + *( + rewrite + for rewrite in SQLGlotCompiler.rewrites + if rewrite is not rewrite_capitalize + ), + ) UNSUPPORTED_OPERATIONS = frozenset( ( diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index 17d599156d16..2f6bbb143823 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -317,11 +317,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype): else: return None - def visit_Capitalize(self, op, *, arg): - return self.f.concat( - self.f.upper(self.f.substr(arg, 1, 1)), self.f.lower(self.f.substr(arg, 2)) - ) - def _neg_idx_to_pos(self, array, idx): arg_length = self.f.array_size(array) return self.if_( diff --git a/ibis/backends/impala/tests/snapshots/test_string_builtins/test_string_builtins/capitalize/out.sql b/ibis/backends/impala/tests/snapshots/test_string_builtins/test_string_builtins/capitalize/out.sql index d23a9119ada0..418a92882745 100644 --- a/ibis/backends/impala/tests/snapshots/test_string_builtins/test_string_builtins/capitalize/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_string_builtins/test_string_builtins/capitalize/out.sql @@ -1,3 +1,22 @@ SELECT - INITCAP(`t0`.`string_col`) AS `Capitalize(string_col)` + CONCAT( + UPPER( + IF( + ( + 0 + 1 + ) >= 1, + SUBSTRING(`t0`.`string_col`, 0 + 1, 1), + SUBSTRING(`t0`.`string_col`, 0 + 1 + LENGTH(`t0`.`string_col`), 1) + ) + ), + LOWER( + IF( + ( + 1 + 1 + ) >= 1, + SUBSTRING(`t0`.`string_col`, 1 + 1, LENGTH(`t0`.`string_col`)), + SUBSTRING(`t0`.`string_col`, 1 + 1 + LENGTH(`t0`.`string_col`), LENGTH(`t0`.`string_col`)) + ) + ) + ) AS `Capitalize(string_col)` FROM `functional_alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/impala/tests/test_exprs.py b/ibis/backends/impala/tests/test_exprs.py index 34aa186bbda7..8fdc4e1b4358 100644 --- a/ibis/backends/impala/tests/test_exprs.py +++ b/ibis/backends/impala/tests/test_exprs.py @@ -286,7 +286,6 @@ def test_decimal_builtins_2(con, func, expected): (L(" a ").strip(), "a"), (L(" a ").lstrip(), "a "), (L(" a ").rstrip(), " a"), - (L("abcd").capitalize(), "Abcd"), (L("abcd").substr(0, 2), "ab"), (L("abcd").left(2), "ab"), (L("abcd").right(2), "cd"), diff --git a/ibis/backends/mssql/compiler.py b/ibis/backends/mssql/compiler.py index dd438ae5f867..3e696229a9bf 100644 --- a/ibis/backends/mssql/compiler.py +++ b/ibis/backends/mssql/compiler.py @@ -183,13 +183,6 @@ def visit_StringLength(self, op, *, arg): """ return paren(self.f.len(self.f.concat("A", arg, "Z")) - 2) - def visit_Capitalize(self, op, *, arg): - length = paren(self.f.len(self.f.concat("A", arg, "Z")) - 2) - return self.f.concat( - self.f.upper(self.f.substring(arg, 1, 1)), - self.f.lower(self.f.substring(arg, 2, length - 1)), - ) - def visit_GroupConcat(self, op, *, arg, sep, where): if where is not None: arg = self.if_(where, arg, NULL) diff --git a/ibis/backends/mysql/compiler.py b/ibis/backends/mysql/compiler.py index c0e720b3e33c..41460616badd 100644 --- a/ibis/backends/mysql/compiler.py +++ b/ibis/backends/mysql/compiler.py @@ -291,11 +291,6 @@ def visit_StringFind(self, op, *, arg, substr, start, end): return self.f.locate(substr, arg, start + 1) return self.f.locate(substr, arg) - def visit_Capitalize(self, op, *, arg): - return self.f.concat( - self.f.upper(self.f.left(arg, 1)), self.f.lower(self.f.substr(arg, 2)) - ) - def visit_LRStrip(self, op, *, arg, position): return reduce( lambda arg, char: self.f.trim(this=arg, position=position, expression=char), diff --git a/ibis/backends/pandas/tests/test_strings.py b/ibis/backends/pandas/tests/test_strings.py index e583cb53437e..3aa9b4bc5a00 100644 --- a/ibis/backends/pandas/tests/test_strings.py +++ b/ibis/backends/pandas/tests/test_strings.py @@ -49,11 +49,6 @@ param(lambda s: s.reverse(), lambda s: s.str[::-1], id="reverse"), param(lambda s: s.lower(), lambda s: s.str.lower(), id="lower"), param(lambda s: s.upper(), lambda s: s.str.upper(), id="upper"), - param( - lambda s: s.capitalize(), - lambda s: s.str.capitalize(), - id="capitalize", - ), param(lambda s: s.repeat(2), lambda s: s * 2, id="repeat"), param( lambda s: s.contains("a"), diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index f51ca271dd55..22a88d84cb58 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -503,7 +503,6 @@ def in_values(op, **kw): ops.RStrip: "strip_chars_end", ops.Lowercase: "to_lowercase", ops.Uppercase: "to_uppercase", - ops.Capitalize: "to_titlecase", } @@ -514,6 +513,15 @@ def string_length(op, **kw): return arg.str.len_bytes().cast(typ) +@translate.register(ops.Capitalize) +def capitalize(op, **kw): + arg = translate(op.arg, **kw) + typ = dtype_to_polars(op.dtype) + first = arg.str.slice(0, 1).str.to_uppercase() + rest = arg.str.slice(1, None).str.to_lowercase() + return (first + rest).cast(typ) + + @translate.register(ops.StringUnary) def string_unary(op, **kw): arg = translate(op.arg, **kw) diff --git a/ibis/backends/postgres/tests/test_functions.py b/ibis/backends/postgres/tests/test_functions.py index 64e3f3cfc25a..de1256853df3 100644 --- a/ibis/backends/postgres/tests/test_functions.py +++ b/ibis/backends/postgres/tests/test_functions.py @@ -243,14 +243,6 @@ def test_string_contains(con, haystack, needle, expected): assert con.execute(expr) == expected -@pytest.mark.parametrize( - ("value", "expected"), - [("foo bar foo", "Foo Bar Foo"), ("foobar Foo", "Foobar Foo")], -) -def test_capitalize(con, value, expected): - assert con.execute(L(value).capitalize()) == expected - - def test_repeat(con): expr = L("bar ").repeat(3) assert con.execute(expr) == "bar bar bar " diff --git a/ibis/backends/risingwave/tests/test_functions.py b/ibis/backends/risingwave/tests/test_functions.py index d680fb3190f9..04281a722d9c 100644 --- a/ibis/backends/risingwave/tests/test_functions.py +++ b/ibis/backends/risingwave/tests/test_functions.py @@ -90,14 +90,6 @@ def test_string_contains(con, haystack, needle, expected): assert con.execute(expr) == expected -@pytest.mark.parametrize( - ("value", "expected"), - [("foo bar foo", "Foo Bar Foo"), ("foobar Foo", "Foobar Foo")], -) -def test_capitalize(con, value, expected): - assert con.execute(L(value).capitalize()) == expected - - def test_repeat(con): expr = L("bar ").repeat(3) assert con.execute(expr) == "bar bar bar " diff --git a/ibis/backends/sqlite/udf.py b/ibis/backends/sqlite/udf.py index 380be055095e..d4d646dbb95f 100644 --- a/ibis/backends/sqlite/udf.py +++ b/ibis/backends/sqlite/udf.py @@ -235,11 +235,6 @@ def _ibis_string_ascii(string): return ord(string[0]) -@udf -def _ibis_capitalize(string): - return string.capitalize() - - @udf def _ibis_rpad(string, width, pad): return string.ljust(width, pad)[:width] diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index 30151b01b668..2595492e0803 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -562,11 +562,6 @@ def uses_java_re(t): lambda t: t.string_col.str.rstrip(), id="rstrip", ), - param( - lambda t: t.string_col.capitalize(), - lambda t: t.string_col.str.capitalize(), - id="capitalize", - ), param( lambda t: t.date_string_col.substr(2, 3), lambda t: t.date_string_col.str[2:5], @@ -844,11 +839,52 @@ def test_parse_url(con, result_func, expected): assert result == expected -def test_capitalize(con): - s = ibis.literal("aBc") - expected = "Abc" +@pytest.mark.parametrize( + ("inp, expected"), + [ + param( + None, + None, + id="none", + marks=[ + pytest.mark.notyet( + ["druid"], + raises=PyDruidProgrammingError, + reason="illegal use of NULL", + ) + ], + ), + param( + "", + "", + id="empty", + marks=[ + pytest.mark.notyet( + ["oracle"], + reason="https://github.com/oracle/python-oracledb/issues/298", + raises=AssertionError, + ), + pytest.mark.notyet(["exasol"], raises=AssertionError), + ], + ), + param("Abc", "Abc", id="no_change"), + param("abc", "Abc", id="lower_to_upper"), + param("aBC", "Abc", id="mixed_to_upper"), + param(" abc", " abc", id="leading_space"), + param("9abc", "9abc", id="leading_digit"), + param("aBc dEf", "Abc def", id="mixed_with_space"), + param("aBc-dEf", "Abc-def", id="mixed_with_hyphen"), + param("aBc1dEf", "Abc1def", id="mixed_with_digit"), + ], +) +def test_capitalize(con, inp, expected): + s = ibis.literal(inp, type="string") expr = s.capitalize() - assert con.execute(expr) == expected + result = con.execute(expr) + if expected is not None: + assert result == expected + else: + assert pd.isnull(result) @pytest.mark.notimpl( diff --git a/ibis/expr/types/strings.py b/ibis/expr/types/strings.py index 10d3d6331e77..8a3a6a4c5903 100644 --- a/ibis/expr/types/strings.py +++ b/ibis/expr/types/strings.py @@ -388,7 +388,10 @@ def rstrip(self) -> StringValue: return ops.RStrip(self).to_expr() def capitalize(self) -> StringValue: - """Capitalize the input string. + """Uppercase the first letter, lowercase the rest. + + This API matches the semantics of the Python [](`str.capitalize`) + method. Returns ------- @@ -399,7 +402,7 @@ def capitalize(self) -> StringValue: -------- >>> import ibis >>> ibis.options.interactive = True - >>> t = ibis.memtable({"s": ["abc", "def", "ghi"]}) + >>> t = ibis.memtable({"s": ["aBC", " abc", "ab cd", None]}) >>> t.s.capitalize() ┏━━━━━━━━━━━━━━━┓ ┃ Capitalize(s) ┃ @@ -407,14 +410,22 @@ def capitalize(self) -> StringValue: │ string │ ├───────────────┤ │ Abc │ - │ Def │ - │ Ghi │ + │ abc │ + │ Ab cd │ + │ NULL │ └───────────────┘ """ return ops.Capitalize(self).to_expr() initcap = capitalize + @util.deprecated( + instead="use the `capitalize` method", as_of="9.0", removed_in="10.0" + ) + def initcap(self) -> StringValue: + """Deprecated. Use `capitalize` instead.""" + return self.capitalize() + def __contains__(self, *_: Any) -> bool: raise TypeError("Use string_expr.contains(arg)")