Skip to content

Commit

Permalink
new wip
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Aug 28, 2024
1 parent a73e5a7 commit a201ba1
Show file tree
Hide file tree
Showing 12 changed files with 662 additions and 148 deletions.
77 changes: 77 additions & 0 deletions examples/splitters/plot_kernel_splitters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
====================
Random Kernel Splits
====================
This example shows how to build a manifold oblique decision tree classifier using
a custom set of user-defined kernel/filter library, such as the Gaussian, or Gabor
kernels.
The example demonstrates superior performance on a 2D dataset with structured images
as samples. The dataset is the downsampled MNIST dataset, where each sample is a
28x28 image. The dataset is downsampled to 14x14, and then flattened to a 196
dimensional vector. The dataset is then split into a training and testing set.
See :ref:`sphx_glr_auto_examples_plot_projection_matrices` for more information on
projection matrices and the way they can be sampled.
"""

import matplotlib.pyplot as plt
import numpy as np
from scipy.sparse import csr_matrix

from treeple.tree.manifold._kernel_splitter import Kernel2D

# %%
# Create a synthetic image
image_height, image_width = 50, 50
image = np.random.rand(image_height, image_width).astype(np.float32)

# Generate a Gaussian kernel (example)
kernel_size = 7
x = np.linspace(-2, 2, kernel_size)
y = np.linspace(-2, 2, kernel_size)
x, y = np.meshgrid(x, y)
kernel = np.exp(-(x**2 + y**2))
kernel = kernel / kernel.sum() # Normalize the kernel

# Vectorize and create a sparse CSR matrix
kernel_vector = kernel.flatten().astype(np.float32)
kernel_indices = np.arange(kernel_vector.size)
kernel_indptr = np.array([0, kernel_vector.size])
kernel_csr = csr_matrix(
(kernel_vector, kernel_indices, kernel_indptr), shape=(1, kernel_vector.size)
)

# %%
# Initialize the Kernel2D class
kernel_sizes = np.array([kernel_size], dtype=np.intp)
random_state = np.random.RandomState(42)
print(kernel_csr.dtype, kernel_sizes.dtype, np.intp)
kernel_2d = Kernel2D(kernel_csr, kernel_sizes, random_state)

# Apply the kernel to the image
result_value = kernel_2d.apply_kernel_py(image, 0)

# %%
# Plot the original image, kernel, and result
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

axs[0].imshow(image, cmap="gray")
axs[0].set_title("Original Image")

axs[1].imshow(kernel, cmap="viridis")
axs[1].set_title("Gaussian Kernel")

# Highlight the region where the kernel was applied
start_x, start_y = random_state.randint(0, image_width - kernel_size + 1), random_state.randint(
0, image_height - kernel_size + 1
)
image_with_kernel = image.copy()
image_with_kernel[start_y : start_y + kernel_size, start_x : start_x + kernel_size] *= kernel

axs[2].imshow(image_with_kernel, cmap="gray")
axs[2].set_title(f"Result: {result_value:.4f}")

plt.tight_layout()
plt.show()
84 changes: 83 additions & 1 deletion treeple/tree/_kernel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy

import numpy as np
from scipy.sparse import issparse
from scipy.sparse import csr_matrix, issparse

from .._lib.sklearn.tree._criterion import BaseCriterion
from .._lib.sklearn.tree._tree import BestFirstTreeBuilder, DepthFirstTreeBuilder
Expand All @@ -20,6 +20,88 @@ def gaussian_kernel(shape, sigma=1.0, mu=0.0):
return g / np.sum(g)


def sample_gaussian_kernels(n_kernels, min_size, max_size, mu=0.0, sigma=1.0):
"""
Sample a set of Gaussian kernels and arrange them into a sparse kernel matrix.
Parameters:
-----------
n_kernels : int
The number of Gaussian kernels to sample.
min_size : int
The minimum size of the kernel (inclusive).
max_size : int
The maximum size of the kernel (inclusive).
mu : float or tuple of floats
The mean(s) of the Gaussian distribution. If a tuple, random mean values are sampled.
sigma : float or tuple of floats
The standard deviation(s) of the Gaussian distribution. If a tuple, random sigma values are sampled.
Returns:
--------
kernel_matrix : csr_matrix
The sparse matrix containing vectorized kernels.
kernel_params : dict of arrays
The parameters of the kernels that were sampled with keys:
- 'size': the size of each kernel; since the kernels are 2D square matrices
this is the side length of the kernel.
- 'mu': the mean of each kernel
- 'sigma': the standard deviation
"""
data = []
indices = []
indptr = [0]
kernel_params = {
"size": np.zeros(n_kernels, dtype=np.intp),
"mu": np.zeros(n_kernels),
"sigma": np.zeros(n_kernels),
}

for i in range(n_kernels):
# Sample the size of the kernel
size = np.random.randint(min_size, max_size + 1)

# Sample mu and sigma if they are tuples
if isinstance(mu, tuple):
mu_sample = np.random.uniform(mu[0], mu[1])
else:
mu_sample = mu

if isinstance(sigma, tuple):
sigma_sample = np.random.uniform(sigma[0], sigma[1])
else:
sigma_sample = sigma

# Create a meshgrid for the kernel
x = np.linspace(-1, 1, size)
y = np.linspace(-1, 1, size)
X, Y = np.meshgrid(x, y)

# Create the Gaussian kernel
kernel = np.exp(-((X - mu_sample) ** 2 + (Y - mu_sample) ** 2) / (2 * sigma_sample**2))

# Vectorize the kernel and store it in the sparse matrix format
kernel_vectorized = kernel.flatten()
data.extend(kernel_vectorized)
indices.extend(range(size * size))
indptr.append(len(data))

# Store the kernel parameters
kernel_params["size"][i] = size
kernel_params["mu"][i] = mu_sample
kernel_params["sigma"][i] = sigma_sample

# Convert lists to the appropriate format
data = np.array(data)
indices = np.array(indices)
indptr = np.array(indptr)

# Create a sparse matrix
kernel_matrix = csr_matrix((data, indices, indptr), shape=(n_kernels, max_size * max_size))

return kernel_matrix, kernel_params


class KernelDecisionTreeClassifier(PatchObliqueDecisionTreeClassifier):
"""Oblique decision tree classifier over data patches combined with Gaussian kernels.
Expand Down
2 changes: 1 addition & 1 deletion treeple/tree/_oblique_splitter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ cdef class BestObliqueSplitter(ObliqueSplitter):
cdef intp_t end = self.end

