Skip to content

Commit

Permalink
Updated morf and utils code
Browse files Browse the repository at this point in the history
  • Loading branch information
adam2392 committed Aug 29, 2024
1 parent a201ba1 commit 6fe71a5
Show file tree
Hide file tree
Showing 8 changed files with 829 additions and 208 deletions.
6 changes: 0 additions & 6 deletions treeple/tree/_oblique_splitter.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,6 @@ cdef class BaseObliqueSplitter(Splitter):
SplitRecord* split,
) except -1 nogil

cdef inline void fisher_yates_shuffle_memview(
self,
intp_t[::1] indices_to_sample,
intp_t grid_size,
uint32_t* random_state
) noexcept nogil

cdef class ObliqueSplitter(BaseObliqueSplitter):
# The splitter searches in the input space for a linear combination of features and a threshold
Expand Down
80 changes: 50 additions & 30 deletions treeple/tree/_oblique_splitter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ from libcpp.vector cimport vector

from .._lib.sklearn.tree._criterion cimport Criterion
from .._lib.sklearn.tree._utils cimport rand_int, rand_uniform
from ._utils cimport fisher_yates_shuffle


cdef float64_t INFINITY = np.inf
Expand Down Expand Up @@ -46,8 +47,12 @@ cdef class BaseObliqueSplitter(Splitter):
def __setstate__(self, d):
pass

cdef int node_reset(self, intp_t start, intp_t end,
float64_t* weighted_n_node_samples) except -1 nogil:
cdef int node_reset(
self,
intp_t start,
intp_t end,
float64_t* weighted_n_node_samples
) except -1 nogil:
"""Reset splitter on node samples[start:end].
Returns -1 in case of failure to allocate memory (and raise MemoryError)
Expand All @@ -62,17 +67,7 @@ cdef class BaseObliqueSplitter(Splitter):
weighted_n_node_samples : ndarray, dtype=float64_t pointer
The total weight of those samples
"""

self.start = start
self.end = end

self.criterion.init(self.y,
self.sample_weight,
self.weighted_n_samples,
self.samples)
self.criterion.set_sample_pointers(start, end)

weighted_n_node_samples[0] = self.criterion.weighted_n_node_samples
Splitter.node_reset(self, start, end, weighted_n_node_samples)

