Skip to content

Commit

Permalink
fix(strings): make StringValue.capitalize() consistent across backe…
Browse files Browse the repository at this point in the history
…nds (#8270)

Fixes #8271.

BREAKING CHANGE: Backends that previously used initcap (analogous to str.title) to implement StringValue.capitalize() will produce different results when the input string contains multiple words (a word's definition being backend-specific).
  • Loading branch information
NickCrews authored Feb 15, 2024
1 parent e2f950d commit c4055d6
Show file tree
Hide file tree
Showing 18 changed files with 115 additions and 77 deletions.
3 changes: 2 additions & 1 deletion ibis/backends/base/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
add_order_by_to_empty_ranking_window_functions,
empty_in_values_right_side,
one_to_zero_index,
rewrite_capitalize,
sqlize,
)
from ibis.expr.operations.udf import InputType
Expand Down Expand Up @@ -179,6 +180,7 @@ class SQLGlotCompiler(abc.ABC):
one_to_zero_index,
add_one_to_nth_value_input,
replace_bucket,
rewrite_capitalize,
)
"""A sequence of rewrites to apply to the expression tree before compilation."""

Expand Down Expand Up @@ -223,7 +225,6 @@ class SQLGlotCompiler(abc.ABC):
ops.Asin: "asin",
ops.Atan2: "atan2",
ops.Atan: "atan",
ops.Capitalize: "initcap",
ops.Cos: "cos",
ops.Cot: "cot",
ops.Count: "count",
Expand Down
12 changes: 12 additions & 0 deletions ibis/backends/base/sqlglot/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,18 @@ def add_one_to_nth_value_input(_, **kwargs):
return _.copy(nth=nth)


@replace(p.Capitalize)
def rewrite_capitalize(_, **kwargs):
"""Rewrite Capitalize in terms of substring, concat, upper, and lower."""
first = ops.Uppercase(ops.Substring(_.arg, start=0, length=1))
# use length instead of length - 1 to avoid backends complaining about
# asking for negative length
#
# there are at most length - 1 characters, so asking for length is fine
rest = ops.Lowercase(ops.Substring(_.arg, start=1, length=ops.StringLength(_.arg)))
return ops.StringConcat((first, rest))


@replace(p.Sample)
def rewrite_sample_as_filter(_, **kwargs):
"""Rewrite Sample as `t.filter(random() <= fraction)`.
Expand Down
5 changes: 0 additions & 5 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,6 @@ def visit_ArrayRepeat(self, op, *, arg, times):
sg.select(arg[self.f.safe_ordinal(idx)]).from_(self._unnest(series, as_=i))
)

def visit_Capitalize(self, op, *, arg):
return self.f.concat(
self.f.upper(self.f.substr(arg, 1, 1)), self.f.lower(self.f.substr(arg, 2))
)

def visit_NthValue(self, op, *, arg, nth):
if not isinstance(op.nth, ops.Literal):
raise com.UnsupportedOperationError(
Expand Down
5 changes: 0 additions & 5 deletions ibis/backends/clickhouse/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,6 @@ def visit_StringSplit(self, op, *, arg, delimiter):
delimiter, self.cast(arg, dt.String(nullable=False))
)

def visit_Capitalize(self, op, *, arg):
return self.f.concat(
self.f.upper(self.f.substr(arg, 1, 1)), self.f.lower(self.f.substr(arg, 2))
)

def visit_GroupConcat(self, op, *, arg, sep, where):
call = self.agg.groupArray(arg, where=where)
return self.if_(self.f.empty(call), NULL, self.f.arrayStringConcat(call, sep))
Expand Down
5 changes: 0 additions & 5 deletions ibis/backends/dask/tests/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@
param(lambda s: s.reverse(), lambda s: s.str[::-1], id="reverse"),
param(lambda s: s.lower(), lambda s: s.str.lower(), id="lower"),
param(lambda s: s.upper(), lambda s: s.str.upper(), id="upper"),
param(
lambda s: s.capitalize(),
lambda s: s.str.capitalize(),
id="capitalize",
),
param(lambda s: s.repeat(2), lambda s: s * 2, id="repeat"),
param(
lambda s: s.contains("a"),
Expand Down
14 changes: 12 additions & 2 deletions ibis/backends/druid/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,25 @@
from ibis.backends.base.sqlglot.compiler import NULL, SQLGlotCompiler
from ibis.backends.base.sqlglot.datatypes import DruidType
from ibis.backends.base.sqlglot.dialects import Druid
from ibis.backends.base.sqlglot.rewrites import rewrite_sample_as_filter
from ibis.backends.base.sqlglot.rewrites import (
rewrite_capitalize,
rewrite_sample_as_filter,
)


class DruidCompiler(SQLGlotCompiler):
__slots__ = ()

dialect = Druid
type_mapper = DruidType
rewrites = (rewrite_sample_as_filter, *SQLGlotCompiler.rewrites)
rewrites = (
rewrite_sample_as_filter,
*(
rewrite
for rewrite in SQLGlotCompiler.rewrites
if rewrite is not rewrite_capitalize
),
)

UNSUPPORTED_OPERATIONS = frozenset(
(
Expand Down
5 changes: 0 additions & 5 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
else:
return None

def visit_Capitalize(self, op, *, arg):
return self.f.concat(
self.f.upper(self.f.substr(arg, 1, 1)), self.f.lower(self.f.substr(arg, 2))
)

def _neg_idx_to_pos(self, array, idx):
arg_length = self.f.array_size(array)
return self.if_(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
SELECT
INITCAP(`t0`.`string_col`) AS `Capitalize(string_col)`
CONCAT(
UPPER(
IF(
(
0 + 1
) >= 1,
SUBSTRING(`t0`.`string_col`, 0 + 1, 1),
SUBSTRING(`t0`.`string_col`, 0 + 1 + LENGTH(`t0`.`string_col`), 1)
)
),
LOWER(
IF(
(
1 + 1
) >= 1,
SUBSTRING(`t0`.`string_col`, 1 + 1, LENGTH(`t0`.`string_col`)),
SUBSTRING(`t0`.`string_col`, 1 + 1 + LENGTH(`t0`.`string_col`), LENGTH(`t0`.`string_col`))
)
)
) AS `Capitalize(string_col)`
FROM `functional_alltypes` AS `t0`
1 change: 0 additions & 1 deletion ibis/backends/impala/tests/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def test_decimal_builtins_2(con, func, expected):
(L(" a ").strip(), "a"),
(L(" a ").lstrip(), "a "),
(L(" a ").rstrip(), " a"),
(L("abcd").capitalize(), "Abcd"),
(L("abcd").substr(0, 2), "ab"),
(L("abcd").left(2), "ab"),
(L("abcd").right(2), "cd"),
Expand Down
7 changes: 0 additions & 7 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,6 @@ def visit_StringLength(self, op, *, arg):
"""
return paren(self.f.len(self.f.concat("A", arg, "Z")) - 2)

