Skip to content

Optional dtype preservation in pandas output for subset-preserving transformers #52

@rowan-stein

Description

@rowan-stein

Preserving dtypes for DataFrame output by transformers that do not modify the input values

User request

It would be nice to optionally preserve the dtypes of the input using pandas output for transformers.
Dtypes can contain information relevant for later steps of the analyses.
E.g. if I include pd.categorical columns to represent ordinal data and then select features using a sklearn transformer the columns will lose their categorical dtype. This means I lose important information for later analyses steps.
This is not only relevant for the categorical dtypes, but could expand to others dtypes (existing, future and custom).
Furthermore, this would allow to sequentially use ColumnTransformer while preserving the dtypes.

Current behavior (minimal illustration):

import numpy as np
from sklearn.datasets import load_iris
from sklearn.feature_selection import SelectKBest
from sklearn.feature_selection import chi2

X, y = load_iris(return_X_y=True, as_frame=True)
X = X.astype(
   {
       "petal width (cm)": np.float16,
       "petal length (cm)": np.float16,
   }
)
X["cat"] = y.astype("category")

selector = SelectKBest(chi2, k=2)
selector.set_output(transform="pandas")
X_out = selector.fit_transform(X, y)
print(X_out.dtypes)

Output (using sklearn version '1.2.dev0'):

petal length (cm)    float64
cat                  float64
dtype: object

The output shows that both the category and np.float16 are converted to np.float64 in the dataframe output.

Proposed solution (specification)

Add a safe, opt-in mechanism to preserve input pandas dtypes when set_output(transform="pandas") is used, but only for transformers that do not modify input values and whose output columns map directly to input column names (subset/reorder).

Design:

  • Introduce a global configuration flag: preserve_output_dtypes: bool = False, configurable via set_config(...) and config_context(...). Default remains unchanged (False).
  • When transform_output == "pandas" and preserve_output_dtypes == True, attempt dtype preservation only when:
    • original_input is a pd.DataFrame.
    • estimator.get_feature_names_out() returns column names all present in original_input.columns (subset/reorder; no new columns).
    • The output values are unchanged relative to original_input[columns] (element-wise equality).
  • Otherwise, skip preservation silently (retain current behavior).

Implementation outline (minimal changes):

  • sklearn/_config.py
    • Add preserve_output_dtypes to the global config and expose it via set_config/config_context.
  • sklearn/utils/_set_output.py
    • Update _get_output_config(method, estimator) to include preserve_dtypes from global config.
    • Extend _wrap_in_pandas_container(...) signature to accept original_input=None, preserve_dtypes=False.
    • After constructing the pandas DataFrame (or updating columns/index in-place), if preservation preconditions are met, cast per-column dtypes using the original input DataFrame:
      1. subset_df = original_input[columns]
      2. Verify np.array_equal(df.to_numpy(), subset_df.to_numpy()); if false, skip casting.
      3. For each position i, cast df.iloc[:, i] = df.iloc[:, i].astype(subset_df.dtypes.iloc[i]) inside a try/except.
    • Update _wrap_data_with_container(...) to pass original_input and preserve_dtypes to _wrap_in_pandas_container.

Dtype handling details:

  • Preserve CategoricalDtype (including ordered categories) and pandas extension dtypes (e.g., StringDtype, Int64Dtype) via Series.astype(dtype).
  • Preserve numeric precision (e.g., float16, int8) by re-casting after DataFrame construction, only when values are unchanged.

Edge cases:

  • New columns created by the transformer: skip preservation.
  • Duplicate column names: cast by position rather than name to avoid ambiguity.
  • Transformed values (e.g., scaling/encoding): skip preservation.
  • Sparse outputs: still unsupported for pandas output.
  • Non-DataFrame input: skip preservation.
  • Columns callable failure or None: skip preservation.

Testing plan:

  • Add unit tests in sklearn/utils/tests/test_set_output.py verifying:
    • Preservation for subset/reorder selectors (including mixed dtypes and CategoricalDtype).
    • No preservation when values are modified (e.g., StandardScaler).
    • No preservation when new columns are introduced (e.g., OneHotEncoder).
    • Handling of duplicate column names via positional casting.

This approach provides the requested ability to preserve statistically relevant dtype information in the pandas output for applicable transformers, while ensuring safety and avoiding incorrect casts when transformations alter values.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions