Skip to content

Commit dd72ac3

Browse files
committed
override set_output in the classifier and the regressor
1 parent 2178ec9 commit dd72ac3

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/tabpfn/classifier.py

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import numpy as np
2727
import torch
28+
from sklearn import config_context
2829
from sklearn.base import BaseEstimator, ClassifierMixin, check_is_fitted
2930
from sklearn.preprocessing import LabelEncoder
3031

@@ -374,6 +375,7 @@ def __sklearn_tags__(self) -> Tags:
374375
tags.estimator_type = "classifier"
375376
return tags
376377

378+
@config_context(transform_output="default")
377379
def fit(self, X: XType, y: YType) -> Self:
378380
"""Fit the model.
379381
@@ -518,6 +520,7 @@ def predict(self, X: XType) -> np.ndarray:
518520
y = np.argmax(proba, axis=1)
519521
return self.label_encoder_.inverse_transform(y) # type: ignore
520522

523+
@config_context(transform_output="default")
521524
def predict_proba(self, X: XType) -> np.ndarray:
522525
"""Predict the probabilities of the classes for the provided input samples.
523526

src/tabpfn/regressor.py

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import numpy as np
2727
import torch
28+
from sklearn import config_context
2829
from sklearn.base import (
2930
BaseEstimator,
3031
RegressorMixin,
@@ -380,6 +381,7 @@ def __sklearn_tags__(self) -> Tags:
380381
tags.estimator_type = "regressor"
381382
return tags
382383

384+
@config_context(transform_output="default")
383385
def fit(self, X: XType, y: YType) -> Self:
384386
"""Fit the model.
385387
@@ -556,6 +558,7 @@ def predict(
556558
) -> dict[str, np.ndarray | FullSupportBarDistribution]: ...
557559

558560
# FIXME: improve to not have noqa C901, PLR0912
561+
@config_context(transform_output="default")
559562
def predict( # noqa: C901, PLR0912
560563
self,
561564
X: XType,

0 commit comments

Comments
 (0)