Skip to content

Commit

Permalink
Parallelize by chunk instead of querying twice.
Browse files Browse the repository at this point in the history
  • Loading branch information
Huite committed Oct 13, 2024
1 parent 374a7a4 commit 806743e
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 117 deletions.
12 changes: 9 additions & 3 deletions numba_celltree/celltree.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Tuple

import numba as nb
import numpy as np

from numba_celltree.algorithms import (
Expand Down Expand Up @@ -181,7 +182,8 @@ def locate_boxes(self, bbox_coords: FloatArray) -> Tuple[IntArray, IntArray]:
Indices of the face.
"""
bbox_coords = cast_bboxes(bbox_coords)
return locate_boxes(bbox_coords, self.celltree_data)
n_chunks = nb.get_num_threads()
return locate_boxes(bbox_coords, self.celltree_data, n_chunks)

def intersect_boxes(
self, bbox_coords: FloatArray
Expand All @@ -205,7 +207,8 @@ def intersect_boxes(
Area of intersection between the two intersecting faces.
"""
bbox_coords = cast_bboxes(bbox_coords)
i, j = locate_boxes(bbox_coords, self.celltree_data)
n_chunks = nb.get_num_threads()
i, j = locate_boxes(bbox_coords, self.celltree_data, n_chunks)
area = box_area_of_intersection(
bbox_coords=bbox_coords,
vertices=self.vertices,
Expand Down Expand Up @@ -246,7 +249,10 @@ def locate_faces(
"""
counter_clockwise(vertices, faces)
bbox_coords = build_bboxes(faces, vertices)
shortlist_i, shortlist_j = locate_boxes(bbox_coords, self.celltree_data)
n_chunks = nb.get_num_threads()
shortlist_i, shortlist_j = locate_boxes(
bbox_coords, self.celltree_data, n_chunks
)
intersects = polygons_intersect(
vertices_a=vertices,
vertices_b=self.vertices,
Expand Down
2 changes: 1 addition & 1 deletion numba_celltree/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class CellTreeData(NamedTuple):
# (int(math.ceil(math.log(MAX_N_FACE, 2))) + 1)
# This is only true for relatively balanced trees. MAX_N_FACE = int(2e9)
# results in required stack of 32.
INITIAL_TREE_DEPTH = 32
INITIAL_STACK_LENGTH = 32
# Floating point slack
TOLERANCE_ON_EDGE = 1e-9

Expand Down
45 changes: 14 additions & 31 deletions numba_celltree/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
NodeDType,
)
from numba_celltree.geometry_utils import build_bboxes
from numba_celltree.utils import allocate_stack, pop, push
from numba_celltree.utils import (
allocate_double_stack,
pop_both,
push_both,
)


@nb.njit(inline="always")
Expand Down Expand Up @@ -227,20 +231,6 @@ def pessimistic_n_nodes(n_polys: int):
return n_nodes + 1


@nb.njit(inline="always")
def push_both(root_stack, dim_stack, root, dim, size):
root_stack, size_root = push(root_stack, root, size)
dim_stack, _ = push(dim_stack, dim, size)
return root_stack, dim_stack, size_root


@nb.njit(inline="always")
def pop_both(root_stack, dim_stack, size):
root, size_root = pop(root_stack, size)
dim, _ = pop(dim_stack, size)
return root, dim, size_root


@nb.njit(cache=True)
def build(
nodes: NodeArray,
Expand All @@ -251,15 +241,14 @@ def build(
cells_per_leaf: int,
):
# Cannot compile ahead of time with Numba and recursion
# Just use a stack based approach instead
root_stack = allocate_stack()
dim_stack = allocate_stack()
root_stack[0] = 0
dim_stack[0] = 0
# Just use a stack based approach instead; store root and dim values.
stack = allocate_double_stack()
stack[0, 0] = 0
stack[0, 1] = 0
size = 1

while size > 0:
root_index, dim, size = pop_both(root_stack, dim_stack, size)
root_index, dim, size = pop_both(stack, size)

dim_flag = dim
if dim < 0:
Expand Down Expand Up @@ -303,7 +292,7 @@ def build(
0, # size
)
)
# NOTA BENE: do not change the default size (0) given to the bucket here
# NOTE: do not change the default size (0) given to the bucket here
# it is used to detect empty buckets later on.

# Now that the buckets are setup, sort them
Expand Down Expand Up @@ -361,9 +350,7 @@ def build(
if dim_flag >= 0:
dim_flag = (not dim) - 2
nodes[root_index]["dim"] = not root.dim
root_stack, dim_stack, size = push_both(
root_stack, dim_stack, root_index, dim_flag, size
)
stack, size = push_both(stack, root_index, dim_flag, size)
else: # Already split once, convert to leaf.
nodes[root_index]["Lmax"] = -1
nodes[root_index]["Rmin"] = -1
Expand All @@ -389,12 +376,8 @@ def build(
node_index = push_node(nodes, left_child, node_index)
node_index = push_node(nodes, right_child, node_index)

root_stack, dim_stack, size = push_both(
root_stack, dim_stack, child_ind + 1, right_child.dim, size
)
root_stack, dim_stack, size = push_both(
root_stack, dim_stack, child_ind, left_child.dim, size
)
stack, size = push_both(stack, child_ind + 1, right_child.dim, size)
stack, size = push_both(stack, child_ind, left_child.dim, size)

return node_index

Expand Down
180 changes: 110 additions & 70 deletions numba_celltree/query.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import numba as nb
import numpy as np

Expand Down Expand Up @@ -25,7 +27,32 @@
point_in_polygon_or_on_edge,
to_vector,
)
from numba_celltree.utils import allocate_polygon, allocate_stack, pop, push
from numba_celltree.utils import (
allocate_polygon,
allocate_stack,
allocate_triple_stack,
grow,
pop,
pop_triple,
push,
push_triple,
)


@nb.njit(inline="always")
def concatenate_indices(
indices: IntArray, counts: IntArray
) -> Tuple[IntArray, IntArray]:
total_size = sum(counts)
ii = np.empty(total_size, dtype=IntDType)
jj = np.empty(total_size, dtype=IntDType)
start = 0
for i, size in enumerate(counts):
end = start + size
ii[start:end] = indices[i][:size, 0]
jj[start:end] = indices[i][:size, 1]
start = end
return ii, jj


# Inlining saves about 15% runtime
Expand Down Expand Up @@ -90,14 +117,33 @@ def locate_points(


@nb.njit(inline="always")
def locate_box(box: Box, tree: CellTreeData, indices: IntArray, store_indices: bool):
def locate_box(
box: Box, tree: CellTreeData, indices: IntArray, indices_size: int, index: int
) -> Tuple[int, int]:
"""
Search the tree for a single axis-aligned box.
Parameters
----------
box: Box named tuple
tree: CellTreeData
indices: IntArray
Array for results. Contains ``index`` of the box we're searching for
the first column, and the index of the box in the celltree (if any) in
the second.
indices_size: int
Current number of filled in values in ``indices``.
index: int
Current index of the box we're searching.
"""
tree_bbox = as_box(tree.bbox)
if not boxes_intersect(box, tree_bbox):
return 0
return 0, indices_size
stack = allocate_stack()
stack[0] = 0
size = 1
count = 0
length = len(indices)

while size > 0:
node_index, size = pop(stack, size)
Expand All @@ -110,8 +156,15 @@ def locate_box(box: Box, tree: CellTreeData, indices: IntArray, store_indices: b
# As a named tuple: saves about 15% runtime
leaf_box = as_box(tree.bb_coords[bbox_index])
if boxes_intersect(box, leaf_box):
if store_indices:
indices[count] = bbox_index
# Exit if we need to re-allocate the array. Exiting instead
# of drawing the re-allocation logic in here makes a
# significant runtime difference; seems like numba can
# optimize this form better.
if indices_size >= length:
return -1, indices_size
indices[indices_size, 0] = index
indices[indices_size, 1] = bbox_index
indices_size += 1
count += 1
else:
dim = 1 if node["dim"] else 0
Expand All @@ -130,51 +183,52 @@ def locate_box(box: Box, tree: CellTreeData, indices: IntArray, store_indices: b
elif right:
stack, size = push(stack, right_child, size)

return count
return count, indices_size


@nb.njit(parallel=PARALLEL, cache=True)
def locate_boxes(
@nb.njit(cache=True)
def locate_boxes_helper(
box_coords: FloatArray,
tree: CellTreeData,
):
# Numba does not support a concurrent list or bag like stucture:
# https://github.com/numba/numba/issues/5878
# (Standard lists are not thread safe.)
# To support parallel execution, we're stuck with numpy arrays therefore.
# Since we don't know the number of contained bounding boxes, we traverse
# the tree twice: first to count, then allocate, then another time to
# actually store the indices.
# The cost of traversing twice is roughly a factor two. Since many
# computers can parallellize over more than two threads, counting first --
# which enables parallelization -- should still result in a net speed up.
n_box = box_coords.shape[0]
counts = np.empty(n_box + 1, dtype=IntDType)
dummy = np.empty((0,), dtype=IntDType)
counts[0] = 0
# First run a count so we can allocate afterwards
for i in nb.prange(n_box): # pylint: disable=not-an-iterable
box = as_box(box_coords[i])
counts[i + 1] = locate_box(box, tree, dummy, False)

# Run a cumulative sum
total = 0
for i in range(1, n_box + 1):
total += counts[i]
counts[i] = total

# Now allocate appropriately
ii = np.empty(total, dtype=IntDType)
jj = np.empty(total, dtype=IntDType)
for i in nb.prange(n_box): # pylint: disable=not-an-iterable
start = counts[i]
end = counts[i + 1]
ii[start:end] = i
indices = jj[start:end]
box = as_box(box_coords[i])
locate_box(box, tree, indices, True)

return ii, jj
offset: int,
) -> IntArray:
n_box = len(box_coords)
# Ensure the initial indices array isn't too small.
indices = np.empty((max(n_box, 256), 2), dtype=IntDType)
total_count = 0
indices_size = 0
for box_index in range(n_box):
box = as_box(box_coords[box_index])
# Re-allocating here is significantly faster than re-allocating inside
# of ``locate_box``; presumably because that function is kept simpler and
# numba can optimize better. Unfortunately, that means we have to keep
# trying until we succeed here; in most cases, success is immediate as
# the indices array will have enough capacity.
while True:
count, indices_size = locate_box(
box, tree, indices, indices_size, box_index + offset
)
if count != -1:
break
# Not enough capacity: grow capacity, discard partial work, retry.
indices_size = total_count
indices = grow(indices)
total_count += count
return indices, total_count


@nb.njit(cache=True, parallel=PARALLEL)
def locate_boxes(box_coords: FloatArray, tree: CellTreeData, n_chunks: int):
chunks = np.array_split(box_coords, n_chunks)
offsets = np.zeros(n_chunks, dtype=IntDType)
for i, chunk in enumerate(chunks[:-1]):
offsets[i + 1] = offsets[i] + len(chunk)
# Setup (dummy) typed list for numba to store parallel results.
indices = [np.empty((0, 2), dtype=IntDType) for _ in range(n_chunks)]
counts = np.empty(n_chunks, dtype=IntDType)
for i in nb.prange(n_chunks):
indices[i], counts[i] = locate_boxes_helper(chunks[i], tree, offsets[i])
return concatenate_indices(indices, counts)


# Inlining this function drives compilation time through the roof. It's
Expand Down Expand Up @@ -340,27 +394,18 @@ def collect_node_bounds(tree: CellTreeData) -> FloatArray:
node_bounds[0, 2] = tree.bbox[2]
node_bounds[0, 3] = tree.bbox[3]

stack = allocate_stack()
parent_stack = allocate_stack()
side_stack = allocate_stack()

# Right child
stack[0] = 2
parent_stack[0] = 0
side_stack[0] = 0
# Left child
stack[1] = 1
parent_stack[1] = 0
side_stack[1] = 1
# Stack size starts at two.
# Stack contains: node_index, parent_index, side (right/left)
ROOT = 0
RIGHT = 0
LEFT = 1
stack = allocate_triple_stack()
stack[0, :] = (2, ROOT, RIGHT) # Right child
stack[1, :] = (1, ROOT, LEFT) # Left child
size = 2

while size > 0:
# Collect from stacks
# Sizes are synchronized.
parent_index, _ = pop(parent_stack, size)
side, _ = pop(side_stack, size)
node_index, size = pop(stack, size)
node_index, parent_index, side, size = pop_triple(stack, size)

parent = tree.nodes[parent_index]
bbox = node_bounds[parent_index]
Expand Down Expand Up @@ -388,14 +433,9 @@ def collect_node_bounds(tree: CellTreeData) -> FloatArray:
right_child = left_child + 1

# Right child
push(parent_stack, node_index, size)
push(side_stack, 0, size)
stack, size = push(stack, right_child, size)

stack, size = push_triple(stack, right_child, node_index, RIGHT, size)
# Left child
push(parent_stack, node_index, size)
push(side_stack, 1, size)
stack, size = push(stack, left_child, size)
stack, size = push_triple(stack, left_child, node_index, LEFT, size)

return node_bounds

Expand Down
Loading

0 comments on commit 806743e

Please sign in to comment.