This repository contains an implementation of a compact hashing based neighborhood search for 1D, 2D and 3D data for pyTorch using a C++/CUDA backend. This code is designed for large scale problems, e.g., point clouds with
Requirements:
pyTorch >= 2.0
The module is built either just-in-time (this is what you get when you install it via pip directly) or pre-built for a variety of systems via conda or our website. Note that for MacOS based systems an external clang compiler installed via homebrew is required for openMP support.
Anaconda:
pytorch pyfluids::torch-compact-radius torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
pip:
pip install torchCompactRadius -f https://fluids.dev/torchCompactRadius/wheels/torch-2.5.0+{cuTag}/
Note, if you are using Google Colab (or similar) you can run
import torch
!pip install torchCompactRadius -f https://fluids.dev/torchCompactRadius/wheels/torch-{version}/
Or the JIT compiled version available on PyPi:
Note that if you install the latter, it makes sense to limit which architectures the code is compiled for before import torchCompactRadius
import torch
os.environ['TORCH_CUDA_ARCH_LIST'] = f'{torch.cuda.get_device_properties(0).major}.{torch.cuda.get_device_properties(0).minor}'
import torchCompactRadius
This has changed from previous versions
This package provices two primary functions radius
and radiusSearch
. radius
is designed as a drop-in replacement of torch cluster's radius function, whereas radiusSearch is the preferred usage. Important: radius
and radiusSearch
return index pairs in flipped order!
To call the radiusSearch
version we use a set of NamedTuples to make the calling conventions less error prone, these are:
class DomainDescription(NamedTuple):
min: torch.Tensor
max: torch.Tensor
periodicity: Union[bool,torch.Tensor]
dim: int
class PointCloud(NamedTuple):
positions: torch.Tensor
supports: Optional[torch.Tensor] = None
class SparseCOO(NamedTuple):
row: torch.Tensor
col: torch.Tensor
numRows: torch.Tensor
numCols: torch.Tensor
class SparseCSR(NamedTuple):
indices: torch.Tensor
indptr: torch.Tensor
rowEntries: torch.Tensor
numRows: torch.Tensor
numCols: torch.Tensor
Based on these we can then construct an input set:
dim = 2
targetNumNeighbors = 32
nx = 32
minDomain = torch.tensor([-1] * dim, dtype = torch.float32, device = device)
maxDomain = torch.tensor([ 1] * dim, dtype = torch.float32, device = device)
periodicity = torch.tensor([periodic] * dim, device = device, dtype = torch.bool)
extent = maxDomain - minDomain
shortExtent = torch.min(extent, dim = 0)[0].item()
dx = (shortExtent / nx)
h = volumeToSupport(dx**dim, targetNumNeighbors, dim)
positions = []
for d in range(dim):
positions.append(torch.linspace(minDomain[d] + dx / 2, maxDomain[d] - dx / 2, int((extent[d] - dx) / dx) + 1, device = device))
grid = torch.meshgrid(*positions, indexing = 'xy')
positions = torch.stack(grid, dim = -1).reshape(-1,dim).to(device)
supports = torch.ones(positions.shape[0], device = device) * h
domainDescription = DomainDescription(minDomain, maxDomain, periodicity, dim)
pointCloudX = PointCloud(positions, supports)
We can then call the radiusSearch
method to compute the neighborhood in COO format:
adjacency = radiusSearch(pointCloudX, domain = domainDescription)
The radiusSearch
method has some further options:
def radiusSearch(
queryPointCloud: PointCloud,
referencePointCloud: Optional[PointCloud],
supportOverride : Optional[float] = None,
mode : str = 'gather',
domain : Optional[DomainDescription] = None,
hashMapLength = 4096,
algorithm: str = 'naive',
verbose: bool = False,
format: str = 'coo',
returnStructure : bool = False
)
-
queryPointCloud
contains the set of points that are related to the other set -
referencePositions
contains the reference set of points, i.e., the points for which relations are queried -
support
determines the cut-off radius for the radius search. This value is either a scalar float, i.e., every point has an identical cut-off radius, a single Tensor of size$n$ that contains a different cut-off radius for every point inqueryPositions
-
mode
determines the method used to compute the cut-off radius of point to point interactions. Options are (a)gather
, which uses only the cut-off radius for thequeryPositions
, (b)scatter
, which uses only the cut-off radius for thereferencePositions
and (c)symmetric
, which uses the mean cut-off radius. -
domainMin
anddomainMax
are required for periodic neighborhood searches to define the coordinates at which point the positions wrap around -
periodicity
indicates if a periodic neighborhood search is to be performed as either a bool (applied to all dimensions) or a list of bools (one per dimension) -
hashMapLength
is used to determine the internal length of the hash map used in the compact data structure, should be close to$n_x$ -
verbose
prints additional logging information on the console -
returnStructure
decides if thecompact
algorithm should return its datastructure for reuse in later searches -
format
decides if an adjacency description in COO or CSR format is returned
For the algorithm the following 4 options exist:
-
naive
: This algorithm computes a dense distance matrix of size$n_x \times n_y \times d$ and performs the adjacency computations on this dense representation. This requires significant amounts of memory but is very straight forward and potentially differentiable. Complexity:$\mathcal{O}\left(n^2\right)$ -
cluster
: This is a wrapper around torch_cluster'sradius
search and only available if that package is installed. Note that this algorithm does not support periodic neighbor searches and does not support non-uniform cut-off radii with a complexity of$\mathcal{O}\left(n^2\right)$ . This algorithm is also limited to a fixed number of maximum neighbors ($256$ ). -
small
: This algorithm is similar tocluster
in its implementation and computes an everything against everything distance on-the-fly, i.e., it does not require intermediate large storage, and first computes the number of neighbors per particle and then allocates the according memory. Accordingly, this approach is slower thancluster
but more versatile. Complexity:$\mathcal{O}\left(n^2\right)$ -
compact
: The primary algorithm of this library. This approach uses compact hashing and a cell-based datastructure to compute neighborhoods in$\mathcal{O}\left(n\log n\right)$ . The idea is based on A parallel sph implementation on multi-core cpus and the GPU approach is based on Multi-Level Memory Structures for Simulating and Rendering SPH. Note that this implementation is not optimized for adaptive simulations.
If you want to evaluate the performance on your system simply run scripts/benchmark.py
, which will generate a Benchmark.png
for various numbers of point counts algorithms and dimensions.
Compute Performance on GPUs for small scale problems:
3090 | A5000 |
---|---|
CPU perforamnce:
Overall GPU based performance for larger scale problems: