Skip to content

Commit 802d937

Browse files
committed
Fix dummies
1 parent 3984ce1 commit 802d937

File tree

3 files changed

+89
-30
lines changed

3 files changed

+89
-30
lines changed

crates/polars-ops/src/series/ops/to_dummies.rs

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,36 @@ impl ToDummies for Series {
3030
) -> PolarsResult<DataFrame> {
3131
let sep = separator.unwrap_or("_");
3232
let col_name = self.name();
33-
let groups = self.group_tuples(true, drop_first)?;
34-
35-
// SAFETY: groups are in bounds
33+
let has_nulls = self.has_nulls();
34+
let is_sorted = self.is_sorted(SortOptions::new())?;
35+
let mut groups = self.group_tuples(true, is_sorted)?;
36+
groups.sort();
37+
38+
// If we have nulls, the first group is the null group. If drop_first is called, we must
39+
// ensure that it's not the null group that we drop. If we are not also dropping nulls, and
40+
// we have a null column, the second group is the one that we need to drop.
41+
let remove_second = drop_first && !drop_nulls && has_nulls;
42+
43+
let skip = if drop_first && drop_nulls && has_nulls {
44+
// If nulls are present and we are dropping both nulls and the first column, we skip the
45+
// first two columns.
46+
2
47+
} else if (drop_first && !has_nulls) || (drop_nulls && has_nulls) {
48+
// We skip the first column if either we have nulls and we want to drop them, or if
49+
// we don't have nulls and want to drop the first column.
50+
1
51+
} else {
52+
0
53+
};
54+
55+
// SAFETY: groups are in bounds.
3656
let columns = unsafe { self.agg_first(&groups) };
37-
let columns = columns.iter().zip(groups.iter()).skip(drop_first as usize);
57+
let columns = columns.iter().zip(groups.iter()).enumerate().skip(skip);
3858
let columns = columns
39-
.filter_map(|(av, group)| {
59+
.filter_map(|(idx, (av, group))| {
60+
if remove_second && idx == 1 {
61+
return None;
62+
}
4063
// strings are formatted with extra \" \" in polars, so we
4164
// extract the string
4265
let name = if let Some(s) = av.get_str() {
@@ -46,10 +69,6 @@ impl ToDummies for Series {
4669
format_pl_smallstr!("{col_name}{sep}{av}")
4770
};
4871

49-
if av.is_null() && drop_nulls {
50-
return None;
51-
}
52-
5372
let ca = match group {
5473
GroupsIndicator::Idx((_, group)) => dummies_helper_idx(group, self.len(), name),
5574
GroupsIndicator::Slice([offset, len]) => {

py-polars/src/polars/series/series.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2294,7 +2294,8 @@ def to_dummies(
22942294
drop_first
22952295
Remove the first category from the variable being encoded.
22962296
drop_nulls
2297-
If there are `None` values in the series, a `null` column is not generated
2297+
If there are `None` values in the series, a `null` column is not generated.
2298+
Null values in the nput are represented by zero vectors.
22982299
22992300
Examples
23002301
--------
@@ -2322,6 +2323,20 @@ def to_dummies(
23222323
│ 1 ┆ 0 │
23232324
│ 0 ┆ 1 │
23242325
└─────┴─────┘
2326+
2327+
>>> s = pl.Series("a", [1, 2, None, 3])
2328+
>>> s.to_dummies(drop_nulls=True, drop_first=True)
2329+
shape: (4, 2)
2330+
┌─────┬────────┐
2331+
│ a_3 ┆ a_null │
2332+
│ --- ┆ --- │
2333+
│ u8 ┆ u8 │
2334+
╞═════╪════════╡
2335+
│ 0 ┆ 0 │
2336+
│ 0 ┆ 0 │
2337+
│ 0 ┆ 1 │
2338+
│ 1 ┆ 0 │
2339+
└─────┴────────┘
23252340
"""
23262341
return wrap_df(self._s.to_dummies(separator, drop_first, drop_nulls))
23272342

py-polars/tests/unit/series/test_series.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,33 +1360,58 @@ def test_temporal_comparison(
13601360
)
13611361

13621362

1363-
def test_to_dummies() -> None:
1364-
s = pl.Series("a", [1, 2, 3])
1365-
result = s.to_dummies()
1363+
@pytest.mark.parametrize(
1364+
("drop_nulls", "drop_first"),
1365+
[
1366+
(False, False),
1367+
(False, True),
1368+
(True, False),
1369+
(True, True),
1370+
],
1371+
)
1372+
def test_to_dummies_with_nulls(drop_nulls: bool, drop_first: bool) -> None:
1373+
s = pl.Series("s", [None, "a", "a", None, "b", "c"])
13661374
expected = pl.DataFrame(
1367-
{"a_1": [1, 0, 0], "a_2": [0, 1, 0], "a_3": [0, 0, 1]},
1368-
schema={"a_1": pl.UInt8, "a_2": pl.UInt8, "a_3": pl.UInt8},
1369-
)
1375+
{
1376+
"s_a": [0, 1, 1, 0, 0, 0],
1377+
"s_b": [0, 0, 0, 0, 1, 0],
1378+
"s_c": [0, 0, 0, 0, 0, 1],
1379+
"s_null": [1, 0, 0, 1, 0, 0],
1380+
}
1381+
).cast(pl.UInt8)
1382+
1383+
if drop_nulls:
1384+
expected = expected.drop("s_null")
1385+
if drop_first:
1386+
expected = expected.drop("s_a")
1387+
1388+
result = s.to_dummies(drop_nulls=drop_nulls, drop_first=drop_first)
13701389
assert_frame_equal(result, expected)
13711390

13721391

1373-
def test_to_dummies_drop_first() -> None:
1374-
s = pl.Series("a", [1, 2, 3])
1375-
result = s.to_dummies(drop_first=True)
1392+
@pytest.mark.parametrize(
1393+
("drop_nulls", "drop_first"),
1394+
[
1395+
(False, False),
1396+
(False, True),
1397+
(True, False),
1398+
(True, True),
1399+
],
1400+
)
1401+
def test_to_dummies_no_nulls(drop_nulls: bool, drop_first: bool) -> None:
1402+
s = pl.Series("s", ["a", "a", "b", "c"])
13761403
expected = pl.DataFrame(
1377-
{"a_2": [0, 1, 0], "a_3": [0, 0, 1]},
1378-
schema={"a_2": pl.UInt8, "a_3": pl.UInt8},
1379-
)
1380-
assert_frame_equal(result, expected)
1404+
{
1405+
"s_a": [1, 1, 0, 0],
1406+
"s_b": [0, 0, 1, 0],
1407+
"s_c": [0, 0, 0, 1],
1408+
}
1409+
).cast(pl.UInt8)
13811410

1411+
if drop_first:
1412+
expected = expected.drop("s_a")
13821413

1383-
def test_to_dummies_drop_nulls() -> None:
1384-
s = pl.Series("a", [1, 2, None])
1385-
result = s.to_dummies(drop_nulls=True)
1386-
expected = pl.DataFrame(
1387-
{"a_1": [1, 0, 0], "a_2": [0, 1, 0]},
1388-
schema={"a_1": pl.UInt8, "a_2": pl.UInt8},
1389-
)
1414+
result = s.to_dummies(drop_nulls=drop_nulls, drop_first=drop_first)
13901415
assert_frame_equal(result, expected)
13911416

13921417

0 commit comments

Comments
 (0)