Skip to content

Commit

Permalink
Fix unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Sep 9, 2024
1 parent 7221741 commit e0dc4a5
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions treeple/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sklearn.neighbors import NearestNeighbors
from sklearn.utils.validation import check_is_fitted, validate_data

from treeple.tree import DecisionTreeClassifier
from treeple.tree._neighbors import _compute_distance_matrix, compute_forest_similarity_matrix


Expand All @@ -31,13 +32,19 @@ class NearestNeighborsMetaEstimator(BaseEstimator, MetaEstimatorMixin):
The number of parallel jobs to run for neighbors, by default None.
"""

def __init__(self, estimator, n_neighbors=5, radius=1.0, algorithm="auto", n_jobs=None):
def __init__(self, estimator=None, n_neighbors=5, radius=1.0, algorithm="auto", n_jobs=None):
self.estimator = estimator
self.n_neighbors = n_neighbors
self.algorithm = algorithm
self.radius = radius
self.n_jobs = n_jobs

def get_estimator(self):
if self.estimator is not None:
return DecisionTreeClassifier(random_state=0)
else:
return copy(self.estimator)

Check warning on line 46 in treeple/neighbors.py

View check run for this annotation

Codecov / codecov/patch

treeple/neighbors.py#L46

Added line #L46 was not covered by tests

def fit(self, X, y=None):
"""Fit the nearest neighbors estimator from the training dataset.
Expand All @@ -58,7 +65,7 @@ def fit(self, X, y=None):
"""
X, y = validate_data(self, X, y, accept_sparse="csc")

self.estimator_ = copy(self.estimator)
self.estimator_ = self.get_estimator()
try:
check_is_fitted(self.estimator_)
except NotFittedError:
Expand Down

0 comments on commit e0dc4a5

Please sign in to comment.