Skip to content

Commit

Permalink
Remove outdated assertion on model support (#36)
Browse files Browse the repository at this point in the history
* Remove outdated assertion on model support

* Add supports for different models, fix unnecessary welcome message.

* Fix formatting

* Fix test case

* Rename model config

* Fix arguments, add test cases

* Ruff
  • Loading branch information
liam-sbhoo authored Sep 20, 2024
1 parent 1f33d89 commit c6aa666
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 19 deletions.
11 changes: 8 additions & 3 deletions tabpfn_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@


class TabPFNConfig:
is_initialized = None
use_server = None
is_initialized = False
use_server = False
user_auth_handler = None
inference_handler = None

Expand All @@ -21,6 +21,10 @@ def init(use_server=True):
use_server = use_server
global g_tabpfn_config

if g_tabpfn_config.is_initialized:
# Only do the following if the initialization has not been done yet
return

if use_server:
service_client = ServiceClient()
user_auth_handler = UserAuthenticationClient(service_client)
Expand Down Expand Up @@ -60,7 +64,8 @@ def init(use_server=True):
g_tabpfn_config.inference_handler = InferenceClient(service_client)

else:
g_tabpfn_config.use_server = False
raise RuntimeError("Local inference is not supported yet.")
# g_tabpfn_config.use_server = False

g_tabpfn_config.is_initialized = True

Expand Down
42 changes: 34 additions & 8 deletions tabpfn_client/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,17 @@ def predict_proba(self, X):


class TabPFNRegressor(BaseEstimator, RegressorMixin):
_AVAILABLE_MODELS = [
"default",
"2noar4o2",
"5wof9ojf",
"09gpqh39",
"wyl4o83o",
]

def __init__(
self,
model: str = "latest_tabpfn_hosted",
model: str = "default",
n_estimators: int = 8,
preprocess_transforms: Tuple[PreprocessorConfig, ...] = (
PreprocessorConfig(
Expand Down Expand Up @@ -310,6 +318,9 @@ def __init__(
If in 0 to 1, the value is viewed as a fraction of the training set size.
"""

if model not in self._AVAILABLE_MODELS:
raise ValueError(f"Invalid model name: {model}")

self.model = model
self.n_estimators = n_estimators
self.preprocess_transforms = preprocess_transforms
Expand All @@ -334,12 +345,6 @@ def fit(self, X, y):
init()

if config.g_tabpfn_config.use_server:
try:
assert (
self.model == "latest_tabpfn_hosted"
), "Only 'latest_tabpfn_hosted' model is supported at the moment for init(use_server=True)"
except AssertionError as e:
print(e)
self.last_train_set_uid = config.g_tabpfn_config.inference_handler.fit(X, y)
self.fitted_ = True
else:
Expand All @@ -361,9 +366,30 @@ def predict(self, X):

def predict_full(self, X):
check_is_fitted(self)

estimator_param = self.get_params()
if "model" in estimator_param:
# replace model by model_path since in TabPFN defines model as model_path
estimator_param["model_path"] = self._model_name_to_path(
estimator_param.pop("model")
)

return config.g_tabpfn_config.inference_handler.predict(
X,
task="regression",
train_set_uid=self.last_train_set_uid,
config=self.get_params(),
config=estimator_param,
)

@classmethod
def list_available_models(cls) -> list[str]:
return cls._AVAILABLE_MODELS

def _model_name_to_path(self, model_name: str) -> str:
base_path = "/home/venv/lib/python3.9/site-packages/tabpfn/model_cache/model_hans_regression"
if model_name == "default":
return f"{base_path}.ckpt"
elif model_name in self._AVAILABLE_MODELS:
return f"{base_path}_{model_name}.ckpt"
else:
raise ValueError(f"Invalid model name: {model_name}")
4 changes: 0 additions & 4 deletions tabpfn_client/tests/unit/test_tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ def test_init_remote_classifier(
mock_server.endpoints.retrieve_greeting_messages.path
).respond(200, json={"messages": []})

mock_server.router.get(mock_server.endpoints.protected_root.path).respond(
200, json={"message": "Welcome to the protected zone, user!"}
)

mock_predict_response = [[1, 0.0], [0.9, 0.1], [0.01, 0.99]]
predict_route = mock_server.router.post(mock_server.endpoints.predict.path)
predict_route.respond(200, json={"classification": mock_predict_response})
Expand Down
15 changes: 11 additions & 4 deletions tabpfn_client/tests/unit/test_tabpfn_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ def test_init_remote_regressor(
mock_server.endpoints.retrieve_greeting_messages.path
).respond(200, json={"messages": []})

mock_server.router.get(mock_server.endpoints.protected_root.path).respond(
200, json={"message": "Welcome to the protected zone, user!"}
)

mock_predict_response = {
"mean": [100, 200, 300],
"median": [110, 210, 310],
Expand Down Expand Up @@ -87,6 +83,17 @@ def test_init_remote_regressor(
"check that n_estimators is passed to the server",
)

def test_valid_model_config(self):
# Test with valid model configuration
model_name = TabPFNRegressor.list_available_models()[0]
valid_config = TabPFNRegressor(model=model_name)
self.assertEqual(valid_config.model, model_name)

def test_invalid_model_config(self):
# Test with invalid model configuration
with self.assertRaises(ValueError):
TabPFNRegressor(model="invalid_model_name")

@with_mock_server()
def test_reuse_saved_access_token(self, mock_server):
# mock connection and authentication
Expand Down

0 comments on commit c6aa666

Please sign in to comment.