Skip to content

Commit

Permalink
some cleanup and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Dec 13, 2023
1 parent 418255d commit fa5af95
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 26 deletions.
17 changes: 10 additions & 7 deletions cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_index,
{
const value_t* x_ptr = X + (n_cols * blockIdx.x);

__shared__ int column_index_smem;
__shared__ unsigned long long int column_index_smem;

bool pass2 = adj_ja != nullptr;

Expand Down Expand Up @@ -584,7 +584,7 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_index,

const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind);
if (dfunc(x_ptr, y_ptr, n_cols) <= eps) {
int row_pos = atomicAdd(&column_index_smem, 1);
auto row_pos = atomicAdd(&column_index_smem, 1);
if (pass2) adj_ja[row_pos] = cur_candidate_ind;
}
}
Expand All @@ -595,14 +595,14 @@ RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_index,

const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind);
if (dfunc(x_ptr, y_ptr, n_cols) <= eps) {
int row_pos = atomicAdd(&column_index_smem, 1);
auto row_pos = atomicAdd(&column_index_smem, 1);
if (pass2) adj_ja[row_pos] = cur_candidate_ind;
}
}
}

__syncthreads();
if (threadIdx.x == 0 && !pass2) { adj_ia[blockIdx.x] = column_index_smem; }
if (threadIdx.x == 0 && !pass2) { adj_ia[blockIdx.x] = (value_idx)column_index_smem; }
}

template <typename value_idx = std::int64_t,
Expand Down Expand Up @@ -1119,8 +1119,11 @@ void rbc_eps_pass(raft::resources const& handle,
vd_ptr,
nullptr);

thrust::exclusive_scan(
resource::get_thrust_policy(handle), vd_ptr, vd_ptr + n_query_rows + 1, adj_ia, 0);
thrust::exclusive_scan(resource::get_thrust_policy(handle),
vd_ptr,
vd_ptr + n_query_rows + 1,
adj_ia,
(value_idx)0);

} else {
// pass 2 -> fill in adj_ja
Expand Down Expand Up @@ -1183,7 +1186,7 @@ void rbc_eps_pass(raft::resources const& handle,
}

thrust::exclusive_scan(
resource::get_thrust_policy(handle), vd_ptr, vd_ptr + n_query_rows + 1, adj_ia, 0);
resource::get_thrust_policy(handle), vd_ptr, vd_ptr + n_query_rows + 1, adj_ia, (value_idx)0);

block_rbc_kernel_eps_max_k_copy<value_idx, 32, value_int>
<<<n_query_rows, 32, 0, resource::get_cuda_stream(handle)>>>(
Expand Down
9 changes: 2 additions & 7 deletions cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,12 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
// P::AccThCols neighboring threads need to reduce
// -> we have P::Nblk/P::AccThCols individual reductions
auto cid = cidx + i * P::AccThRows;
if (cid < this->m) {
totalSum += sums[i];
atomicUpdate(cid, sums[i]);
}

/*sums[i] = batchedBlockReduce<IdxT, P::AccThCols>(sums[i], smem);
sums[i] = raft::logicalWarpReduce<P::AccThCols>(sums[i], raft::add_op());
if (lid == 0 && cid < this->m) {
atomicUpdate(cid, sums[i]);
totalSum += sums[i];
}
__syncthreads(); // for safe smem reuse*/
__syncthreads(); // for safe smem reuse
}
// update the total edge count
totalSum = raft::blockReduce<IdxT>(totalSum, smem);
Expand Down
24 changes: 12 additions & 12 deletions python/pylibraft/pylibraft/neighbors/eps_neighborhood.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ cdef class RbcIndexFloat(RbcIndex):
@auto_convert_output
def build_rbc_index(dataset, handle=None):
"""
Builds an random ball cover index from dataset.
Builds a random ball cover index from dataset using the L2-norm.
Parameters
----------
Expand All @@ -100,7 +100,7 @@ def build_rbc_index(dataset, handle=None):
Examples
--------
see 'eps_neighbors_l2_rbc'
see 'eps_neighbors_sparse'
"""
if handle is None:
Expand Down Expand Up @@ -131,9 +131,9 @@ def build_rbc_index(dataset, handle=None):

