Skip to content

Commit 2178ec9

Browse files
committed
add failing tests
1 parent 05ab7da commit 2178ec9

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

tests/test_classifier_interface.py

+35
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
import sklearn.datasets
1111
import torch
12+
from sklearn import config_context
1213
from sklearn.base import check_is_fitted
1314
from sklearn.pipeline import Pipeline
1415
from sklearn.preprocessing import StandardScaler
@@ -310,3 +311,37 @@ def test_get_embeddings(X_y: tuple[np.ndarray, np.ndarray], data_source: str) ->
310311
assert embeddings.shape[0] == n_estimators
311312
assert embeddings.shape[1] == X.shape[0]
312313
assert embeddings.shape[2] == encoder_shape
314+
315+
316+
def test_pandas_output_config():
317+
"""Test compatibility with sklearn's output configuration settings."""
318+
# Generate synthetic classification data
319+
X, y = sklearn.datasets.make_classification(
320+
n_samples=100,
321+
n_features=10,
322+
random_state=19,
323+
)
324+
325+
# Initialize TabPFN
326+
model = TabPFNClassifier(n_estimators=1, random_state=42)
327+
328+
# Get default predictions
329+
model.fit(X, y)
330+
default_pred = model.predict(X)
331+
default_proba = model.predict_proba(X)
332+
333+
# Test with pandas output
334+
with config_context(transform_output="pandas"):
335+
model.fit(X, y)
336+
pandas_pred = model.predict(X)
337+
pandas_proba = model.predict_proba(X)
338+
np.testing.assert_array_equal(default_pred, pandas_pred)
339+
np.testing.assert_array_almost_equal(default_proba, pandas_proba)
340+
341+
# Test with polars output
342+
with config_context(transform_output="polars"):
343+
model.fit(X, y)
344+
polars_pred = model.predict(X)
345+
polars_proba = model.predict_proba(X)
346+
np.testing.assert_array_equal(default_pred, polars_pred)
347+
np.testing.assert_array_almost_equal(default_proba, polars_proba)

tests/test_regressor_interface.py

+30
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
import sklearn.datasets
1111
import torch
12+
from sklearn import config_context
1213
from sklearn.base import check_is_fitted
1314
from sklearn.pipeline import Pipeline
1415
from sklearn.preprocessing import StandardScaler
@@ -321,3 +322,32 @@ def test_overflow():
321322

322323
predictions = regressor.predict(X)
323324
assert predictions.shape == (X.shape[0],), "Predictions shape is incorrect"
325+
326+
327+
def test_pandas_output_config():
328+
"""Test compatibility with sklearn's output configuration settings."""
329+
# Generate synthetic regression data
330+
X, y = sklearn.datasets.make_regression(
331+
n_samples=100,
332+
n_features=10,
333+
random_state=19,
334+
)
335+
336+
# Initialize TabPFN
337+
model = TabPFNRegressor(n_estimators=1, random_state=42)
338+
339+
# Get default predictions
340+
model.fit(X, y)
341+
default_pred = model.predict(X)
342+
343+
# Test with pandas output
344+
with config_context(transform_output="pandas"):
345+
model.fit(X, y)
346+
pandas_pred = model.predict(X)
347+
np.testing.assert_array_almost_equal(default_pred, pandas_pred)
348+
349+
# Test with polars output
350+
with config_context(transform_output="polars"):
351+
model.fit(X, y)
352+
polars_pred = model.predict(X)
353+
np.testing.assert_array_almost_equal(default_pred, polars_pred)

0 commit comments

Comments
 (0)