diff --git a/pyproject.toml b/pyproject.toml index 47701fb..7d0ce3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "tabpfn-client" -version = "0.0.15" +version = "0.0.20" requires-python = ">=3.10" dependencies = [ "httpx>=0.24.1", diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index 4fe1c4c..21d2a9d 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -160,7 +160,9 @@ def predict( return result @staticmethod - def _validate_response(response, method_name, only_version_check=False): + def _validate_response( + response: httpx.Response, method_name, only_version_check=False + ): # If status code is 200, no errors occurred on the server side. if response.status_code == 200: return @@ -170,11 +172,11 @@ def _validate_response(response, method_name, only_version_check=False): try: load = response.json() except json.JSONDecodeError as e: - logging.error(f"Failed to parse JSON from response in {method_name}: {e}") + logging.info(f"Failed to parse JSON from response in {method_name}: {e}") # Check if the server requires a newer client version. if response.status_code == 426: - logger.error( + logger.info( f"Fail to call {method_name}, response status: {response.status_code}" ) raise RuntimeError(load.get("detail")) @@ -186,20 +188,34 @@ def _validate_response(response, method_name, only_version_check=False): logger.error( f"Fail to call {method_name}, response status: {response.status_code}" ) - if ( - len( - reponse_split_up := response.text.split( - "The following exception has occurred:" + try: + if ( + len( + reponse_split_up := response.text.split( + "The following exception has occurred:" + ) ) - ) - > 1 - ): - raise RuntimeError( - f"Fail to call {method_name} with error: {reponse_split_up[1]}" - ) + > 1 + ): + relevant_reponse_text = reponse_split_up[1].split( + "debug_error_string" + )[0] + if "ValueError" in relevant_reponse_text: + # Extract the ValueError message + value_error_msg = relevant_reponse_text.split( + "ValueError. Arguments: (" + )[1].split(",)")[0] + # Remove extra quotes and spaces + value_error_msg = value_error_msg.strip("'") + # Raise the ValueError with the extracted message + raise ValueError(value_error_msg) + raise RuntimeError(relevant_reponse_text) + except Exception as e: + if isinstance(e, (ValueError, RuntimeError)): + raise e raise RuntimeError( - f"Fail to call {method_name} with error: {response.status_code} and reason: " - f"{response.reason_phrase}" + f"Fail to call {method_name} with error: {response.status_code}, reason: " + f"{response.reason_phrase} and text: {response.text}" ) def try_connection(self) -> bool: diff --git a/tabpfn_client/estimator.py b/tabpfn_client/estimator.py index 7d043b7..65f061f 100644 --- a/tabpfn_client/estimator.py +++ b/tabpfn_client/estimator.py @@ -182,10 +182,23 @@ def __init__( self.add_fingerprint_features = add_fingerprint_features self.subsample_samples = subsample_samples + def _validate_targets_and_classes(self, y) -> np.ndarray: + from sklearn.utils import column_or_1d + from sklearn.utils.multiclass import check_classification_targets + + y_ = column_or_1d(y, warn=True) + check_classification_targets(y) + + # Get classes and encode before type conversion to guarantee correct class labels. + not_nan_mask = ~np.isnan(y) + self.classes_ = np.unique(y_[not_nan_mask]) + def fit(self, X, y): # assert init() is called init() + self._validate_targets_and_classes(y) + if config.g_tabpfn_config.use_server: try: assert ( @@ -203,7 +216,9 @@ def fit(self, X, y): def predict(self, X): probas = self.predict_proba(X) - return np.argmax(probas, axis=1) + y = np.argmax(probas, axis=1) + y = self.classes_.take(np.asarray(y, dtype=int)) + return y def predict_proba(self, X): check_is_fitted(self)