# pointer array to store feature values to split on
cdef float32_t[::1] feature_values = self.feature_values
cdef float32_t[::1] feature_values = self.feature_values
cdef intp_t max_features = self.max_features
cdef intp_t min_samples_leaf = self.min_samples_leaf

Expand Down
16 changes: 13 additions & 3 deletions treeple/tree/_utils.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,24 @@ from .._lib.sklearn.tree._splitter cimport SplitRecord
from .._lib.sklearn.utils._typedefs cimport float32_t, float64_t, int32_t, intp_t, uint32_t


cdef int rand_weighted_binary(float64_t p0, uint32_t* random_state) noexcept nogil
cdef intp_t rand_weighted_binary(
float64_t p0,
uint32_t* random_state
) noexcept nogil

cpdef unravel_index(
intp_t index, cnp.ndarray[intp_t, ndim=1] shape
)

cpdef ravel_multi_index(intp_t[:] coords, const intp_t[:] shape)

cdef void unravel_index_cython(intp_t index, const intp_t[:] shape, intp_t[:] coords) noexcept nogil
cdef void unravel_index_cython(
intp_t index,
const intp_t[:] shape,
const intp_t[:] coords
) noexcept nogil

cdef intp_t ravel_multi_index_cython(intp_t[:] coords, const intp_t[:] shape) noexcept nogil
cdef intp_t ravel_multi_index_cython(
const intp_t[:] coords,
const intp_t[:] shape
) noexcept nogil
40 changes: 36 additions & 4 deletions treeple/tree/_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# cython: wraparound=False
# cython: initializedcheck=False

