Skip to content

Commit

Permalink
Added: Ball query for batch vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
KanishkNavale committed Feb 6, 2024
1 parent 55d524b commit e21399e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
1 change: 1 addition & 0 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Bound to a long term development. Below is the list of present contents:

- Pointcloud:
- Furthest point sampling
- Ball query

- Robotics:
- Stable inverse jacobian
Expand Down
10 changes: 6 additions & 4 deletions heimdall/pointcloud/sampling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import torch

from heimdall.utils import convert_numpy_to_tensor
Expand Down Expand Up @@ -70,10 +72,10 @@ def furthest_point_sampling(

@convert_numpy_to_tensor
def ball_query(
vector: torch.Tensor, pointcloud: torch.Tensor, radius: float = 0.01, **kwargs
) -> torch.Tensor:
distances = torch.linalg.norm(pointcloud - vectors.unsqueeze(dim=0), dim=-1, ord=2)
vectors: torch.Tensor, pointcloud: torch.Tensor, radius: float = 0.01, **kwargs
) -> List[torch.Tensor]:
distances = torch.linalg.norm(pointcloud - vectors.unsqueeze(dim=1), dim=-1, ord=2)

sampling_mask = (distances <= radius) * (distances > 0.0)

return pointcloud[sampling_mask]
return [pointcloud[mask] for mask in sampling_mask]

0 comments on commit e21399e

Please sign in to comment.