Skip to content

Commit

Permalink
fix class remappings (#31)
Browse files Browse the repository at this point in the history
* fix class remappings

* make parsing even nicer

* fix tests
  • Loading branch information
SamuelGabriel committed Jul 16, 2024
1 parent b14cb11 commit 7af0659
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "tabpfn-client"
version = "0.0.15"
version = "0.0.20"
requires-python = ">=3.10"
dependencies = [
"httpx>=0.24.1",
Expand Down
46 changes: 31 additions & 15 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def predict(
return result

@staticmethod
def _validate_response(response, method_name, only_version_check=False):
def _validate_response(
response: httpx.Response, method_name, only_version_check=False
):
# If status code is 200, no errors occurred on the server side.
if response.status_code == 200:
return
Expand All @@ -170,11 +172,11 @@ def _validate_response(response, method_name, only_version_check=False):
try:
load = response.json()
except json.JSONDecodeError as e:
logging.error(f"Failed to parse JSON from response in {method_name}: {e}")
logging.info(f"Failed to parse JSON from response in {method_name}: {e}")

# Check if the server requires a newer client version.
if response.status_code == 426:
logger.error(
logger.info(
f"Fail to call {method_name}, response status: {response.status_code}"
)
raise RuntimeError(load.get("detail"))
Expand All @@ -186,20 +188,34 @@ def _validate_response(response, method_name, only_version_check=False):
logger.error(
f"Fail to call {method_name}, response status: {response.status_code}"
)
if (
len(
reponse_split_up := response.text.split(
"The following exception has occurred:"
try:
if (
len(
reponse_split_up := response.text.split(
"The following exception has occurred:"
)
)
)
> 1
):
raise RuntimeError(
f"Fail to call {method_name} with error: {reponse_split_up[1]}"
)
> 1
):
relevant_reponse_text = reponse_split_up[1].split(
"debug_error_string"
)[0]
if "ValueError" in relevant_reponse_text:
# Extract the ValueError message
value_error_msg = relevant_reponse_text.split(
"ValueError. Arguments: ("
)[1].split(",)")[0]
# Remove extra quotes and spaces
value_error_msg = value_error_msg.strip("'")
# Raise the ValueError with the extracted message
raise ValueError(value_error_msg)
raise RuntimeError(relevant_reponse_text)
except Exception as e:
if isinstance(e, (ValueError, RuntimeError)):
raise e
raise RuntimeError(
f"Fail to call {method_name} with error: {response.status_code} and reason: "
f"{response.reason_phrase}"
f"Fail to call {method_name} with error: {response.status_code}, reason: "
f"{response.reason_phrase} and text: {response.text}"
)

def try_connection(self) -> bool:
Expand Down
17 changes: 16 additions & 1 deletion tabpfn_client/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,23 @@ def __init__(
self.add_fingerprint_features = add_fingerprint_features
self.subsample_samples = subsample_samples

def _validate_targets_and_classes(self, y) -> np.ndarray:
from sklearn.utils import column_or_1d
from sklearn.utils.multiclass import check_classification_targets

y_ = column_or_1d(y, warn=True)
check_classification_targets(y)

# Get classes and encode before type conversion to guarantee correct class labels.
not_nan_mask = ~np.isnan(y)
self.classes_ = np.unique(y_[not_nan_mask])

def fit(self, X, y):
# assert init() is called
init()

self._validate_targets_and_classes(y)

if config.g_tabpfn_config.use_server:
try:
assert (
Expand All @@ -203,7 +216,9 @@ def fit(self, X, y):

def predict(self, X):
probas = self.predict_proba(X)
return np.argmax(probas, axis=1)
y = np.argmax(probas, axis=1)
y = self.classes_.take(np.asarray(y, dtype=int))
return y

def predict_proba(self, X):
check_is_fitted(self)
Expand Down

0 comments on commit 7af0659

Please sign in to comment.