Skip to content

Commit

Permalink
add stack_ball_query opp
Browse files Browse the repository at this point in the history
  • Loading branch information
Ginray committed Dec 26, 2023
1 parent 857b041 commit 2f50ddd
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ We implement common ops used in detection, segmentation, etc.
| ---------------------------- | --- | ---- | --- | --- | ------ |
| ActiveRotatedFilter ||| | ||
| AssignScoreWithK | || | | |
| BallQuery | ||| | |
| BallQuery | ||| | |
| BBoxOverlaps | |||||
| BorderAlign | || | | |
| BoxIouRotated |||| ||
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| ---------------------------- | --- | ---- | --- | --- | ------ |
| ActiveRotatedFilter ||| | ||
| AssignScoreWithK | || | | |
| BallQuery | ||| | |
| BallQuery | ||| | |
| BBoxOverlaps | |||||
| BorderAlign | || | | |
| BoxIouRotated |||| ||
Expand Down
23 changes: 23 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void stack_ball_query_forward_npu(float max_radius, int nsample,
const Tensor new_xyz,
const Tensor new_xyz_batch_cnt,
const Tensor xyz, const Tensor xyz_batch_cnt,
Tensor idx) {
at::Tensor xyz_transpose = xyz.transpose(0, 1).contiguous();
double max_radius_double = double(max_radius);
EXEC_NPU_CMD(aclnnStackBallQuery, xyz_transpose, new_xyz, xyz_batch_cnt,
new_xyz_batch_cnt, max_radius_double, nsample, idx);
}

void stack_ball_query_forward_impl(float max_radius, int nsample,
const Tensor new_xyz,
const Tensor new_xyz_batch_cnt,
const Tensor xyz, const Tensor xyz_batch_cnt,
Tensor idx);

REGISTER_NPU_IMPL(stack_ball_query_forward_impl, stack_ball_query_forward_npu);
25 changes: 16 additions & 9 deletions tests/test_ops/test_ball_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,17 @@ def test_ball_query(device):
assert torch.all(idx == expected_idx)


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_stack_ball_query():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_stack_ball_query(device):
new_xyz = torch.tensor([[-0.0740, 1.3147, -1.3625],
[-2.2769, 2.7817, -0.2334],
[-0.4003, 2.4666, -0.5116],
Expand All @@ -75,8 +83,8 @@ def test_stack_ball_query():
[-2.0668, 6.0278, -0.4875],
[0.4066, 1.4211, -0.2947],
[-2.0289, 2.4952, -0.1708],
[-2.0289, 2.4952, -0.1708]]).cuda()
new_xyz_batch_cnt = torch.tensor([5, 5], dtype=torch.int32).cuda()
[-2.0289, 2.4952, -0.1708]], device=device)
new_xyz_batch_cnt = torch.tensor([5, 5], dtype=torch.int32, device=device)
xyz = torch.tensor([[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634],
[-0.4003, 2.4666, -0.5116], [-0.5251, 2.4379, -0.8466],
[-0.9691, 1.1418, -1.3733], [-0.2232, 0.9561, -1.3626],
Expand All @@ -86,15 +94,14 @@ def test_stack_ball_query():
[-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610],
[0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791],
[-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947],
[0.3220, 1.4447, 0.3548], [-0.9744, 2.3856,
-1.2000]]).cuda()
xyz_batch_cnt = torch.tensor([10, 10], dtype=torch.int32).cuda()
[0.3220, 1.4447, 0.3548], [-0.9744, 2.3856, -1.2000]], device=device)
xyz_batch_cnt = torch.tensor([10, 10], dtype=torch.int32, device=device)
idx = ball_query(0, 0.2, 5, xyz, new_xyz, xyz_batch_cnt, new_xyz_batch_cnt)
expected_idx = torch.tensor([[0, 0, 0, 0, 0], [6, 6, 6, 6, 6],
[2, 2, 2, 2, 2], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
[2, 2, 2, 2, 2], [7, 7, 7, 7, 7],
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]).cuda()
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], device=device)
assert torch.all(idx == expected_idx)

xyz = xyz.double()
Expand Down

0 comments on commit 2f50ddd

Please sign in to comment.