Skip to content

Commit

Permalink
Store last train set uid in classifier/regressor class (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidotte committed Aug 19, 2024
1 parent 7af0659 commit 1e6f73f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
16 changes: 12 additions & 4 deletions tabpfn_client/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def __init__(
self.remove_outliers = remove_outliers
self.add_fingerprint_features = add_fingerprint_features
self.subsample_samples = subsample_samples
self.last_train_set_uid = None

def _validate_targets_and_classes(self, y) -> np.ndarray:
from sklearn.utils import column_or_1d
Expand All @@ -206,7 +207,7 @@ def fit(self, X, y):
), "Only 'latest_tabpfn_hosted' model is supported at the moment for init(use_server=True)"
except AssertionError as e:
print(e)
config.g_tabpfn_config.inference_handler.fit(X, y)
self.last_train_set_uid = config.g_tabpfn_config.inference_handler.fit(X, y)
self.fitted_ = True
else:
raise NotImplementedError(
Expand All @@ -223,7 +224,10 @@ def predict(self, X):
def predict_proba(self, X):
check_is_fitted(self)
return config.g_tabpfn_config.inference_handler.predict(
X, task="classification", config=self.get_params()
X,
task="classification",
train_set_uid=self.last_train_set_uid,
config=self.get_params(),
)["probas"]


Expand Down Expand Up @@ -323,6 +327,7 @@ def __init__(
self.cancel_nan_borders = cancel_nan_borders
self.super_bar_dist_averaging = super_bar_dist_averaging
self.subsample_samples = subsample_samples
self.last_train_set_uid = None

def fit(self, X, y):
# assert init() is called
Expand All @@ -335,7 +340,7 @@ def fit(self, X, y):
), "Only 'latest_tabpfn_hosted' model is supported at the moment for init(use_server=True)"
except AssertionError as e:
print(e)
config.g_tabpfn_config.inference_handler.fit(X, y)
self.last_train_set_uid = config.g_tabpfn_config.inference_handler.fit(X, y)
self.fitted_ = True
else:
raise NotImplementedError(
Expand All @@ -357,5 +362,8 @@ def predict(self, X):
def predict_full(self, X):
check_is_fitted(self)
return config.g_tabpfn_config.inference_handler.predict(
X, task="regression", config=self.get_params()
X,
task="regression",
train_set_uid=self.last_train_set_uid,
config=self.get_params(),
)
15 changes: 10 additions & 5 deletions tabpfn_client/service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,25 @@ class InferenceClient(ServiceClientWrapper):

def __init__(self, service_client=ServiceClient()):
super().__init__(service_client)
self.last_train_set_uid = None

def fit(self, X, y) -> None:
def fit(self, X, y) -> str:
if not self.service_client.is_initialized:
raise RuntimeError(
"Dear TabPFN User, please initialize the client first by verifying your E-mail address sent to your registered E-mail account."
"Please Note: The email verification token expires in 30 minutes."
)

self.last_train_set_uid = self.service_client.upload_train_set(X, y)
return self.service_client.upload_train_set(X, y)

def predict(self, X, task: Literal["classification", "regression"], config=None):
def predict(
self,
X,
task: Literal["classification", "regression"],
train_set_uid: str,
config=None,
):
return self.service_client.predict(
train_set_uid=self.last_train_set_uid,
train_set_uid=train_set_uid,
x_test=X,
tabpfn_config=config,
task=task,
Expand Down
2 changes: 1 addition & 1 deletion tabpfn_client/tabpfn_common_utils

0 comments on commit 1e6f73f

Please sign in to comment.