Skip to content

Commit

Permalink
Update submodule
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Aug 16, 2023
1 parent d5ce01c commit 565e4d6
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 57 deletions.
2 changes: 1 addition & 1 deletion sktree/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,7 +1630,7 @@ def _build_tree(
)

if self.data_dims is None:
self.data_dims_ = np.array((1, X.shape[1]), dtype=np.intp)
self.data_dims_ = np.array((1, X.shape[1]), dtype=np.int8)
else:
if np.prod(self.data_dims) != X.shape[1]:
raise RuntimeError(f"Data dimensions {self.data_dims} do not match {X.shape[1]}.")
Expand Down
55 changes: 0 additions & 55 deletions sktree/tree/manifold/_morf_splitter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ cdef class PatchSplitter(BaseObliqueSplitter):
const DOUBLE_t[:, ::1] y,
const DOUBLE_t[:] sample_weight,
const unsigned char[::1] missing_values_in_feature_mask,
# const INT32_t[:] n_categories
) except -1:
BaseObliqueSplitter.init(self, X, y, sample_weight, missing_values_in_feature_mask)

Expand Down Expand Up @@ -198,60 +197,6 @@ cdef class BestPatchSplitter(BaseDensePatchSplitter):
self.feature_weight.base if self.feature_weight is not None else None,
), self.__getstate__())

# def __getstate__(self):
# """Getstate re-implementation, for pickling."""
# d = {}
# # capacity is inferred during the __setstate__ using nodes
# d["criterion"] = self.criterion
# d["max_features"] = self.max_features
# d["min_samples_leaf"] = self.min_samples_leaf
# d["min_weight_leaf"] = self.min_weight_leaf
# d['random_state'] = self.random_state
# d['monotonic_cst'] = self.monotonic_cst.base
# d['min_patch_dims'] = self.min_patch_dims.base
# d['max_patch_dims'] = self.max_patch_dims.base
# d['dim_contiguous'] = self.dim_contiguous.base
# d['data_dims'] = self.data_dims.base
# d['boundary'] = self.boundary
# d['feature_weight'] = self.feature_weight.base
# return d

# def __setstate__(self, d):
# self.criterion = d["criterion"]
# self.max_features = d["max_features"]
# self.min_samples_leaf = d["min_samples_leaf"]
# self.min_weight_leaf = d["min_weight_leaf"]
# self.random_state = d['random_state']
# self.monotonic_cst = d['monotonic_cst']

# self.min_patch_dims = d['min_patch_dims']
# self.max_patch_dims = d['max_patch_dims']
# self.dim_contiguous = d['dim_contiguous']
# self.data_dims = d['data_dims']
# self.boundary = d['boundary']
# self.feature_weight = d['feature_weight']

# whether or not to perform some discontinuous sampling
# if not all(self.dim_contiguous):
# self._discontiguous = True
# else:
# self._discontiguous = False

# # Sparse max_features x n_features projection matrix
# self.proj_mat_weights = vector[vector[DTYPE_t]](self.max_features)
# self.proj_mat_indices = vector[vector[SIZE_t]](self.max_features)

# # initialize state to allow generalization to higher-dimensional tensors
# self.ndim = self.data_dims.shape[0]

# # create a buffer for storing the patch dimensions sampled per projection matrix
# self.patch_dims_buff = np.zeros(self.data_dims.shape[0], dtype=np.intp)
# self.unraveled_patch_point = np.zeros(self.data_dims.shape[0], dtype=np.intp)

# # # initialize a buffer to allow for Fisher-Yates
# self._index_patch_buffer = np.zeros(np.max(self.max_patch_dims), dtype=np.intp)
# self._index_data_buffer = np.zeros(np.max(self.data_dims), dtype=np.intp)

cdef (SIZE_t, SIZE_t) sample_top_left_seed(self) noexcept nogil:
"""Sample the top-left seed for the n-dim patch.
Expand Down
2 changes: 1 addition & 1 deletion sktree/tree/tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_pickle_splitters():
[
ObliqueDecisionTreeClassifier(random_state=12),
ObliqueDecisionTreeRegressor(random_state=12),
PatchObliqueDecisionTreeClassifier(random_state=12),
# PatchObliqueDecisionTreeClassifier(random_state=12),
PatchObliqueDecisionTreeRegressor(random_state=12),
]
)
Expand Down

0 comments on commit 565e4d6

Please sign in to comment.