# Clear all projection vectors
for i in range(self.max_features):
Expand Down Expand Up @@ -102,8 +97,8 @@ cdef class BaseObliqueSplitter(Splitter):
intp_t end,
const intp_t[:] samples,
float32_t[:] feature_values,
vector[float32_t]* proj_vec_weights, # weights of the vector (max_features,)
vector[intp_t]* proj_vec_indices # indices of the features (max_features,)
vector[float32_t]* proj_vec_weights, # weights of the vector for this projection (n_non_zeros',)
vector[intp_t]* proj_vec_indices # indices of the features for this projection (n_non_zeros',)
) noexcept nogil:
"""Compute the feature values for the samples[start:end] range.
Expand All @@ -126,20 +121,6 @@ cdef class BaseObliqueSplitter(Splitter):
feature_values[idx] = 0.0
feature_values[idx] += self.X[samples[idx], col_idx] * col_weight

cdef inline void fisher_yates_shuffle_memview(
self,
intp_t[::1] indices_to_sample,
intp_t grid_size,
uint32_t* random_state,
) noexcept nogil:
cdef intp_t i, j

# XXX: should this be `i` or `i+1`? for valid Fisher-Yates?
for i in range(0, grid_size - 1):
j = rand_int(i, grid_size, random_state)
indices_to_sample[j], indices_to_sample[i] = \
indices_to_sample[i], indices_to_sample[j]

cdef class ObliqueSplitter(BaseObliqueSplitter):
def __cinit__(
self,
Expand Down Expand Up @@ -220,6 +201,43 @@ cdef class ObliqueSplitter(BaseObliqueSplitter):
# self.feature_weights = np.ones((self.n_features,), dtype=float32_t) / self.n_features
return 0

cdef void sample_proj_vec(
self,
vector[float32_t]& proj_vec_weights,
vector[intp_t]& proj_vec_indices
) noexcept nogil:
cdef intp_t n_features = self.n_features
cdef intp_t n_non_zeros = self.n_non_zeros
cdef uint32_t* random_state = &self.rand_r_state

cdef intp_t i, feat_i, proj_i, rand_vec_index
cdef float32_t weight

# construct an array to sample from mTry x n_features set of indices
cdef intp_t[::1] indices_to_sample = self.indices_to_sample
cdef intp_t grid_size = self.max_features * self.n_features

# shuffle indices over the 2D grid to sample using Fisher-Yates
fisher_yates_shuffle(indices_to_sample, grid_size, random_state)

# sample 'n_non_zeros' in a mtry X n_features projection matrix
# which consists of +/- 1's chosen at a 1/2s rate
for i in range(0, n_non_zeros):
# get the next index from the shuffled index array
rand_vec_index = indices_to_sample[i]

# get the projection index (i.e. row of the projection matrix) and
# feature index (i.e. column of the projection matrix)
proj_i = rand_vec_index // n_features
feat_i = rand_vec_index % n_features

# sample a random weight
weight = 1 if (rand_int(0, 2, random_state) == 1) else -1

proj_vec_indices[proj_i].push_back(feat_i) # Store index of nonzero
proj_vec_weights[proj_i].push_back(weight) # Store weight of nonzero


cdef void sample_proj_mat(
self,
vector[vector[float32_t]]& proj_mat_weights,
Expand Down Expand Up @@ -257,7 +275,7 @@ cdef class ObliqueSplitter(BaseObliqueSplitter):
cdef intp_t grid_size = self.max_features * self.n_features

# shuffle indices over the 2D grid to sample using Fisher-Yates
self.fisher_yates_shuffle_memview(indices_to_sample, grid_size, random_state)
fisher_yates_shuffle(indices_to_sample, grid_size, random_state)

# sample 'n_non_zeros' in a mtry X n_features projection matrix
# which consists of +/- 1's chosen at a 1/2s rate
Expand Down Expand Up @@ -340,6 +358,8 @@ cdef class BestObliqueSplitter(ObliqueSplitter):
# XXX: 'feature' is not actually used in oblique split records
# Just indicates which split was sampled
current_split.feature = feat_i

# sample the projection vector
current_split.proj_vec_weights = &self.proj_mat_weights[feat_i]
current_split.proj_vec_indices = &self.proj_mat_indices[feat_i]

Expand Down
30 changes: 28 additions & 2 deletions treeple/tree/_utils.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,23 @@ cimport numpy as cnp

cnp.import_array()

from libcpp.vector cimport vector

from .._lib.sklearn.tree._splitter cimport SplitRecord
from .._lib.sklearn.utils._typedefs cimport float32_t, float64_t, int32_t, intp_t, uint32_t

ctypedef fused vector_or_memview:
vector[intp_t]
intp_t[::1]
intp_t[:]


cdef inline void fisher_yates_shuffle(
vector_or_memview indices_to_sample,
intp_t grid_size,
uint32_t* random_state,
) noexcept nogil


cdef intp_t rand_weighted_binary(
float64_t p0,
Expand All @@ -22,10 +36,22 @@ cpdef ravel_multi_index(intp_t[:] coords, const intp_t[:] shape)
cdef void unravel_index_cython(
intp_t index,
const intp_t[:] shape,
const intp_t[:] coords
vector_or_memview coords
) noexcept nogil

cdef intp_t ravel_multi_index_cython(
const intp_t[:] coords,
vector_or_memview coords,
const intp_t[:] shape
) noexcept nogil

cdef void compute_vectorized_indices_from_cartesian(
intp_t base_index,
vector[vector[intp_t]]& index_arrays,
const intp_t[:] data_dims,
vector[intp_t]& result
) noexcept nogil

cdef memoryview[float32_t, ndim=3] init_2dmemoryview(
cnp.ndarray array,
const intp_t[:] data_dims
)
Loading

0 comments on commit 6fe71a5

Please sign in to comment.