Skip to content

Commit

Permalink
test: add simple outline to test the predict method in sklearn frontend.
Browse files Browse the repository at this point in the history
the integration with already exisiting strategies requires some work as predict method assume fit called etc.
  • Loading branch information
Ishticode committed Dec 28, 2023
1 parent 56bf012 commit d78fd4a
Showing 1 changed file with 37 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,41 @@
import ivy.functional.frontends.sklearn as sklearn_frontend
import numpy as np
import ivy
from hypothesis import given
import ivy_tests.test_ivy.helpers as helpers


# --- Helpers --- #
# --------------- #


# helper functions
def _get_sklearn_predict(X, y, max_depth):
ivy_clf = sklearn_frontend.tree.DecisionTreeClassifier(max_depth=max_depth)
ivy_clf.fit(X, y)
return ivy_clf.predict
def _get_sklearn_predict(X, y, max_depth, module=None):
clf = module.tree.DecisionTreeClassifier(max_depth=max_depth)
clf.fit(X, y)
return clf.predict


# --- Main --- #
# ------------ #


# todo: integrate with already existing strats and generalize
@given(
X=helpers.array_values(shape=(5, 2), dtype=helpers.get_dtypes("float")),
y=helpers.array_values(shape=(5,), dtype=helpers.get_dtypes("integer")),
)
def test_sklearn_tree_predict(X, y):
try:
import sklearn
except ImportError:
print("sklearn not installed, skipping test_sklearn_tree_predict")
return
sklearn_pred = _get_sklearn_predict(X, y, max_depth=3, module=sklearn)(X)
for fw in helpers.available_frameworks:
ivy.set_backend(fw)
ivy_pred = _get_sklearn_predict(
ivy.array(X), ivy.array(y), max_depth=3, module=sklearn_frontend
)(X)
assert np.allclose(ivy_pred.to_numpy(), sklearn_pred)
ivy.unset_backend()

0 comments on commit d78fd4a

Please sign in to comment.