Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): move from .case() to .cases() #9039

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions docs/posts/ci-analysis/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,12 @@ Let's also give them some names that'll look nice on our plots.
stats = stats.mutate(
raw_improvements=_.has_poetry.cast("int") + _.has_team.cast("int")
).mutate(
improvements=(
_.raw_improvements.case()
.when(0, "None")
.when(1, "Poetry")
.when(2, "Poetry + Team Plan")
.else_("NA")
.end()
),
improvements=_.raw_improvements.cases(
(0, "None"),
(1, "Poetry"),
(2, "Poetry + Team Plan"),
else_="NA",
)
team_plan=ibis.where(_.raw_improvements > 1, "Poetry + Team Plan", "None"),
)
stats
Expand Down
24 changes: 11 additions & 13 deletions docs/tutorials/ibis-for-sql-users.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -473,11 +473,11 @@ semantics:
case = (
t.one.cast("timestamp")
.year()
.case()
.when(2015, "This year")
.when(2014, "Last year")
.else_("Earlier")
.end()
.cases(
(2015, "This year"),
(2014, "Last year"),
else_="Earlier",
)
)
expr = t.mutate(year_group=case)
Expand All @@ -496,18 +496,16 @@ CASE
END
```

To do this, use `ibis.case`:
To do this, use `ibis.cases`:

```{python}
case = (
ibis.case()
.when(t.two < 0, t.three * 2)
.when(t.two > 1, t.three)
.else_(t.two)
.end()
cases = ibis.cases(
(t.two < 0, t.three * 2),
(t.two > 1, t.three),
else_=t.two,
)
expr = t.mutate(cond_value=case)
expr = t.mutate(cond_value=cases)
ibis.to_sql(expr)
```

Expand Down
14 changes: 5 additions & 9 deletions ibis/backends/clickhouse/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,22 +201,18 @@ def test_ifelse(alltypes, df, op, pandas_op):

def test_simple_case(con, alltypes, assert_sql):
t = alltypes
expr = (
t.string_col.case().when("foo", "bar").when("baz", "qux").else_("default").end()
)
expr = t.string_col.cases(("foo", "bar"), ("baz", "qux"), else_="default")

assert_sql(expr)
assert len(con.execute(expr))


def test_search_case(con, alltypes, assert_sql):
t = alltypes
expr = (
ibis.case()
.when(t.float_col > 0, t.int_col * 2)
.when(t.float_col < 0, t.int_col)
.else_(0)
.end()
expr = ibis.cases(
(t.float_col > 0, t.int_col * 2),
(t.float_col < 0, t.int_col),
else_=0,
)

assert_sql(expr)
Expand Down
21 changes: 5 additions & 16 deletions ibis/backends/dask/tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def q_fun(x, quantile):


def test_searched_case_scalar(client):
expr = ibis.case().when(True, 1).when(False, 2).end()
expr = ibis.cases((True, 1), (False, 2))
result = client.execute(expr)
expected = np.int8(1)
assert result == expected
Expand All @@ -783,12 +783,8 @@ def test_searched_case_scalar(client):
def test_searched_case_column(batting, batting_pandas_df):
t = batting
df = batting_pandas_df
expr = (
ibis.case()
.when(t.RBI < 5, "really bad team")
.when(t.teamID == "PH1", "ph1 team")
.else_(t.teamID)
.end()
expr = ibis.cases(
(t.RBI < 5, "really bad team"), (t.teamID == "PH1", "ph1 team"), else_=t.teamID
)
result = expr.execute()
expected = pd.Series(
Expand All @@ -803,7 +799,7 @@ def test_searched_case_column(batting, batting_pandas_df):

def test_simple_case_scalar(client):
x = ibis.literal(2)
expr = x.case().when(2, x - 1).when(3, x + 1).when(4, x + 2).end()
expr = x.cases((2, x - 1), (3, x + 1), (4, x + 2))
result = client.execute(expr)
expected = np.int8(1)
assert result == expected
Expand All @@ -812,14 +808,7 @@ def test_simple_case_scalar(client):
def test_simple_case_column(batting, batting_pandas_df):
t = batting
df = batting_pandas_df
expr = (
t.RBI.case()
.when(5, "five")
.when(4, "four")
.when(3, "three")
.else_("could be good?")
.end()
)
expr = t.RBI.cases((5, "five"), (4, "four"), (3, "three"), else_="could be good?")
result = expr.execute()
expected = pd.Series(
np.select(
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/impala/tests/test_case_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ def table(mockcon):

@pytest.fixture
def simple_case(table):
return table.g.case().when("foo", "bar").when("baz", "qux").else_("default").end()
return table.g.cases(("foo", "bar"), ("baz", "qux"), else_="default")


@pytest.fixture
def search_case(table):
t = table
return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end()
return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2))


@pytest.fixture
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def visit(cls, op: ops.IsNan, arg):
def visit(
cls, op: ops.SearchedCase | ops.SimpleCase, cases, results, default, base=None
):
if not cases:
return default
if base is not None:
cases = tuple(base == case for case in cases)
cases, _ = cls.asframe(cases, concat=False)
Expand Down
23 changes: 7 additions & 16 deletions ibis/backends/pandas/tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def test_summary_non_numeric(batting, batting_df):


def test_searched_case_scalar(client):
expr = ibis.case().when(True, 1).when(False, 2).end()
expr = ibis.cases((True, 1), (False, 2))
result = client.execute(expr)
expected = np.int8(1)
assert result == expected
Expand All @@ -693,12 +693,10 @@ def test_searched_case_scalar(client):
def test_searched_case_column(batting, batting_df):
t = batting
df = batting_df
expr = (
ibis.case()
.when(t.RBI < 5, "really bad team")
.when(t.teamID == "PH1", "ph1 team")
.else_(t.teamID)
.end()
expr = ibis.cases(
(t.RBI < 5, "really bad team"),
(t.teamID == "PH1", "ph1 team"),
else_=t.teamID,
)
result = expr.execute()
expected = pd.Series(
Expand All @@ -713,7 +711,7 @@ def test_searched_case_column(batting, batting_df):

def test_simple_case_scalar(client):
x = ibis.literal(2)
expr = x.case().when(2, x - 1).when(3, x + 1).when(4, x + 2).end()
expr = x.cases((2, x - 1), (3, x + 1), (4, x + 2))
result = client.execute(expr)
expected = np.int8(1)
assert result == expected
Expand All @@ -722,14 +720,7 @@ def test_simple_case_scalar(client):
def test_simple_case_column(batting, batting_df):
t = batting
df = batting_df
expr = (
t.RBI.case()
.when(5, "five")
.when(4, "four")
.when(3, "three")
.else_("could be good?")
.end()
)
expr = t.RBI.cases((5, "five"), (4, "four"), (3, "three"), else_="could be good?")
result = expr.execute()
expected = pd.Series(
np.select(
Expand Down
6 changes: 2 additions & 4 deletions ibis/backends/snowflake/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,8 @@ def predict_price(

def cases(value, mapping):
"""This should really be a top-level function or method."""
expr = ibis.case()
for k, v in mapping.items():
expr = expr.when(value == k, v)
return expr.end()
pairs = [(value == k, v) for k, v in mapping.items()]
return ibis.cases(*pairs)

diamonds = con.tables.DIAMONDS
expr = diamonds.mutate(
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,8 @@ def visit_VarianceStandardDevCovariance(self, op, *, how, where, **kw):
)

def visit_SimpleCase(self, op, *, base=None, cases, results, default):
if not cases:
return default
return sge.Case(
this=base, ifs=list(map(self.if_, cases, results)), default=default
)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/sql/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,13 @@ def difference(con):
@pytest.fixture(scope="module")
def simple_case(con):
t = con.table("alltypes")
return t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end()
return t.g.cases(("foo", "bar"), ("baz", "qux"), else_="default")


@pytest.fixture(scope="module")
def search_case(con):
t = con.table("alltypes")
return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end()
return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2))


@pytest.fixture(scope="module")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,14 @@
lit2 = ibis.literal("bar")

result = alltypes.select(
alltypes.g.case()
.when(lit, lit2)
.when(lit1, ibis.literal("qux"))
.else_(ibis.literal("default"))
.end()
.name("col1"),
ibis.case()
.when(alltypes.g == lit, lit2)
.when(alltypes.g == lit1, alltypes.g)
.else_(ibis.literal(None).cast("string"))
.end()
.name("col2"),
alltypes.g.cases(
(lit, lit2), (lit1, ibis.literal("qux")), else_=ibis.literal("default")
).name("col1"),
ibis.cases(
(alltypes.g == lit, lit2),
(alltypes.g == lit1, alltypes.g),
else_=ibis.literal(None).cast("string"),
).name("col2"),
alltypes.a,
alltypes.b,
alltypes.c,
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/sql/test_select_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,8 @@ def test_bool_bool(snapshot):

def test_case_in_projection(alltypes, snapshot):
t = alltypes
expr = t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end()
expr2 = ibis.case().when(t.g == "foo", "bar").when(t.g == "baz", t.g).end()
expr = t.g.cases(("foo", "bar"), ("baz", "qux"), else_=("default"))
expr2 = ibis.cases((t.g == "foo", "bar"), (t.g == "baz", t.g))
expr = t[expr.name("col1"), expr2.name("col2"), t]

snapshot.assert_match(to_sql(expr), "out.sql")
Expand Down
8 changes: 3 additions & 5 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def test_arbitrary(backend, alltypes, df, filtered):
# _something_ we create a column that is a mix of nulls and a single value
# (or a single value after filtering is applied).
if filtered:
new = alltypes.int_col.cases([(3, 30), (4, 40)])
new = alltypes.int_col.cases((3, 30), (4, 40))
where = _.int_col == 3
else:
new = (alltypes.int_col == 3).ifelse(30, None)
Expand Down Expand Up @@ -1428,9 +1428,7 @@ def collect_udf(v):

def test_binds_are_cast(alltypes):
expr = alltypes.aggregate(
high_line_count=(
alltypes.string_col.case().when("1-URGENT", 1).else_(0).end().sum()
)
high_line_count=alltypes.string_col.cases(("1-URGENT", 1), else_=0).sum()
)

expr.execute()
Expand Down Expand Up @@ -1476,7 +1474,7 @@ def test_agg_name_in_output_column(alltypes):
def test_grouped_case(backend, con):
table = ibis.memtable({"key": [1, 1, 2, 2], "value": [10, 30, 20, 40]})

case_expr = ibis.case().when(table.value < 25, table.value).else_(ibis.null()).end()
case_expr = ibis.cases((table.value < 25, table.value), else_=ibis.null())

expr = (
table.group_by(k="key").aggregate(mx=case_expr.max()).dropna("k").order_by("k")
Expand Down
Loading
Loading