From c6aa6661153305f86ba009541ea997a4b77731e1 Mon Sep 17 00:00:00 2001 From: "Liam, Shi Bin Hoo" <44376667+liam-sbhoo@users.noreply.github.com> Date: Fri, 20 Sep 2024 14:56:46 +0200 Subject: [PATCH] Remove outdated assertion on model support (#36) * Remove outdated assertion on model support * Add supports for different models, fix unnecessary welcome message. * Fix formatting * Fix test case * Rename model config * Fix arguments, add test cases * Ruff --- tabpfn_client/config.py | 11 +++-- tabpfn_client/estimator.py | 42 +++++++++++++++---- .../tests/unit/test_tabpfn_classifier.py | 4 -- .../tests/unit/test_tabpfn_regressor.py | 15 +++++-- 4 files changed, 53 insertions(+), 19 deletions(-) diff --git a/tabpfn_client/config.py b/tabpfn_client/config.py index cba80ed..259432e 100644 --- a/tabpfn_client/config.py +++ b/tabpfn_client/config.py @@ -7,8 +7,8 @@ class TabPFNConfig: - is_initialized = None - use_server = None + is_initialized = False + use_server = False user_auth_handler = None inference_handler = None @@ -21,6 +21,10 @@ def init(use_server=True): use_server = use_server global g_tabpfn_config + if g_tabpfn_config.is_initialized: + # Only do the following if the initialization has not been done yet + return + if use_server: service_client = ServiceClient() user_auth_handler = UserAuthenticationClient(service_client) @@ -60,7 +64,8 @@ def init(use_server=True): g_tabpfn_config.inference_handler = InferenceClient(service_client) else: - g_tabpfn_config.use_server = False + raise RuntimeError("Local inference is not supported yet.") + # g_tabpfn_config.use_server = False g_tabpfn_config.is_initialized = True diff --git a/tabpfn_client/estimator.py b/tabpfn_client/estimator.py index 7829c39..290600c 100644 --- a/tabpfn_client/estimator.py +++ b/tabpfn_client/estimator.py @@ -232,9 +232,17 @@ def predict_proba(self, X): class TabPFNRegressor(BaseEstimator, RegressorMixin): + _AVAILABLE_MODELS = [ + "default", + "2noar4o2", + "5wof9ojf", + "09gpqh39", + "wyl4o83o", + ] + def __init__( self, - model: str = "latest_tabpfn_hosted", + model: str = "default", n_estimators: int = 8, preprocess_transforms: Tuple[PreprocessorConfig, ...] = ( PreprocessorConfig( @@ -310,6 +318,9 @@ def __init__( If in 0 to 1, the value is viewed as a fraction of the training set size. """ + if model not in self._AVAILABLE_MODELS: + raise ValueError(f"Invalid model name: {model}") + self.model = model self.n_estimators = n_estimators self.preprocess_transforms = preprocess_transforms @@ -334,12 +345,6 @@ def fit(self, X, y): init() if config.g_tabpfn_config.use_server: - try: - assert ( - self.model == "latest_tabpfn_hosted" - ), "Only 'latest_tabpfn_hosted' model is supported at the moment for init(use_server=True)" - except AssertionError as e: - print(e) self.last_train_set_uid = config.g_tabpfn_config.inference_handler.fit(X, y) self.fitted_ = True else: @@ -361,9 +366,30 @@ def predict(self, X): def predict_full(self, X): check_is_fitted(self) + + estimator_param = self.get_params() + if "model" in estimator_param: + # replace model by model_path since in TabPFN defines model as model_path + estimator_param["model_path"] = self._model_name_to_path( + estimator_param.pop("model") + ) + return config.g_tabpfn_config.inference_handler.predict( X, task="regression", train_set_uid=self.last_train_set_uid, - config=self.get_params(), + config=estimator_param, ) + + @classmethod + def list_available_models(cls) -> list[str]: + return cls._AVAILABLE_MODELS + + def _model_name_to_path(self, model_name: str) -> str: + base_path = "/home/venv/lib/python3.9/site-packages/tabpfn/model_cache/model_hans_regression" + if model_name == "default": + return f"{base_path}.ckpt" + elif model_name in self._AVAILABLE_MODELS: + return f"{base_path}_{model_name}.ckpt" + else: + raise ValueError(f"Invalid model name: {model_name}") diff --git a/tabpfn_client/tests/unit/test_tabpfn_classifier.py b/tabpfn_client/tests/unit/test_tabpfn_classifier.py index f6cd64f..863d14b 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/unit/test_tabpfn_classifier.py @@ -58,10 +58,6 @@ def test_init_remote_classifier( mock_server.endpoints.retrieve_greeting_messages.path ).respond(200, json={"messages": []}) - mock_server.router.get(mock_server.endpoints.protected_root.path).respond( - 200, json={"message": "Welcome to the protected zone, user!"} - ) - mock_predict_response = [[1, 0.0], [0.9, 0.1], [0.01, 0.99]] predict_route = mock_server.router.post(mock_server.endpoints.predict.path) predict_route.respond(200, json={"classification": mock_predict_response}) diff --git a/tabpfn_client/tests/unit/test_tabpfn_regressor.py b/tabpfn_client/tests/unit/test_tabpfn_regressor.py index ff5c594..2b5215e 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_regressor.py +++ b/tabpfn_client/tests/unit/test_tabpfn_regressor.py @@ -56,10 +56,6 @@ def test_init_remote_regressor( mock_server.endpoints.retrieve_greeting_messages.path ).respond(200, json={"messages": []}) - mock_server.router.get(mock_server.endpoints.protected_root.path).respond( - 200, json={"message": "Welcome to the protected zone, user!"} - ) - mock_predict_response = { "mean": [100, 200, 300], "median": [110, 210, 310], @@ -87,6 +83,17 @@ def test_init_remote_regressor( "check that n_estimators is passed to the server", ) + def test_valid_model_config(self): + # Test with valid model configuration + model_name = TabPFNRegressor.list_available_models()[0] + valid_config = TabPFNRegressor(model=model_name) + self.assertEqual(valid_config.model, model_name) + + def test_invalid_model_config(self): + # Test with invalid model configuration + with self.assertRaises(ValueError): + TabPFNRegressor(model="invalid_model_name") + @with_mock_server() def test_reuse_saved_access_token(self, mock_server): # mock connection and authentication