|
9 | 9 | import pytest
|
10 | 10 | import sklearn.datasets
|
11 | 11 | import torch
|
| 12 | +from sklearn import config_context |
12 | 13 | from sklearn.base import check_is_fitted
|
13 | 14 | from sklearn.pipeline import Pipeline
|
14 | 15 | from sklearn.preprocessing import StandardScaler
|
@@ -310,3 +311,37 @@ def test_get_embeddings(X_y: tuple[np.ndarray, np.ndarray], data_source: str) ->
|
310 | 311 | assert embeddings.shape[0] == n_estimators
|
311 | 312 | assert embeddings.shape[1] == X.shape[0]
|
312 | 313 | assert embeddings.shape[2] == encoder_shape
|
| 314 | + |
| 315 | + |
| 316 | +def test_pandas_output_config(): |
| 317 | + """Test compatibility with sklearn's output configuration settings.""" |
| 318 | + # Generate synthetic classification data |
| 319 | + X, y = sklearn.datasets.make_classification( |
| 320 | + n_samples=100, |
| 321 | + n_features=10, |
| 322 | + random_state=19, |
| 323 | + ) |
| 324 | + |
| 325 | + # Initialize TabPFN |
| 326 | + model = TabPFNClassifier(n_estimators=1, random_state=42) |
| 327 | + |
| 328 | + # Get default predictions |
| 329 | + model.fit(X, y) |
| 330 | + default_pred = model.predict(X) |
| 331 | + default_proba = model.predict_proba(X) |
| 332 | + |
| 333 | + # Test with pandas output |
| 334 | + with config_context(transform_output="pandas"): |
| 335 | + model.fit(X, y) |
| 336 | + pandas_pred = model.predict(X) |
| 337 | + pandas_proba = model.predict_proba(X) |
| 338 | + np.testing.assert_array_equal(default_pred, pandas_pred) |
| 339 | + np.testing.assert_array_almost_equal(default_proba, pandas_proba) |
| 340 | + |
| 341 | + # Test with polars output |
| 342 | + with config_context(transform_output="polars"): |
| 343 | + model.fit(X, y) |
| 344 | + polars_pred = model.predict(X) |
| 345 | + polars_proba = model.predict_proba(X) |
| 346 | + np.testing.assert_array_equal(default_pred, polars_pred) |
| 347 | + np.testing.assert_array_almost_equal(default_proba, polars_proba) |
0 commit comments