Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data checks in client.py #30

Merged
merged 9 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def upload_train_set(self, X, y) -> str:
The unique ID of the train set in the server.

"""

X = common_utils.serialize_to_csv_formatted_bytes(X)
y = common_utils.serialize_to_csv_formatted_bytes(y)

Expand Down
36 changes: 36 additions & 0 deletions tabpfn_client/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
from tabpfn_client import init
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.utils.validation import check_is_fitted
from sklearn.utils import check_consistent_length

from tabpfn_client import config

logger = logging.getLogger(__name__)

MAX_ROWS = 10000
MAX_COLS = 500


@dataclass(eq=True, frozen=True)
class PreprocessorConfig:
Expand Down Expand Up @@ -194,10 +198,16 @@ def _validate_targets_and_classes(self, y) -> np.ndarray:
not_nan_mask = ~np.isnan(y)
self.classes_ = np.unique(y_[not_nan_mask])

@staticmethod
def _validate_data_size(X: np.ndarray, y: np.ndarray | None):
if X.shape[0] != y.shape[0]:
raise ValueError("X and y must have the same number of samples")

def fit(self, X, y):
# assert init() is called
init()

validate_data_size(X, y)
self._validate_targets_and_classes(y)

if config.g_tabpfn_config.use_server:
Expand All @@ -207,6 +217,7 @@ def fit(self, X, y):
), "Only 'latest_tabpfn_hosted' model is supported at the moment for init(use_server=True)"
except AssertionError as e:
print(e)

self.last_train_set_uid = config.g_tabpfn_config.inference_handler.fit(X, y)
self.fitted_ = True
else:
Expand All @@ -223,6 +234,8 @@ def predict(self, X):

def predict_proba(self, X):
check_is_fitted(self)
validate_data_size(X)

return config.g_tabpfn_config.inference_handler.predict(
X,
task="classification",
Expand Down Expand Up @@ -344,6 +357,8 @@ def fit(self, X, y):
# assert init() is called
init()

validate_data_size(X, y)

if config.g_tabpfn_config.use_server:
self.last_train_set_uid = config.g_tabpfn_config.inference_handler.fit(X, y)
self.fitted_ = True
Expand All @@ -366,6 +381,7 @@ def predict(self, X):

def predict_full(self, X):
check_is_fitted(self)
validate_data_size(X)

estimator_param = self.get_params()
if "model" in estimator_param:
Expand Down Expand Up @@ -393,3 +409,23 @@ def _model_name_to_path(self, model_name: str) -> str:
return f"{base_path}_{model_name}.ckpt"
else:
raise ValueError(f"Invalid model name: {model_name}")


def validate_data_size(X: np.ndarray, y: np.ndarray | None = None):
"""
Check the integrity of the training data.
- check if the number of rows between X and y is consistent
if y is not None (ValueError)
- check if the number of rows is less than MAX_ROWS (ValueError)
- check if the number of columns is less than MAX_COLS (ValueError)
"""

# check if the number of samples is consistent (ValueError)
if y is not None:
check_consistent_length(X, y)
liam-sbhoo marked this conversation as resolved.
Show resolved Hide resolved

# length and feature assertions
if X.shape[0] > MAX_ROWS:
raise ValueError(f"The number of rows cannot be more than {MAX_ROWS}.")
if X.shape[1] > MAX_COLS:
raise ValueError(f"The number of columns cannot be more than {MAX_COLS}.")
64 changes: 63 additions & 1 deletion tabpfn_client/tests/unit/test_tabpfn_classifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from unittest.mock import patch
from unittest.mock import patch, MagicMock
import shutil

import numpy as np
Expand All @@ -14,6 +14,7 @@
from tabpfn_client.client import ServiceClient
from tabpfn_client.tests.mock_tabpfn_server import with_mock_server
from tabpfn_client.constants import CACHE_DIR
from tabpfn_client import config


class TestTabPFNClassifierInit(unittest.TestCase):
Expand Down Expand Up @@ -160,3 +161,64 @@ def test_decline_terms_and_cond(self, mock_server, mock_prompt_for_terms_and_con

self.assertRaises(RuntimeError, init, use_server=True)
self.assertTrue(mock_prompt_for_terms_and_cond.called)


class TestTabPFNClassifierInference(unittest.TestCase):
def setUp(self):
# skip init
config.g_tabpfn_config.is_initialized = True

def tearDown(self):
# undo setUp
config.reset()

def test_data_size_check_on_train_with_inconsistent_number_of_samples_raise_error(
self,
):
X = np.random.rand(10, 5)
y = np.random.randint(0, 2, 11)
tabpfn = TabPFNClassifier()

with self.assertRaises(ValueError):
tabpfn.fit(X, y)

def test_data_size_check_on_train_with_oversized_data_raise_error(self):
X = np.random.randn(10001, 501)
y = np.random.randint(0, 2, 10001)

tabpfn = TabPFNClassifier()

# test oversized columns
with self.assertRaises(ValueError):
tabpfn.fit(X[:10], y[:10])

# test oversized rows
with self.assertRaises(ValueError):
tabpfn.fit(X[:, :10], y)

def test_data_size_check_on_predict_with_oversized_data_raise_error(self):
test_X = np.random.randn(10001, 5)
tabpfn = TabPFNClassifier()

# skip fitting
tabpfn.fitted_ = True

# test oversized rows
with self.assertRaises(ValueError):
tabpfn.predict(test_X)

def test_data_check_on_predict_with_valid_data_pass(self):
test_X = np.random.randn(10, 5)
tabpfn = TabPFNClassifier()

# skip fitting
tabpfn.fitted_ = True
tabpfn.classes_ = np.array([0, 1])

# mock prediction
config.g_tabpfn_config.inference_handler = MagicMock()
config.g_tabpfn_config.inference_handler.predict = MagicMock(
return_value={"probas": np.random.rand(10, 2)}
)

tabpfn.predict(test_X)
64 changes: 63 additions & 1 deletion tabpfn_client/tests/unit/test_tabpfn_regressor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
from unittest.mock import patch
from unittest.mock import patch, MagicMock

import shutil
import numpy as np
from sklearn.datasets import load_diabetes
Expand All @@ -13,6 +14,7 @@
from tabpfn_client.client import ServiceClient
from tabpfn_client.tests.mock_tabpfn_server import with_mock_server
from tabpfn_client.constants import CACHE_DIR
from tabpfn_client import config


class TestTabPFNRegressorInit(unittest.TestCase):
Expand Down Expand Up @@ -175,3 +177,63 @@ def test_decline_terms_and_cond(self, mock_server, mock_prompt_for_terms_and_con

self.assertRaises(RuntimeError, init, use_server=True)
self.assertTrue(mock_prompt_for_terms_and_cond.called)


class TestTabPFNRegressorInference(unittest.TestCase):
def setUp(self):
# skip init
config.g_tabpfn_config.is_initialized = True

def tearDown(self):
# undo setUp
config.reset()

def test_data_size_check_on_train_with_inconsistent_number_of_samples_raise_error(
self,
):
X = np.random.rand(10, 5)
y = np.random.rand(11)
tabpfn = TabPFNRegressor()

with self.assertRaises(ValueError):
tabpfn.fit(X, y)

def test_data_size_check_on_train_with_oversized_data_raise_error(self):
X = np.random.randn(10001, 501)
y = np.random.randn(10001)

tabpfn = TabPFNRegressor()

# test oversized columns
with self.assertRaises(ValueError):
tabpfn.fit(X[:10], y[:10])

# test oversized rows
with self.assertRaises(ValueError):
tabpfn.fit(X[:, :10], y)

def test_data_size_check_on_predict_with_oversized_data_raise_error(self):
test_X = np.random.randn(10001, 5)
tabpfn = TabPFNRegressor()

# skip fitting
tabpfn.fitted_ = True

# test oversized rows
with self.assertRaises(ValueError):
tabpfn.predict(test_X)

def test_data_check_on_predict_with_valid_data_pass(self):
test_X = np.random.randn(10, 5)
tabpfn = TabPFNRegressor()

# skip fitting
tabpfn.fitted_ = True

# mock prediction
config.g_tabpfn_config.inference_handler = MagicMock()
config.g_tabpfn_config.inference_handler.predict = MagicMock(
return_value={"mean": np.random.randn(10)}
)

tabpfn.predict(test_X)
Loading