diff --git a/examples/02_detectron2/modeling/proposal_generator/rpn.py b/examples/02_detectron2/modeling/proposal_generator/rpn.py index ce7a0f2bc..545105936 100644 --- a/examples/02_detectron2/modeling/proposal_generator/rpn.py +++ b/examples/02_detectron2/modeling/proposal_generator/rpn.py @@ -123,7 +123,7 @@ def forward(self, features): for rois, logit in zip(pred_rois, pred_logits): rois = ops.reshape()(rois, [N, -1, 4]) if self.topk > 0 and rois.shape()[1].value() > self.topk: - score_inds = ops.topk(k=self.topk)(ops.reshape()(logit, [N, -1])) + _, score_inds = ops.topk(k=self.topk)(ops.reshape()(logit, [N, -1])) boxes_topk = ops.batch_gather()(rois, score_inds) scores_topk = ops.batch_gather()( ops.reshape()(logit, [N, -1, 1]), score_inds diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index 9058b38da..66aaff6c5 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -743,17 +743,8 @@ def acc_ops_topk( if sorted is not None: logger.warning("Ignoring the value of 'sorted': %s", sorted) - result_indices = topk(k=k)(input_val) - # current AIT implementation only returns indices. to match the torch topk return types, create dummy values - # - # TODO remove the hard coded dtype below, once we know whether AIT will support fp32 (thus providing an option of - # fp16 or fp32 for values) - return ( - AITTensor( - shape=result_indices.shape(), dtype="float16", name=f"{name}_result_values" - ), - result_indices, - ) + result = topk(k=k)(input_val) + return result @ait_converter(acc_ops.tuple_construct) diff --git a/fx2ait/fx2ait/test/converters/test_ait_topk.py b/fx2ait/fx2ait/test/converters/test_ait_topk.py index 0da25aac0..be04e214b 100644 --- a/fx2ait/fx2ait/test/converters/test_ait_topk.py +++ b/fx2ait/fx2ait/test/converters/test_ait_topk.py @@ -32,8 +32,7 @@ class TestTopkConverter(AITTestCase): def test_simple(self, input: List[int], k: int) -> None: class TestModule(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: - values, indices = torch.topk(x, k) - return indices + return torch.topk(x, k) model = TestModule().cuda() inputs = [ diff --git a/python/aitemplate/backend/common/tensor/topk_common.py b/python/aitemplate/backend/common/tensor/topk_common.py index f546795d1..e6d89d714 100644 --- a/python/aitemplate/backend/common/tensor/topk_common.py +++ b/python/aitemplate/backend/common/tensor/topk_common.py @@ -37,7 +37,7 @@ {{func_signature}} { - topk_launcher<{{dtype}}>(stream, elem_cnt, instance_size, instance_num, top_k, input, workspace, output); + topk_launcher<{{dtype}}>(stream, elem_cnt, instance_size, instance_num, top_k, input, workspace, output_index, output_value); } """ ) @@ -74,7 +74,8 @@ FUNC_SIGNATURE = jinja2.Template( """ -void {{func_name}}(int64_t* output, +void {{func_name}}(int64_t* output_index, + void* output_value, const void* input, const {{index_type}} elem_cnt, const {{index_type}} instance_size, @@ -94,7 +95,9 @@ FUNC_CALL_TEMPLATE = jinja2.Template( """ {{indent}}{{func_name}}( -{{indent}} {{output}}, {{input}}, +{{indent}} {{output_index}}, +{{indent}} {{output_value}}, +{{indent}} {{input}}, {{indent}} {{elem_cnt}}, {{indent}} {{instance_size}}, {{indent}} {{instance_num}}, @@ -508,7 +511,8 @@ class TmpBufferManager final { const int64_t heap_size, const int64_t init_index, const T init_value, - int64_t* out_ptr) { + int64_t* out_index_ptr, + T* out_value_ptr) { extern __shared__ char smem[]; auto* shared_entries = reinterpret_cast*>(smem); @@ -539,7 +543,8 @@ class TmpBufferManager final { // Write top_k elements in sorted array to output for (int64_t i = threadIdx.x; i < k; i += blockDim.x) { - (out_ptr + blockIdx.x * k)[i] = shared_entries[i].GetIndex(); + (out_index_ptr + blockIdx.x * k)[i] = shared_entries[i].GetIndex(); + (out_value_ptr + blockIdx.x * k)[i] = shared_entries[i].GetValue(); } } // ALIGNPTR @@ -566,7 +571,8 @@ class TmpBufferManager final { const int top_k, const void* input, void* workspace, - void* output) { + void* output_index, + void* output_value) { const int32_t k = std::min(top_k, instance_size); if (top_k < 100) { @@ -593,7 +599,8 @@ class TmpBufferManager final { heap_size, std::numeric_limits::max(), NumericTraits::min(), - (int64_t*)output); + (int64_t*)output_index, + (T*)output_value); } else { const uintptr_t ALIGNMENT = 32; @@ -621,7 +628,7 @@ class TmpBufferManager final { stream); {{prefix}}Memcpy2DAsync( - (int64_t*)output, + (int64_t*)output_index, k * sizeof(int64_t), buf_manager.SortedIndicesPtr(), instance_size * sizeof(int64_t), @@ -629,6 +636,16 @@ class TmpBufferManager final { instance_num, {{prefix}}MemcpyDefault, stream); + + {{prefix}}Memcpy2DAsync( + (T*)output_value, + k * sizeof(T), + buf_manager.SortedInPtr(), + instance_size * sizeof(T), + k * sizeof(T), + instance_num, + {{prefix}}MemcpyDefault, + stream); } } """ @@ -706,12 +723,12 @@ def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent=" ") -> str Rendered function call. """ - output_name = "" - assert len(func_attrs["outputs"]) == 1 + assert len(func_attrs["outputs"]) == 2 assert len(func_attrs["inputs"]) == 1 - output_name = FUNC_CALL_INT64_PARAM_TEMPLATE.render( - name=func_attrs["outputs"][0]._attrs["name"] + output_value_name = func_attrs["outputs"][0]._attrs["name"] + output_index_name = FUNC_CALL_INT64_PARAM_TEMPLATE.render( + name=func_attrs["outputs"][1]._attrs["name"] ) input_name = func_attrs["inputs"][0]._attrs["name"] @@ -726,7 +743,8 @@ def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent=" ") -> return FUNC_CALL_TEMPLATE.render( func_name=func_attrs["name"], - output=output_name, + output_index=output_index_name, + output_value=output_value_name, input=input_name, elem_cnt=elem_cnt, instance_size=instance_size, diff --git a/python/aitemplate/compiler/ops/tensor/topk.py b/python/aitemplate/compiler/ops/tensor/topk.py index 6a3cdfbe7..fb058751a 100644 --- a/python/aitemplate/compiler/ops/tensor/topk.py +++ b/python/aitemplate/compiler/ops/tensor/topk.py @@ -59,8 +59,8 @@ class topk(Operator): .. code-block:: python X = Tensor(shape=[2, 800], name="X", is_input=True) - Y = ops.topk(k=300)(X) - y_shape = [d._attrs["values"][0] for d in Y.shape()] + value, indice = ops.topk(k=300)(X) + y_shape = [d._attrs["values"][0] for d in indice.shape()] print(y_shape) Outs: @@ -87,8 +87,10 @@ def __call__(self, x: Tensor) -> Tensor: self._set_depth() output_shape = self._infer_shapes(x) self._extract_exec_path(x) - output = Tensor(output_shape, src_ops={self}, dtype="int64") - self._attrs["outputs"] = [output] + output_index = Tensor(output_shape, src_ops={self}, dtype="int64") + output_value = Tensor(output_shape, src_ops={self}, dtype=x._attrs["dtype"]) + output = (output_value, output_index) + self._attrs["outputs"] = [output_value, output_index] return output def _get_op_attributes(self): diff --git a/tests/unittest/ops/test_batch_gather.py b/tests/unittest/ops/test_batch_gather.py index 4c210af1a..a62484130 100644 --- a/tests/unittest/ops/test_batch_gather.py +++ b/tests/unittest/ops/test_batch_gather.py @@ -149,7 +149,7 @@ def _test_batch_gather_topk( name="scores", is_input=True, ) - X3 = ops.topk(k=topK)(X2) + _, X3 = ops.topk(k=topK)(X2) X4 = ops.batch_gather()(X1, X3) X4._attrs["is_output"] = True X4._attrs["name"] = "output" diff --git a/tests/unittest/ops/test_nms.py b/tests/unittest/ops/test_nms.py index 430967af3..e6cdd6b4b 100644 --- a/tests/unittest/ops/test_nms.py +++ b/tests/unittest/ops/test_nms.py @@ -205,7 +205,7 @@ def model(): name="scores", is_input=True, ) - score_inds = ops.topk(k=topK)(X_scores) + _, score_inds = ops.topk(k=topK)(X_scores) bboxes = ops.batch_gather()(X_boxes, score_inds) OP = ops.batched_nms(iou_threshold=iou, keep_n=N) if copy_op: diff --git a/tests/unittest/ops/test_topk.py b/tests/unittest/ops/test_topk.py index 3a3353d02..cb62e7e6b 100644 --- a/tests/unittest/ops/test_topk.py +++ b/tests/unittest/ops/test_topk.py @@ -56,23 +56,37 @@ def _test_topk( name="X", is_input=True, ) + X5 = Tensor( + shape=shape, + dtype=dtype, + name="Y", + is_input=True, + ) OP = ops.topk(k=topK) if copy_op: OP = ops.topk(**OP._get_op_attributes()) - X4 = OP(X1) + X4, X5 = OP(X1) + X4._attrs["is_output"] = True X4._attrs["is_output"] = True X4._attrs["name"] = "output" + X5._attrs["is_output"] = True + X5._attrs["is_output"] = True + X5._attrs["name"] = "output2" target = detect_target() - module = compile_model(X4, target, "./tmp", f"{test_name}_{self.test_count}") + module = compile_model( + (X4, X5), target, "./tmp", f"{test_name}_{self.test_count}" + ) scores = self._create_tensors(shape, dtype) (values, y_pt) = torch.topk(scores, k=topK, dim=dim) - + torch_dtype = string_to_torch_dtype(dtype) x = scores.reshape(shape).contiguous() - y = torch.empty(o_shape).cuda().to(torch.int64) - module.run_with_tensors([x], [y]) - torch.testing.assert_close(y_pt, y, atol=0, rtol=0) + y2 = torch.empty(o_shape).cuda().to(torch.int64) + y = torch.empty(o_shape).cuda().to(torch_dtype) + module.run_with_tensors([x], [y, y2]) + torch.testing.assert_close(values, y, atol=0, rtol=0) + torch.testing.assert_close(y_pt, y2, atol=0, rtol=0) self.test_count += 1 def test_topk_heap(self):