Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions crates/polars-ops/src/series/ops/to_dummies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,24 @@ impl ToDummies for Series {
) -> PolarsResult<DataFrame> {
let sep = separator.unwrap_or("_");
let col_name = self.name();
let groups = self.group_tuples(true, drop_first)?;

// SAFETY: groups are in bounds
// We only need to maintain order if we need to drop the first non-null item.
let maintain_order = drop_first;
let groups = self.group_tuples(true, maintain_order)?;

// SAFETY: groups are in bounds.
let columns = unsafe { self.agg_first(&groups) };
let columns = columns.iter().zip(groups.iter()).skip(drop_first as usize);
let columns = columns.iter().zip(groups.iter());
let mut seen_first = false;
let columns = columns
.filter_map(|(av, group)| {
if av.is_null() && drop_nulls {
return None;
} else if !seen_first && !av.is_null() && drop_first {
// The position of the first non-null item could be either 0 or 1.
seen_first = true;
return None;
}
// strings are formatted with extra \" \" in polars, so we
// extract the string
let name = if let Some(s) = av.get_str() {
Expand All @@ -46,10 +57,6 @@ impl ToDummies for Series {
format_pl_smallstr!("{col_name}{sep}{av}")
};

if av.is_null() && drop_nulls {
return None;
}

let ca = match group {
GroupsIndicator::Idx((_, group)) => dummies_helper_idx(group, self.len(), name),
GroupsIndicator::Slice([offset, len]) => {
Expand Down
17 changes: 16 additions & 1 deletion py-polars/src/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2294,7 +2294,8 @@ def to_dummies(
drop_first
Remove the first category from the variable being encoded.
drop_nulls
If there are `None` values in the series, a `null` column is not generated
If there are `None` values in the series, a `null` column is not generated.
Null values in the nput are represented by zero vectors.

Examples
--------
Expand Down Expand Up @@ -2322,6 +2323,20 @@ def to_dummies(
│ 1 ┆ 0 │
│ 0 ┆ 1 │
└─────┴─────┘

>>> s = pl.Series("a", [1, 2, None, 3])
>>> s.to_dummies(drop_nulls=True, drop_first=True)
shape: (4, 2)
┌─────┬─────┐
│ a_2 ┆ a_3 │
│ --- ┆ --- │
│ u8 ┆ u8 │
╞═════╪═════╡
│ 0 ┆ 0 │
│ 1 ┆ 0 │
│ 0 ┆ 0 │
│ 0 ┆ 1 │
└─────┴─────┘
"""
return wrap_df(self._s.to_dummies(separator, drop_first, drop_nulls))

Expand Down
65 changes: 45 additions & 20 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,33 +1360,58 @@ def test_temporal_comparison(
)


def test_to_dummies() -> None:
s = pl.Series("a", [1, 2, 3])
result = s.to_dummies()
@pytest.mark.parametrize(
("drop_nulls", "drop_first"),
[
(False, False),
(False, True),
(True, False),
(True, True),
],
)
def test_to_dummies_with_nulls(drop_nulls: bool, drop_first: bool) -> None:
s = pl.Series("s", [None, "a", "a", None, "b", "c"])
expected = pl.DataFrame(
{"a_1": [1, 0, 0], "a_2": [0, 1, 0], "a_3": [0, 0, 1]},
schema={"a_1": pl.UInt8, "a_2": pl.UInt8, "a_3": pl.UInt8},
)
{
"s_a": [0, 1, 1, 0, 0, 0],
"s_b": [0, 0, 0, 0, 1, 0],
"s_c": [0, 0, 0, 0, 0, 1],
"s_null": [1, 0, 0, 1, 0, 0],
}
).cast(pl.UInt8)

if drop_nulls:
expected = expected.drop("s_null")
if drop_first:
expected = expected.drop("s_a")

result = s.to_dummies(drop_nulls=drop_nulls, drop_first=drop_first)
assert_frame_equal(result, expected)


def test_to_dummies_drop_first() -> None:
s = pl.Series("a", [1, 2, 3])
result = s.to_dummies(drop_first=True)
@pytest.mark.parametrize(
("drop_nulls", "drop_first"),
[
(False, False),
(False, True),
(True, False),
(True, True),
],
)
def test_to_dummies_no_nulls(drop_nulls: bool, drop_first: bool) -> None:
s = pl.Series("s", ["a", "a", "b", "c"])
expected = pl.DataFrame(
{"a_2": [0, 1, 0], "a_3": [0, 0, 1]},
schema={"a_2": pl.UInt8, "a_3": pl.UInt8},
)
assert_frame_equal(result, expected)
{
"s_a": [1, 1, 0, 0],
"s_b": [0, 0, 1, 0],
"s_c": [0, 0, 0, 1],
}
).cast(pl.UInt8)

if drop_first:
expected = expected.drop("s_a")

def test_to_dummies_drop_nulls() -> None:
s = pl.Series("a", [1, 2, None])
result = s.to_dummies(drop_nulls=True)
expected = pl.DataFrame(
{"a_1": [1, 0, 0], "a_2": [0, 1, 0]},
schema={"a_1": pl.UInt8, "a_2": pl.UInt8},
)
result = s.to_dummies(drop_nulls=drop_nulls, drop_first=drop_first)
assert_frame_equal(result, expected)


Expand Down
Loading