diff --git a/pyproject.toml b/pyproject.toml index 7d0ce3b..2394dcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "tabpfn-client" -version = "0.0.20" +version = "0.0.21" requires-python = ">=3.10" dependencies = [ "httpx>=0.24.1", diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index 21d2a9d..81c1c09 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -313,7 +313,7 @@ def register( response = self.httpx_client.post( self.server_endpoints.register.path, - params={ + json={ "email": email, "password": password, "password_confirm": password_confirm, diff --git a/tabpfn_client/estimator.py b/tabpfn_client/estimator.py index 65f061f..7829c39 100644 --- a/tabpfn_client/estimator.py +++ b/tabpfn_client/estimator.py @@ -181,6 +181,7 @@ def __init__( self.remove_outliers = remove_outliers self.add_fingerprint_features = add_fingerprint_features self.subsample_samples = subsample_samples + self.last_train_set_uid = None def _validate_targets_and_classes(self, y) -> np.ndarray: from sklearn.utils import column_or_1d @@ -206,7 +207,7 @@ def fit(self, X, y): ), "Only 'latest_tabpfn_hosted' model is supported at the moment for init(use_server=True)" except AssertionError as e: print(e) - config.g_tabpfn_config.inference_handler.fit(X, y) + self.last_train_set_uid = config.g_tabpfn_config.inference_handler.fit(X, y) self.fitted_ = True else: raise NotImplementedError( @@ -223,7 +224,10 @@ def predict(self, X): def predict_proba(self, X): check_is_fitted(self) return config.g_tabpfn_config.inference_handler.predict( - X, task="classification", config=self.get_params() + X, + task="classification", + train_set_uid=self.last_train_set_uid, + config=self.get_params(), )["probas"] @@ -323,6 +327,7 @@ def __init__( self.cancel_nan_borders = cancel_nan_borders self.super_bar_dist_averaging = super_bar_dist_averaging self.subsample_samples = subsample_samples + self.last_train_set_uid = None def fit(self, X, y): # assert init() is called @@ -335,7 +340,7 @@ def fit(self, X, y): ), "Only 'latest_tabpfn_hosted' model is supported at the moment for init(use_server=True)" except AssertionError as e: print(e) - config.g_tabpfn_config.inference_handler.fit(X, y) + self.last_train_set_uid = config.g_tabpfn_config.inference_handler.fit(X, y) self.fitted_ = True else: raise NotImplementedError( @@ -357,5 +362,8 @@ def predict(self, X): def predict_full(self, X): check_is_fitted(self) return config.g_tabpfn_config.inference_handler.predict( - X, task="regression", config=self.get_params() + X, + task="regression", + train_set_uid=self.last_train_set_uid, + config=self.get_params(), ) diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py index cc3026e..f4954e0 100644 --- a/tabpfn_client/service_wrapper.py +++ b/tabpfn_client/service_wrapper.py @@ -178,20 +178,25 @@ class InferenceClient(ServiceClientWrapper): def __init__(self, service_client=ServiceClient()): super().__init__(service_client) - self.last_train_set_uid = None - def fit(self, X, y) -> None: + def fit(self, X, y) -> str: if not self.service_client.is_initialized: raise RuntimeError( "Dear TabPFN User, please initialize the client first by verifying your E-mail address sent to your registered E-mail account." "Please Note: The email verification token expires in 30 minutes." ) - self.last_train_set_uid = self.service_client.upload_train_set(X, y) + return self.service_client.upload_train_set(X, y) - def predict(self, X, task: Literal["classification", "regression"], config=None): + def predict( + self, + X, + task: Literal["classification", "regression"], + train_set_uid: str, + config=None, + ): return self.service_client.predict( - train_set_uid=self.last_train_set_uid, + train_set_uid=train_set_uid, x_test=X, tabpfn_config=config, task=task, diff --git a/tabpfn_client/tabpfn_common_utils b/tabpfn_client/tabpfn_common_utils index c26d6d9..cb44694 160000 --- a/tabpfn_client/tabpfn_common_utils +++ b/tabpfn_client/tabpfn_common_utils @@ -1 +1 @@ -Subproject commit c26d6d928fdd7600f20a5700b25c75edad573c61 +Subproject commit cb4469425eba995b4cefad1357c020878e1a6d02