Skip to content

Commit

Permalink
Merge branch 'main' into add_data_checks
Browse files Browse the repository at this point in the history
  • Loading branch information
Sathya98 authored Aug 30, 2024
2 parents c93a709 + 575ea39 commit e0f9759
Show file tree
Hide file tree
Showing 12 changed files with 181 additions and 60 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "tabpfn-client"
version = "0.0.15"
version = "0.0.21"
requires-python = ">=3.10"
dependencies = [
"httpx>=0.24.1",
Expand Down
3 changes: 1 addition & 2 deletions quick_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sklearn.datasets import load_breast_cancer, load_diabetes
from sklearn.model_selection import train_test_split

from tabpfn_client import UserDataClient, init
from tabpfn_client import UserDataClient
from tabpfn_client.estimator import TabPFNClassifier, TabPFNRegressor

logging.basicConfig(level=logging.DEBUG)
Expand All @@ -21,7 +21,6 @@
X, y, test_size=0.33, random_state=42
)

init()
tabpfn = TabPFNClassifier(model="latest_tabpfn_hosted", n_estimators=3)
# print("checking estimator", check_estimator(tabpfn))
tabpfn.fit(X_train[:99], y_train[:99])
Expand Down
77 changes: 56 additions & 21 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def predict(
return result

@staticmethod
def _validate_response(response, method_name, only_version_check=False):
def _validate_response(
response: httpx.Response, method_name, only_version_check=False
):
# If status code is 200, no errors occurred on the server side.
if response.status_code == 200:
return
Expand All @@ -207,11 +209,11 @@ def _validate_response(response, method_name, only_version_check=False):
try:
load = response.json()
except json.JSONDecodeError as e:
logging.error(f"Failed to parse JSON from response in {method_name}: {e}")
logging.info(f"Failed to parse JSON from response in {method_name}: {e}")

# Check if the server requires a newer client version.
if response.status_code == 426:
logger.error(
logger.info(
f"Fail to call {method_name}, response status: {response.status_code}"
)
raise RuntimeError(load.get("detail"))
Expand All @@ -223,20 +225,34 @@ def _validate_response(response, method_name, only_version_check=False):
logger.error(
f"Fail to call {method_name}, response status: {response.status_code}"
)
if (
len(
reponse_split_up := response.text.split(
"The following exception has occurred:"
try:
if (
len(
reponse_split_up := response.text.split(
"The following exception has occurred:"
)
)
)
> 1
):
raise RuntimeError(
f"Fail to call {method_name} with error: {reponse_split_up[1]}"
)
> 1
):
relevant_reponse_text = reponse_split_up[1].split(
"debug_error_string"
)[0]
if "ValueError" in relevant_reponse_text:
# Extract the ValueError message
value_error_msg = relevant_reponse_text.split(
"ValueError. Arguments: ("
)[1].split(",)")[0]
# Remove extra quotes and spaces
value_error_msg = value_error_msg.strip("'")
# Raise the ValueError with the extracted message
raise ValueError(value_error_msg)
raise RuntimeError(relevant_reponse_text)
except Exception as e:
if isinstance(e, (ValueError, RuntimeError)):
raise e
raise RuntimeError(
f"Fail to call {method_name} with error: {response.status_code} and reason: "
f"{response.reason_phrase}"
f"Fail to call {method_name} with error: {response.status_code}, reason: "
f"{response.reason_phrase} and text: {response.text}"
)

def try_connection(self) -> bool:
Expand All @@ -257,7 +273,7 @@ def try_connection(self) -> bool:

return found_valid_connection

def try_authenticate(self, access_token) -> bool:
def is_auth_token_outdated(self, access_token) -> bool | None:
"""
Check if the provided access token is valid and return True if successful.
"""
Expand All @@ -267,11 +283,13 @@ def try_authenticate(self, access_token) -> bool:
headers={"Authorization": f"Bearer {access_token}"},
)

self._validate_response(response, "try_authenticate", only_version_check=True)

self._validate_response(
response, "is_auth_token_outdated", only_version_check=True
)
if response.status_code == 200:
is_authenticated = True

elif response.status_code == 403:
is_authenticated = None
return is_authenticated

def validate_email(self, email: str) -> tuple[bool, str]:
Expand Down Expand Up @@ -332,7 +350,7 @@ def register(

response = self.httpx_client.post(
self.server_endpoints.register.path,
params={
json={
"email": email,
"password": password,
"password_confirm": password_confirm,
Expand All @@ -349,7 +367,8 @@ def register(
is_created = False
message = response.json()["detail"]

return is_created, message
access_token = response.json()["token"] if is_created else None
return is_created, message, access_token

def login(self, email: str, password: str) -> tuple[str, str]:
"""
Expand Down Expand Up @@ -418,6 +437,22 @@ def send_reset_password_email(self, email: str) -> tuple[bool, str]:
message = response.json()["detail"]
return sent, message

def send_verification_email(self, access_token: str) -> tuple[bool, str]:
"""
Let the server send an email for verifying the email.
"""
response = self.httpx_client.post(
self.server_endpoints.send_verification_email.path,
headers={"Authorization": f"Bearer {access_token}"},
)
if response.status_code == 200:
sent = True
message = response.json()["message"]
else:
sent = False
message = response.json()["detail"]
return sent, message

def retrieve_greeting_messages(self) -> list[str]:
"""
Retrieve greeting messages that are new for the user.
Expand Down
7 changes: 6 additions & 1 deletion tabpfn_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@ def init(use_server=True):

is_valid_token_set = user_auth_handler.try_reuse_existing_token()

if is_valid_token_set:
if isinstance(is_valid_token_set, bool) and is_valid_token_set:
PromptAgent.prompt_reusing_existing_token()
elif (
isinstance(is_valid_token_set, tuple) and is_valid_token_set[1] is not None
):
print("Your email is not verified. Please verify your email to continue...")
PromptAgent.reverify_email(is_valid_token_set[1], user_auth_handler)
else:
if not PromptAgent.prompt_terms_and_cond():
raise RuntimeError(
Expand Down
44 changes: 31 additions & 13 deletions tabpfn_client/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass, asdict

import numpy as np
from tabpfn_client import init
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.utils.validation import check_is_fitted

Expand Down Expand Up @@ -180,13 +181,24 @@ def __init__(
self.remove_outliers = remove_outliers
self.add_fingerprint_features = add_fingerprint_features
self.subsample_samples = subsample_samples
self.last_train_set_uid = None

def _validate_targets_and_classes(self, y) -> np.ndarray:
from sklearn.utils import column_or_1d
from sklearn.utils.multiclass import check_classification_targets

y_ = column_or_1d(y, warn=True)
check_classification_targets(y)

# Get classes and encode before type conversion to guarantee correct class labels.
not_nan_mask = ~np.isnan(y)
self.classes_ = np.unique(y_[not_nan_mask])

def fit(self, X, y):
# assert init() is called
if not config.g_tabpfn_config.is_initialized:
raise RuntimeError(
"tabpfn_client.init() must be called before using TabPFNClassifier"
)
init()

self._validate_targets_and_classes(y)

if config.g_tabpfn_config.use_server:
try:
Expand All @@ -195,7 +207,7 @@ def fit(self, X, y):
), "Only 'latest_tabpfn_hosted' model is supported at the moment for init(use_server=True)"
except AssertionError as e:
print(e)
config.g_tabpfn_config.inference_handler.fit(X, y)
self.last_train_set_uid = config.g_tabpfn_config.inference_handler.fit(X, y)
self.fitted_ = True
else:
raise NotImplementedError(
Expand All @@ -205,12 +217,17 @@ def fit(self, X, y):

def predict(self, X):
probas = self.predict_proba(X)
return np.argmax(probas, axis=1)
y = np.argmax(probas, axis=1)
y = self.classes_.take(np.asarray(y, dtype=int))
return y

def predict_proba(self, X):
check_is_fitted(self)
return config.g_tabpfn_config.inference_handler.predict(
X, task="classification", config=self.get_params()
X,
task="classification",
train_set_uid=self.last_train_set_uid,
config=self.get_params(),
)["probas"]


Expand Down Expand Up @@ -310,13 +327,11 @@ def __init__(
self.cancel_nan_borders = cancel_nan_borders
self.super_bar_dist_averaging = super_bar_dist_averaging
self.subsample_samples = subsample_samples
self.last_train_set_uid = None

def fit(self, X, y):
# assert init() is called
if not config.g_tabpfn_config.is_initialized:
raise RuntimeError(
"tabpfn_client.init() must be called before using TabPFNRegressor"
)
init()

if config.g_tabpfn_config.use_server:
try:
Expand All @@ -325,7 +340,7 @@ def fit(self, X, y):
), "Only 'latest_tabpfn_hosted' model is supported at the moment for init(use_server=True)"
except AssertionError as e:
print(e)
config.g_tabpfn_config.inference_handler.fit(X, y)
self.last_train_set_uid = config.g_tabpfn_config.inference_handler.fit(X, y)
self.fitted_ = True
else:
raise NotImplementedError(
Expand All @@ -347,5 +362,8 @@ def predict(self, X):
def predict_full(self, X):
check_is_fitted(self)
return config.g_tabpfn_config.inference_handler.predict(
X, task="regression", config=self.get_params()
X,
task="regression",
train_set_uid=self.last_train_set_uid,
config=self.get_params(),
)
32 changes: 32 additions & 0 deletions tabpfn_client/prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"):
]
)
choice = cls._choice_with_retries(prompt, ["1", "2"])
email = ""

# Registration
if choice == "1":
Expand Down Expand Up @@ -207,6 +208,37 @@ def prompt_reusing_existing_token(cls):

print(cls.indent(prompt))

@classmethod
def reverify_email(
cls, access_token, user_auth_handler: "UserAuthenticationClient"
):
prompt = "\n".join(
[
"Please check your inbox for the verification email.",
"Note: The email might be in your spam folder or could have expired.",
]
)
print(cls.indent(prompt))
retry_verification = "\n".join(
[
"Do you want to resend email verification link? (y/n): ",
]
)
choice = cls._choice_with_retries(retry_verification, ["y", "n"])
if choice == "y":
# get user email from user_auth_handler and resend verification email
sent, message = user_auth_handler.send_verification_email(access_token)
if not sent:
print(cls.indent("Failed to send verification email: " + message))
else:
print(
cls.indent(
"A verification email has been sent, provided the details are correct!"
)
+ "\n"
)
return

@classmethod
def prompt_retrieved_greeting_messages(cls, greeting_messages: list[str]):
for message in greeting_messages:
Expand Down
10 changes: 5 additions & 5 deletions tabpfn_client/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ endpoints:
methods: [ "POST" ]
description: "User login"

send_verification_email:
path: "/auth/send_verification_email/"
methods: [ "POST" ]
description: "Send verifiaction email or for reverification"

send_reset_password_email:
path: "/auth/send_reset_password_email/"
methods: [ "POST" ]
Expand All @@ -44,11 +49,6 @@ endpoints:
methods: [ "GET" ]
description: "Retrieve new greeting messages"

add_user_information:
path: "/add_user_information/"
methods: [ "POST" ]
description: "Add additional user information to database"

protected_root:
path: "/protected/"
methods: [ "GET" ]
Expand Down
Loading

0 comments on commit e0f9759

Please sign in to comment.