From 1e6f73f707c7dcd11cbe9f24d734e930d0e492ac Mon Sep 17 00:00:00 2001 From: David Otte <63399116+davidotte@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:03:55 +0200 Subject: [PATCH] Store last train set uid in classifier/regressor class (#33) --- tabpfn_client/estimator.py | 16 ++++++++++++---- tabpfn_client/service_wrapper.py | 15 ++++++++++----- tabpfn_client/tabpfn_common_utils | 2 +- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/tabpfn_client/estimator.py b/tabpfn_client/estimator.py index 65f061f..7829c39 100644 --- a/tabpfn_client/estimator.py +++ b/tabpfn_client/estimator.py @@ -181,6 +181,7 @@ def __init__( self.remove_outliers = remove_outliers self.add_fingerprint_features = add_fingerprint_features self.subsample_samples = subsample_samples + self.last_train_set_uid = None def _validate_targets_and_classes(self, y) -> np.ndarray: from sklearn.utils import column_or_1d @@ -206,7 +207,7 @@ def fit(self, X, y): ), "Only 'latest_tabpfn_hosted' model is supported at the moment for init(use_server=True)" except AssertionError as e: print(e) - config.g_tabpfn_config.inference_handler.fit(X, y) + self.last_train_set_uid = config.g_tabpfn_config.inference_handler.fit(X, y) self.fitted_ = True else: raise NotImplementedError( @@ -223,7 +224,10 @@ def predict(self, X): def predict_proba(self, X): check_is_fitted(self) return config.g_tabpfn_config.inference_handler.predict( - X, task="classification", config=self.get_params() + X, + task="classification", + train_set_uid=self.last_train_set_uid, + config=self.get_params(), )["probas"] @@ -323,6 +327,7 @@ def __init__( self.cancel_nan_borders = cancel_nan_borders self.super_bar_dist_averaging = super_bar_dist_averaging self.subsample_samples = subsample_samples + self.last_train_set_uid = None def fit(self, X, y): # assert init() is called @@ -335,7 +340,7 @@ def fit(self, X, y): ), "Only 'latest_tabpfn_hosted' model is supported at the moment for init(use_server=True)" except AssertionError as e: print(e) - config.g_tabpfn_config.inference_handler.fit(X, y) + self.last_train_set_uid = config.g_tabpfn_config.inference_handler.fit(X, y) self.fitted_ = True else: raise NotImplementedError( @@ -357,5 +362,8 @@ def predict(self, X): def predict_full(self, X): check_is_fitted(self) return config.g_tabpfn_config.inference_handler.predict( - X, task="regression", config=self.get_params() + X, + task="regression", + train_set_uid=self.last_train_set_uid, + config=self.get_params(), ) diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py index cc3026e..f4954e0 100644 --- a/tabpfn_client/service_wrapper.py +++ b/tabpfn_client/service_wrapper.py @@ -178,20 +178,25 @@ class InferenceClient(ServiceClientWrapper): def __init__(self, service_client=ServiceClient()): super().__init__(service_client) - self.last_train_set_uid = None - def fit(self, X, y) -> None: + def fit(self, X, y) -> str: if not self.service_client.is_initialized: raise RuntimeError( "Dear TabPFN User, please initialize the client first by verifying your E-mail address sent to your registered E-mail account." "Please Note: The email verification token expires in 30 minutes." ) - self.last_train_set_uid = self.service_client.upload_train_set(X, y) + return self.service_client.upload_train_set(X, y) - def predict(self, X, task: Literal["classification", "regression"], config=None): + def predict( + self, + X, + task: Literal["classification", "regression"], + train_set_uid: str, + config=None, + ): return self.service_client.predict( - train_set_uid=self.last_train_set_uid, + train_set_uid=train_set_uid, x_test=X, tabpfn_config=config, task=task, diff --git a/tabpfn_client/tabpfn_common_utils b/tabpfn_client/tabpfn_common_utils index c26d6d9..cb44694 160000 --- a/tabpfn_client/tabpfn_common_utils +++ b/tabpfn_client/tabpfn_common_utils @@ -1 +1 @@ -Subproject commit c26d6d928fdd7600f20a5700b25c75edad573c61 +Subproject commit cb4469425eba995b4cefad1357c020878e1a6d02