def visit_Capitalize(self, op, *, arg):
length = paren(self.f.len(self.f.concat("A", arg, "Z")) - 2)
return self.f.concat(
self.f.upper(self.f.substring(arg, 1, 1)),
self.f.lower(self.f.substring(arg, 2, length - 1)),
)

def visit_GroupConcat(self, op, *, arg, sep, where):
if where is not None:
arg = self.if_(where, arg, NULL)
Expand Down
5 changes: 0 additions & 5 deletions ibis/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,6 @@ def visit_StringFind(self, op, *, arg, substr, start, end):
return self.f.locate(substr, arg, start + 1)
return self.f.locate(substr, arg)

def visit_Capitalize(self, op, *, arg):
return self.f.concat(
self.f.upper(self.f.left(arg, 1)), self.f.lower(self.f.substr(arg, 2))
)

def visit_LRStrip(self, op, *, arg, position):
return reduce(
lambda arg, char: self.f.trim(this=arg, position=position, expression=char),
Expand Down
5 changes: 0 additions & 5 deletions ibis/backends/pandas/tests/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,6 @@
param(lambda s: s.reverse(), lambda s: s.str[::-1], id="reverse"),
param(lambda s: s.lower(), lambda s: s.str.lower(), id="lower"),
param(lambda s: s.upper(), lambda s: s.str.upper(), id="upper"),
param(
lambda s: s.capitalize(),
lambda s: s.str.capitalize(),
id="capitalize",
),
param(lambda s: s.repeat(2), lambda s: s * 2, id="repeat"),
param(
lambda s: s.contains("a"),
Expand Down
10 changes: 9 additions & 1 deletion ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,6 @@ def in_values(op, **kw):
ops.RStrip: "strip_chars_end",
ops.Lowercase: "to_lowercase",
ops.Uppercase: "to_uppercase",
ops.Capitalize: "to_titlecase",
}


Expand All @@ -514,6 +513,15 @@ def string_length(op, **kw):
return arg.str.len_bytes().cast(typ)


@translate.register(ops.Capitalize)
def capitalize(op, **kw):
arg = translate(op.arg, **kw)
typ = dtype_to_polars(op.dtype)
first = arg.str.slice(0, 1).str.to_uppercase()
rest = arg.str.slice(1, None).str.to_lowercase()
return (first + rest).cast(typ)


@translate.register(ops.StringUnary)
def string_unary(op, **kw):
arg = translate(op.arg, **kw)
Expand Down
8 changes: 0 additions & 8 deletions ibis/backends/postgres/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,6 @@ def test_string_contains(con, haystack, needle, expected):
assert con.execute(expr) == expected


@pytest.mark.parametrize(
("value", "expected"),
[("foo bar foo", "Foo Bar Foo"), ("foobar Foo", "Foobar Foo")],
)
def test_capitalize(con, value, expected):
assert con.execute(L(value).capitalize()) == expected


def test_repeat(con):
expr = L("bar ").repeat(3)
assert con.execute(expr) == "bar bar bar "
Expand Down
8 changes: 0 additions & 8 deletions ibis/backends/risingwave/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,6 @@ def test_string_contains(con, haystack, needle, expected):
assert con.execute(expr) == expected


@pytest.mark.parametrize(
("value", "expected"),
[("foo bar foo", "Foo Bar Foo"), ("foobar Foo", "Foobar Foo")],
)
def test_capitalize(con, value, expected):
assert con.execute(L(value).capitalize()) == expected


def test_repeat(con):
expr = L("bar ").repeat(3)
assert con.execute(expr) == "bar bar bar "
Expand Down
5 changes: 0 additions & 5 deletions ibis/backends/sqlite/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,6 @@ def _ibis_string_ascii(string):
return ord(string[0])


@udf
def _ibis_capitalize(string):
return string.capitalize()


@udf
def _ibis_rpad(string, width, pad):
return string.ljust(width, pad)[:width]
Expand Down
54 changes: 45 additions & 9 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,11 +562,6 @@ def uses_java_re(t):
lambda t: t.string_col.str.rstrip(),
id="rstrip",
),
param(
lambda t: t.string_col.capitalize(),
lambda t: t.string_col.str.capitalize(),
id="capitalize",
),
param(
lambda t: t.date_string_col.substr(2, 3),
lambda t: t.date_string_col.str[2:5],
Expand Down Expand Up @@ -844,11 +839,52 @@ def test_parse_url(con, result_func, expected):
assert result == expected


def test_capitalize(con):
s = ibis.literal("aBc")
expected = "Abc"
@pytest.mark.parametrize(
("inp, expected"),
[
param(
None,
None,
id="none",
marks=[
pytest.mark.notyet(
["druid"],
raises=PyDruidProgrammingError,
reason="illegal use of NULL",
)
],
),
param(
"",
"",
id="empty",
marks=[
pytest.mark.notyet(
["oracle"],
reason="https://github.com/oracle/python-oracledb/issues/298",
raises=AssertionError,
),
pytest.mark.notyet(["exasol"], raises=AssertionError),
],
),
param("Abc", "Abc", id="no_change"),
param("abc", "Abc", id="lower_to_upper"),
param("aBC", "Abc", id="mixed_to_upper"),
param(" abc", " abc", id="leading_space"),
param("9abc", "9abc", id="leading_digit"),
param("aBc dEf", "Abc def", id="mixed_with_space"),
param("aBc-dEf", "Abc-def", id="mixed_with_hyphen"),
param("aBc1dEf", "Abc1def", id="mixed_with_digit"),
],
)
def test_capitalize(con, inp, expected):
s = ibis.literal(inp, type="string")
expr = s.capitalize()
assert con.execute(expr) == expected
result = con.execute(expr)
if expected is not None:
assert result == expected
else:
assert pd.isnull(result)


@pytest.mark.notimpl(
Expand Down
19 changes: 15 additions & 4 deletions ibis/expr/types/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,10 @@ def rstrip(self) -> StringValue:
return ops.RStrip(self).to_expr()

def capitalize(self) -> StringValue:
"""Capitalize the input string.
"""Uppercase the first letter, lowercase the rest.
This API matches the semantics of the Python [](`str.capitalize`)
method.
Returns
-------
Expand All @@ -399,22 +402,30 @@ def capitalize(self) -> StringValue:
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.memtable({"s": ["abc", "def", "ghi"]})
>>> t = ibis.memtable({"s": ["aBC", " abc", "ab cd", None]})
>>> t.s.capitalize()
┏━━━━━━━━━━━━━━━┓
┃ Capitalize(s) ┃
┡━━━━━━━━━━━━━━━┩
│ string │
├───────────────┤
│ Abc │
│ Def │
│ Ghi │
│ abc │
│ Ab cd │
│ NULL │
└───────────────┘
"""
return ops.Capitalize(self).to_expr()

initcap = capitalize

@util.deprecated(
instead="use the `capitalize` method", as_of="9.0", removed_in="10.0"
)
def initcap(self) -> StringValue:
"""Deprecated. Use `capitalize` instead."""
return self.capitalize()

def __contains__(self, *_: Any) -> bool:
raise TypeError("Use string_expr.contains(arg)")

Expand Down

0 comments on commit c4055d6

Please sign in to comment.