Skip to content

Commit 3b31425

Browse files
authored
Fix mypy issue by adding plugin, add tests (#534)
* Fix mypy issue by adding plugin, add tests * Add test for sklearn model wrapper find_best method * Small fix, add [all] option to extras install * Additional fixes, test
1 parent cbe9eeb commit 3b31425

File tree

10 files changed

+862
-722
lines changed

10 files changed

+862
-722
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ repos:
2828
types_or: [python, jupyter]
2929

3030
- repo: https://github.com/pre-commit/mirrors-mypy
31-
rev: v1.6.1
31+
rev: v1.7.1
3232
hooks:
3333
- id: mypy
3434
entry: python3 -m mypy --config-file pyproject.toml

cyclops/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
register_model(name="sgd_regressor", model_type="static")(SGDRegressor)
4141
register_model("rf_classifier", model_type="static")(RandomForestClassifier)
4242
register_model("logistic_regression", model_type="static")(LogisticRegression)
43-
register_model("mlp", model_type="static")(MLPClassifier)
43+
register_model("mlp_classifier", model_type="static")(MLPClassifier)
4444
if XGBClassifier is not None:
4545
register_model("xgb_classifier", model_type="static")(XGBClassifier)
4646
if DenseNet is not None:

cyclops/models/catalog.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def create_model(
224224
raise RuntimeError(_xgboost_unavailable_message)
225225
if model_name in ["densenet", "resnet"]:
226226
raise RuntimeError(_torchxrayvision_unavailable_message)
227-
if model_name in ["gru", "lstm", "mlp", "rnn"]:
227+
if model_name in ["gru", "lstm", "mlp_pt", "rnn"]:
228228
raise RuntimeError(_torch_unavailable_message)
229229
similar_keys_list: List[str] = get_close_matches(
230230
model_name,

cyclops/models/wrappers/sk_model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def find_best( # noqa: PLR0912, PLR0915
187187
if isinstance(X, (Dataset, DatasetDict)):
188188
if feature_columns is None:
189189
raise ValueError(
190-
"Missing target columns 'target_columns'. Please provide \
190+
"Missing target columns 'feature_columns'. Please provide \
191191
the name of feature columns when using a \
192192
Hugging Face dataset as the input.",
193193
)
@@ -336,10 +336,11 @@ def find_best( # noqa: PLR0912, PLR0915
336336
)
337337
clf.fit(X, y)
338338

339-
for key, value in clf["clf"].best_params_.items():
339+
if isinstance(clf, Pipeline):
340+
clf = clf["clf"]
341+
for key, value in clf.best_params_.items():
340342
LOGGER.info("Best %s: %s", key, value)
341-
342-
self.model_ = clf["clf"].best_estimator_
343+
self.model_ = clf.best_estimator_
343344

344345
return self
345346

poetry.lock

Lines changed: 736 additions & 712 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ black = "^22.1.0"
9898
pytest-cov = "^3.0.0"
9999
codecov = "^2.1.13"
100100
nbstripout = "^0.6.1"
101-
mypy = "^1.0.0"
101+
mypy = "^1.7.0"
102102
ruff = "^0.1.0"
103103
nbqa = { version = "^1.7.0", extras = ["toolchain"] }
104104
cycquery = "^0.1.2" # used for integration test
@@ -146,8 +146,10 @@ monai = ["torch", "monai"]
146146
xgboost = ["xgboost"]
147147
alibi = ["llvmlite", "alibi"]
148148
alibi-detect = ["torch", "llvmlite", "alibi-detect"]
149+
all = ["torch", "torchvision", "torchxrayvision", "llvmlite", "monai", "xgboost", "alibi", "alibi-detect"]
149150

150151
[tool.mypy]
152+
plugins = ["numpy.typing.mypy_plugin"]
151153
ignore_missing_imports = true
152154
install_types = true
153155
pretty = true
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Tests for scikit-learn model wrapper."""
2+
3+
import pandas as pd
4+
from datasets import Dataset
5+
from sklearn.datasets import load_diabetes
6+
7+
from cyclops.models import create_model
8+
from cyclops.models.wrappers import SKModel
9+
10+
11+
def test_find_best_grid_search():
12+
"""Test find_best method with grid search."""
13+
parameters = {"C": [1], "l1_ratio": [0.5]}
14+
X, y = load_diabetes(return_X_y=True)
15+
metric = "accuracy"
16+
method = "grid"
17+
18+
model = create_model("logistic_regression", penalty="elasticnet", solver="saga")
19+
best_estimator = model.find_best(
20+
parameters=parameters,
21+
X=X,
22+
y=y,
23+
metric=metric,
24+
method=method,
25+
)
26+
assert isinstance(best_estimator, SKModel)
27+
28+
29+
def test_find_best_random_search():
30+
"""Test find_best method with random search."""
31+
parameters = {"alpha": [0.001], "hidden_layer_sizes": [10]}
32+
X, y = load_diabetes(return_X_y=True)
33+
metric = "accuracy"
34+
method = "random"
35+
36+
model = create_model("mlp_classifier", early_stopping=True)
37+
best_estimator = model.find_best(
38+
parameters=parameters,
39+
X=X,
40+
y=y,
41+
metric=metric,
42+
method=method,
43+
)
44+
assert isinstance(best_estimator, SKModel)
45+
46+
47+
def test_find_best_hf_dataset_input():
48+
"""Test find_best method with huggingface dataset input."""
49+
parameters = {"alpha": [0.001], "hidden_layer_sizes": [10]}
50+
data = load_diabetes(as_frame=True)
51+
X, y = data["data"], data["target"]
52+
X_y = pd.concat([X, y], axis=1)
53+
features_names = data["feature_names"]
54+
dataset = Dataset.from_pandas(X_y)
55+
metric = "accuracy"
56+
method = "random"
57+
58+
model = create_model("mlp_classifier", early_stopping=True)
59+
best_estimator = model.find_best(
60+
parameters=parameters,
61+
X=dataset,
62+
metric=metric,
63+
method=method,
64+
feature_columns=features_names,
65+
target_columns="target",
66+
)
67+
assert isinstance(best_estimator, SKModel)

tests/cyclops/models/wrappers/test_utils.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,53 @@
55
import torch
66
from datasets import Dataset
77

8-
from cyclops.models.wrappers.utils import DatasetColumn, to_numpy, to_tensor
8+
from cyclops.models.wrappers.utils import (
9+
DatasetColumn,
10+
get_params,
11+
set_params,
12+
to_numpy,
13+
to_tensor,
14+
)
15+
16+
17+
def test_set_params():
18+
"""Test set_params function."""
19+
20+
class ExampleClass:
21+
"""Example class for testing."""
22+
23+
def __init__(self, param1, param2, param3):
24+
"""Initialize the class."""
25+
self.param1 = param1
26+
self.param2 = param2
27+
self.param3 = param3
28+
29+
params = {"param1": 10, "param2": "hello", "param3": True}
30+
example_class = ExampleClass(1, "world", False)
31+
set_params(example_class, **params)
32+
assert example_class.param1 == 10
33+
assert example_class.param2 == "hello"
34+
assert example_class.param3 is True
35+
36+
37+
def test_get_params():
38+
"""Test get_params function."""
39+
40+
class ExampleClass:
41+
"""Example class for testing."""
42+
43+
def __init__(self, param1, param2, param3):
44+
"""Initialize the class."""
45+
self.param1 = param1
46+
self.param2 = param2
47+
self.param3 = param3
48+
49+
result = get_params(ExampleClass(10, "hello", True))
50+
assert isinstance(result, dict)
51+
assert len(result) == 3
52+
assert result["param1"] == 10
53+
assert result["param2"] == "hello"
54+
assert result["param3"] is True
955

1056

1157
@pytest.mark.integration_test()

tests/cyclops/tasks/test_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class TestBinaryTabularClassificationTask(TestCase):
1919

2020
def setUp(self):
2121
"""Set up for testing."""
22-
self.model_name = "mlp"
22+
self.model_name = "mlp_classifier"
2323
self.model = create_model(self.model_name)
2424
self.test_task = BinaryTabularClassificationTask(
2525
{self.model_name: self.model},

0 commit comments

Comments
 (0)