From 2f50dddd2bda459885b3d4d95bae3751fce51dc4 Mon Sep 17 00:00:00 2001 From: sunyinlei Date: Thu, 30 Nov 2023 22:56:25 +0800 Subject: [PATCH] add stack_ball_query opp --- docs/en/understand_mmcv/ops.md | 2 +- docs/zh_cn/understand_mmcv/ops.md | 2 +- .../csrc/pytorch/npu/stack_ball_query_npu.cpp | 23 +++++++++++++++++ tests/test_ops/test_ball_query.py | 25 ++++++++++++------- 4 files changed, 41 insertions(+), 11 deletions(-) create mode 100644 mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 76efe288c7..0bf6d76d71 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -6,7 +6,7 @@ We implement common ops used in detection, segmentation, etc. | ---------------------------- | --- | ---- | --- | --- | ------ | | ActiveRotatedFilter | √ | √ | | | √ | | AssignScoreWithK | | √ | | | | -| BallQuery | | √ | √ | | | +| BallQuery | | √ | √ | | √ | | BBoxOverlaps | | √ | √ | √ | √ | | BorderAlign | | √ | | | | | BoxIouRotated | √ | √ | √ | | √ | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 5998d4e6b4..0d80270082 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -6,7 +6,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | ---------------------------- | --- | ---- | --- | --- | ------ | | ActiveRotatedFilter | √ | √ | | | √ | | AssignScoreWithK | | √ | | | | -| BallQuery | | √ | √ | | | +| BallQuery | | √ | √ | | √ | | BBoxOverlaps | | √ | √ | √ | √ | | BorderAlign | | √ | | | | | BoxIouRotated | √ | √ | √ | | √ | diff --git a/mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp b/mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp new file mode 100644 index 0000000000..cd8c3ad8c9 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp @@ -0,0 +1,23 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void stack_ball_query_forward_npu(float max_radius, int nsample, + const Tensor new_xyz, + const Tensor new_xyz_batch_cnt, + const Tensor xyz, const Tensor xyz_batch_cnt, + Tensor idx) { + at::Tensor xyz_transpose = xyz.transpose(0, 1).contiguous(); + double max_radius_double = double(max_radius); + EXEC_NPU_CMD(aclnnStackBallQuery, xyz_transpose, new_xyz, xyz_batch_cnt, + new_xyz_batch_cnt, max_radius_double, nsample, idx); +} + +void stack_ball_query_forward_impl(float max_radius, int nsample, + const Tensor new_xyz, + const Tensor new_xyz_batch_cnt, + const Tensor xyz, const Tensor xyz_batch_cnt, + Tensor idx); + +REGISTER_NPU_IMPL(stack_ball_query_forward_impl, stack_ball_query_forward_npu); diff --git a/tests/test_ops/test_ball_query.py b/tests/test_ops/test_ball_query.py index 25899f2e1f..53da2166cf 100644 --- a/tests/test_ops/test_ball_query.py +++ b/tests/test_ops/test_ball_query.py @@ -63,9 +63,17 @@ def test_ball_query(device): assert torch.all(idx == expected_idx) -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') -def test_stack_ball_query(): +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) +]) +def test_stack_ball_query(device): new_xyz = torch.tensor([[-0.0740, 1.3147, -1.3625], [-2.2769, 2.7817, -0.2334], [-0.4003, 2.4666, -0.5116], @@ -75,8 +83,8 @@ def test_stack_ball_query(): [-2.0668, 6.0278, -0.4875], [0.4066, 1.4211, -0.2947], [-2.0289, 2.4952, -0.1708], - [-2.0289, 2.4952, -0.1708]]).cuda() - new_xyz_batch_cnt = torch.tensor([5, 5], dtype=torch.int32).cuda() + [-2.0289, 2.4952, -0.1708]], device=device) + new_xyz_batch_cnt = torch.tensor([5, 5], dtype=torch.int32, device=device) xyz = torch.tensor([[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634], [-0.4003, 2.4666, -0.5116], [-0.5251, 2.4379, -0.8466], [-0.9691, 1.1418, -1.3733], [-0.2232, 0.9561, -1.3626], @@ -86,15 +94,14 @@ def test_stack_ball_query(): [-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610], [0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791], [-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947], - [0.3220, 1.4447, 0.3548], [-0.9744, 2.3856, - -1.2000]]).cuda() - xyz_batch_cnt = torch.tensor([10, 10], dtype=torch.int32).cuda() + [0.3220, 1.4447, 0.3548], [-0.9744, 2.3856, -1.2000]], device=device) + xyz_batch_cnt = torch.tensor([10, 10], dtype=torch.int32, device=device) idx = ball_query(0, 0.2, 5, xyz, new_xyz, xyz_batch_cnt, new_xyz_batch_cnt) expected_idx = torch.tensor([[0, 0, 0, 0, 0], [6, 6, 6, 6, 6], [2, 2, 2, 2, 2], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [2, 2, 2, 2, 2], [7, 7, 7, 7, 7], - [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]).cuda() + [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], device=device) assert torch.all(idx == expected_idx) xyz = xyz.double()