@auto_sync_handle
@auto_convert_output
def eps_neighbors_l2(dataset, queries, eps, method="brute", handle=None):
def eps_neighbors(dataset, queries, eps, method="brute", handle=None):
"""
Perform a brute-force epsilon neighborhood search.
Perform an epsilon neighborhood search using the L2-norm.
Parameters
----------
Expand All @@ -158,7 +158,7 @@ def eps_neighbors_l2(dataset, queries, eps, method="brute", handle=None):
--------
>>> import cupy as cp
>>> from pylibraft.common import DeviceResources
>>> from pylibraft.neighbors.eps_neighborhood import eps_neighbors_l2sq
>>> from pylibraft.neighbors.eps_neighborhood import eps_neighbors
>>> n_samples = 50000
>>> n_features = 50
>>> n_queries = 1000
Expand All @@ -167,7 +167,7 @@ def eps_neighbors_l2(dataset, queries, eps, method="brute", handle=None):
>>> queries = cp.random.random_sample((n_queries, n_features),
... dtype=cp.float32)
>>> eps = 0.1
>>> adj, vd = eps_neighbors_l2sq(dataset, queries, eps)
>>> adj, vd = eps_neighbors(dataset, queries, eps)
>>> adj = cp.asarray(adj)
>>> vd = cp.asarray(vd)
>>> # pylibraft functions are often asynchronous so the
Expand Down Expand Up @@ -228,9 +228,10 @@ def eps_neighbors_l2(dataset, queries, eps, method="brute", handle=None):

@auto_sync_handle
@auto_convert_output
def eps_neighbors_l2_rbc(RbcIndex rbc_index, queries, eps, handle=None):
def eps_neighbors_sparse(RbcIndex rbc_index, queries, eps, handle=None):
"""
Perform an epsilon neighborhood search with random ball cover (rbc).
Perform an epsilon neighborhood search with random ball cover (rbc)
using the L2-norm.
Parameters
----------
Expand All @@ -256,7 +257,7 @@ def eps_neighbors_l2_rbc(RbcIndex rbc_index, queries, eps, handle=None):
--------
>>> import cupy as cp
>>> from pylibraft.common import DeviceResources
>>> from pylibraft.neighbors.eps_neighborhood import eps_neighbors_l2sq_rbc
>>> from pylibraft.neighbors.eps_neighborhood import eps_neighbors_sparse
>>> from pylibraft.neighbors.eps_neighborhood import build_rbc_index
>>> n_samples = 50000
>>> n_features = 50
Expand All @@ -267,7 +268,7 @@ def eps_neighbors_l2_rbc(RbcIndex rbc_index, queries, eps, handle=None):
... dtype=cp.float32)
>>> eps = 0.1
>>> rbc_index = build_rbc_index(dataset, handle=handle)
>>> adj_ia, adj_ja, vd = eps_neighbors_l2_rbc(rbc_index, queries, eps)
>>> adj_ia, adj_ja, vd = eps_neighbors_sparse(rbc_index, queries, eps)
>>> adj_ia = cp.asarray(adj_ia)
>>> adj_ja = cp.asarray(adj_ja)
>>> vd = cp.asarray(vd)
Expand All @@ -287,7 +288,7 @@ def eps_neighbors_l2_rbc(RbcIndex rbc_index, queries, eps, handle=None):

n_queries = queries_cai.shape[0]

adj_ia = device_ndarray.empty((n_queries +1, ), dtype='int64')
adj_ia = device_ndarray.empty((n_queries + 1, ), dtype='int64')
vd = device_ndarray.empty((n_queries + 1, ), dtype='int64')
adj_ia_cai = cai_wrapper(adj_ia)
vd_cai = cai_wrapper(vd)
Expand Down Expand Up @@ -317,7 +318,6 @@ def eps_neighbors_l2_rbc(RbcIndex rbc_index, queries, eps, handle=None):

handle.sync()
n_nnz = adj_ia.copy_to_host()[n_queries]

adj_ja = device_ndarray.empty((n_nnz, ), dtype='int64')
adj_ja_cai = cai_wrapper(adj_ja)
adj_ja_vector_view = make_device_vector_view(
Expand Down

0 comments on commit fa5af95

Please sign in to comment.