-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Preserving dtypes for DataFrame output by transformers that do not modify the input values
User request
It would be nice to optionally preserve the dtypes of the input using pandas output for transformers.
Dtypes can contain information relevant for later steps of the analyses.
E.g. if I include pd.categorical columns to represent ordinal data and then select features using a sklearn transformer the columns will lose their categorical dtype. This means I lose important information for later analyses steps.
This is not only relevant for the categorical dtypes, but could expand to others dtypes (existing, future and custom).
Furthermore, this would allow to sequentially use ColumnTransformer while preserving the dtypes.
Current behavior (minimal illustration):
import numpy as np
from sklearn.datasets import load_iris
from sklearn.feature_selection import SelectKBest
from sklearn.feature_selection import chi2
X, y = load_iris(return_X_y=True, as_frame=True)
X = X.astype(
{
"petal width (cm)": np.float16,
"petal length (cm)": np.float16,
}
)
X["cat"] = y.astype("category")
selector = SelectKBest(chi2, k=2)
selector.set_output(transform="pandas")
X_out = selector.fit_transform(X, y)
print(X_out.dtypes)Output (using sklearn version '1.2.dev0'):
petal length (cm) float64
cat float64
dtype: object
The output shows that both the category and np.float16 are converted to np.float64 in the dataframe output.
Proposed solution (specification)
Add a safe, opt-in mechanism to preserve input pandas dtypes when set_output(transform="pandas") is used, but only for transformers that do not modify input values and whose output columns map directly to input column names (subset/reorder).
Design:
- Introduce a global configuration flag:
preserve_output_dtypes: bool = False, configurable viaset_config(...)andconfig_context(...). Default remains unchanged (False). - When
transform_output == "pandas"andpreserve_output_dtypes == True, attempt dtype preservation only when:original_inputis apd.DataFrame.estimator.get_feature_names_out()returns column names all present inoriginal_input.columns(subset/reorder; no new columns).- The output values are unchanged relative to
original_input[columns](element-wise equality).
- Otherwise, skip preservation silently (retain current behavior).
Implementation outline (minimal changes):
sklearn/_config.py- Add
preserve_output_dtypesto the global config and expose it viaset_config/config_context.
- Add
sklearn/utils/_set_output.py- Update
_get_output_config(method, estimator)to includepreserve_dtypesfrom global config. - Extend
_wrap_in_pandas_container(...)signature to acceptoriginal_input=None, preserve_dtypes=False. - After constructing the pandas DataFrame (or updating columns/index in-place), if preservation preconditions are met, cast per-column dtypes using the original input DataFrame:
subset_df = original_input[columns]- Verify
np.array_equal(df.to_numpy(), subset_df.to_numpy()); if false, skip casting. - For each position
i, castdf.iloc[:, i] = df.iloc[:, i].astype(subset_df.dtypes.iloc[i])inside a try/except.
- Update
_wrap_data_with_container(...)to passoriginal_inputandpreserve_dtypesto_wrap_in_pandas_container.
- Update
Dtype handling details:
- Preserve
CategoricalDtype(including ordered categories) and pandas extension dtypes (e.g.,StringDtype,Int64Dtype) viaSeries.astype(dtype). - Preserve numeric precision (e.g.,
float16,int8) by re-casting after DataFrame construction, only when values are unchanged.
Edge cases:
- New columns created by the transformer: skip preservation.
- Duplicate column names: cast by position rather than name to avoid ambiguity.
- Transformed values (e.g., scaling/encoding): skip preservation.
- Sparse outputs: still unsupported for pandas output.
- Non-DataFrame input: skip preservation.
- Columns callable failure or
None: skip preservation.
Testing plan:
- Add unit tests in
sklearn/utils/tests/test_set_output.pyverifying:- Preservation for subset/reorder selectors (including mixed dtypes and
CategoricalDtype). - No preservation when values are modified (e.g.,
StandardScaler). - No preservation when new columns are introduced (e.g.,
OneHotEncoder). - Handling of duplicate column names via positional casting.
- Preservation for subset/reorder selectors (including mixed dtypes and
This approach provides the requested ability to preserve statistically relevant dtype information in the pandas output for applicable transformers, while ensuring safety and avoiding incorrect casts when transformations alter values.