Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions sklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,6 +1620,111 @@ def test_feature_union_set_output():
assert_array_equal(X_trans.index, X_test.index)


def test_feature_union_pandas_preserves_aggregated_index():
"""FeatureUnion should keep aggregator-defined index when lengths shrink."""
pd = pytest.importorskip("pandas")

class MeanAggregator(TransformerMixin, BaseEstimator):
def fit(self, X, y=None):
self.feature_names_in_ = list(X.columns)
return self

def transform(self, X, y=None):
summary = X.mean().to_frame().T
summary.index = pd.Index(["mean"], name="summary")
summary.columns = [f"mean_{name}" for name in self.feature_names_in_]
return summary

def get_feature_names_out(self, input_features=None):
return np.asarray(
[f"mean_{name}" for name in self.feature_names_in_], dtype=object
)

X, _ = load_iris(as_frame=True, return_X_y=True)
X_train, X_test = train_test_split(X, random_state=0)

union = FeatureUnion([("aggregate", MeanAggregator())])
union.set_output(transform="pandas")
union.fit(X_train)

X_trans = union.transform(X_test)

assert isinstance(X_trans, pd.DataFrame)
assert_array_equal(X_trans.columns, union.get_feature_names_out())
assert X_trans.index.equals(pd.Index(["mean"], name="summary"))


def test_feature_union_pandas_preserves_aggregated_index_series():
"""Series outputs keep their aggregated index under pandas wrapping."""
pd = pytest.importorskip("pandas")

class SeriesAggregator(TransformerMixin, BaseEstimator):
def fit(self, X, y=None):
self.feature_names_in_ = list(X.columns)
return self

def transform(self, X, y=None):
grouped = X.groupby("group")["value"].sum()
grouped.index.name = "group"
return grouped

def get_feature_names_out(self, input_features=None):
return np.asarray(["value_sum"], dtype=object)

X = pd.DataFrame(
{
"value": [1, 2, 3, 4],
"group": ["a", "a", "b", "b"],
},
index=pd.Index([0, 1, 2, 3], name="sample"),
)

union = FeatureUnion([("aggregate", SeriesAggregator())])
union.set_output(transform="pandas")
union.fit(X)

X_trans = union.transform(X)

assert isinstance(X_trans, pd.DataFrame)
assert_array_equal(X_trans.columns, union.get_feature_names_out())
assert X_trans.index.equals(pd.Index(["a", "b"], name="group"))


def test_feature_union_pandas_aligns_index_when_lengths_match():
"""FeatureUnion still aligns to original index when lengths match."""
pd = pytest.importorskip("pandas")

class IdentityWithCustomIndex(TransformerMixin, BaseEstimator):
def fit(self, X, y=None):
self.feature_names_in_ = list(X.columns)
return self

def transform(self, X, y=None):
renamed = [f"copy_{name}" for name in self.feature_names_in_]
df = X.copy()
df.columns = renamed
df.index = pd.Index(np.arange(len(df)) + 100, name="custom")
return df

def get_feature_names_out(self, input_features=None):
return np.asarray(
[f"copy_{name}" for name in self.feature_names_in_], dtype=object
)

X, _ = load_iris(as_frame=True, return_X_y=True)
X_train, X_test = train_test_split(X, random_state=1)

union = FeatureUnion([("identity", IdentityWithCustomIndex())])
union.set_output(transform="pandas")
union.fit(X_train)

X_trans = union.transform(X_test)

assert isinstance(X_trans, pd.DataFrame)
assert_array_equal(X_trans.columns, union.get_feature_names_out())
assert X_trans.index.equals(X_test.index)


def test_feature_union_getitem():
"""Check FeatureUnion.__getitem__ returns expected results."""
scalar = StandardScaler()
Expand Down
23 changes: 21 additions & 2 deletions sklearn/utils/_set_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,30 @@ def _wrap_in_pandas_container(
if isinstance(data_to_wrap, pd.DataFrame):
if columns is not None:
data_to_wrap.columns = columns
if index is not None:
if index is not None and len(index) == len(data_to_wrap):
data_to_wrap.index = index
# Keep the output's own index when lengths differ to preserve
# aggregator-defined indexing and avoid length mismatch errors.
return data_to_wrap

return pd.DataFrame(data_to_wrap, index=index, columns=columns)
if isinstance(data_to_wrap, pd.Series):
dataframe_to_wrap = data_to_wrap.to_frame()
if columns is not None:
dataframe_to_wrap.columns = columns
if index is not None and len(index) == len(dataframe_to_wrap):
dataframe_to_wrap.index = index
# Series outputs mirror the same constraint: mismatched lengths keep
# the Series-provided index so aggregation semantics survive.
return dataframe_to_wrap

dataframe_index = (
index
if index is not None and len(index) == len(data_to_wrap)
else None
)
# Ignore a mismatched index for ndarray outputs; pandas would otherwise
# raise due to unequal lengths, and downstream reducers may shorten rows.
return pd.DataFrame(data_to_wrap, index=dataframe_index, columns=columns)


def _get_output_config(method, estimator=None):
Expand Down
45 changes: 45 additions & 0 deletions sklearn/utils/tests/test_set_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,51 @@ def test__wrap_in_pandas_container_dense_update_columns_and_index():
assert_array_equal(new_df.index, new_index)


def test__wrap_in_pandas_container_preserve_index_on_length_mismatch_dataframe():
"""Do not overwrite DataFrame index when length differs."""
pd = pytest.importorskip("pandas")
data = pd.DataFrame({"a": [1, 2]}, index=pd.Index([5, 6], name="agg"))
mismatched_index = pd.Index([0, 1, 2])

wrapped = _wrap_in_pandas_container(data, columns=None, index=mismatched_index)

assert wrapped.index.equals(data.index)
assert_array_equal(wrapped.columns, data.columns)


def test__wrap_in_pandas_container_ndarray_ignore_index_on_length_mismatch():
"""Ignore provided index for ndarray outputs when lengths mismatch."""
pd = pytest.importorskip("pandas")
X = np.asarray([[1, 2], [3, 4]])
mismatched_index = pd.Index([0, 1, 2])

wrapped = _wrap_in_pandas_container(
X,
columns=lambda: np.asarray(["f0", "f1"], dtype=object),
index=mismatched_index,
)

assert isinstance(wrapped, pd.DataFrame)
assert_array_equal(wrapped.index.to_numpy(), np.arange(len(X)))
assert_array_equal(wrapped.columns, np.asarray(["f0", "f1"], dtype=object))


def test__wrap_in_pandas_container_align_when_lengths_match():
"""Provided index is still applied when lengths match."""
pd = pytest.importorskip("pandas")
X = np.asarray([[1, 2], [3, 4]])
index = pd.Index([10, 11])

wrapped = _wrap_in_pandas_container(
X,
columns=lambda: np.asarray(["f0", "f1"], dtype=object),
index=index,
)

assert_array_equal(wrapped.index.to_numpy(), index.to_numpy())
assert_array_equal(wrapped.columns, np.asarray(["f0", "f1"], dtype=object))


def test__wrap_in_pandas_container_error_validation():
"""Check errors in _wrap_in_pandas_container."""
X = np.asarray([[1, 0, 3], [0, 0, 1]])
Expand Down