from libcpp.vector cimport vector

import numpy as np

cimport numpy as cnp
Expand All @@ -14,7 +16,10 @@ cnp.import_array()
from .._lib.sklearn.tree._utils cimport rand_uniform


cdef inline int rand_weighted_binary(float64_t p0, uint32_t* random_state) noexcept nogil:
cdef inline intp_t rand_weighted_binary(
float64_t p0,
uint32_t* random_state
) noexcept nogil:
"""Sample from integers 0 and 1 with different probabilities.
Parameters
Expand Down Expand Up @@ -83,7 +88,11 @@ cpdef ravel_multi_index(intp_t[:] coords, const intp_t[:] shape):
return ravel_multi_index_cython(coords, shape)


cdef void unravel_index_cython(intp_t index, const intp_t[:] shape, intp_t[:] coords) noexcept nogil:
cdef inline void unravel_index_cython(
intp_t index,
const intp_t[:] shape,
const intp_t[:] coords
) noexcept nogil:
"""Converts a flat index into a tuple of coordinate arrays.
Parameters
Expand All @@ -109,8 +118,11 @@ cdef void unravel_index_cython(intp_t index, const intp_t[:] shape, intp_t[:] co
index //= size


cdef intp_t ravel_multi_index_cython(intp_t[:] coords, const intp_t[:] shape) noexcept nogil:
"""Converts a tuple of coordinate arrays into a flat index.
cdef inline intp_t ravel_multi_index_cython(
const intp_t[:] coords,
const intp_t[:] shape
) noexcept nogil:
"""Converts a tuple of coordinate arrays into a flat index in the vectorized dimension.
Parameters
----------
Expand Down Expand Up @@ -145,3 +157,23 @@ cdef intp_t ravel_multi_index_cython(intp_t[:] coords, const intp_t[:] shape) no
flat_index *= shape[i + 1]

return flat_index


cdef vector[vector[intp_t]] cartesian_cython(
vector[vector[intp_t]] sequences
) noexcept nogil:
cdef vector[vector[intp_t]] results = vector[vector[intp_t]](1)
cdef vector[vector[intp_t]] next_results
for new_values in sequences:
for result in results:
for value in new_values:
result_copy = result
result_copy.push_back(value)
next_results.push_back(result_copy)
results = next_results
next_results.clear()
return results


cpdef cartesian_python(vector[vector[intp_t]]& sequences):
return cartesian_cython(sequences)
37 changes: 37 additions & 0 deletions treeple/tree/manifold/_kernel_splitter.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np

from libcpp.vector cimport vector

from ..._lib.sklearn.tree._splitter cimport SplitRecord
from ..._lib.sklearn.utils._typedefs cimport (
float32_t,
float64_t,
int32_t,
intp_t,
uint8_t,
uint32_t,
)
from .._oblique_splitter cimport BestObliqueSplitter, ObliqueSplitRecord
from ._morf_splitter cimport PatchSplitter


cdef class UserKernelSplitter(PatchSplitter):
"""A class to hold user-specified kernels."""
# cdef vector[float32_t[:, ::1]] kernel_dictionary # A list of C-contiguous 2D kernels
cdef vector[float32_t*] kernel_dictionary # A list of C-contiguous 2D kernels
cdef vector[intp_t*] kernel_dims # A list of arrays storing the dimensions of each kernel in `kernel_dictionary`


cdef class GaussianKernelSplitter(PatchSplitter):
"""A class to hold Gaussian kernels.
Overrides the weights that are generated to be sampled from a Gaussian distribution.
See: https://www.tutorialspoint.com/gaussian-filter-generation-in-cplusplus
See: https://gist.github.com/thomasaarholt/267ec4fff40ca9dff1106490ea3b7567
"""

cdef void sample_proj_mat(
self,
vector[vector[float32_t]]& proj_mat_weights,
vector[vector[intp_t]]& proj_mat_indices
) noexcept nogil
Loading

0 comments on commit a201ba1

Please sign in to comment.