Skip to content

Commit

Permalink
Fix AIT topk converter (facebookincubator#631)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#631

AIT topk was return only the indice tensor. Adding the value tensor now to make it match torch.topk behavior.

Reviewed By: terrychenism

Differential Revision: D45334595

fbshipit-source-id: 9d6bd1370928d64963e19371f6b488b94e533a91
  • Loading branch information
wushirong authored and facebook-github-bot committed Apr 28, 2023
1 parent 9487143 commit 1d5a942
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 39 deletions.
2 changes: 1 addition & 1 deletion examples/02_detectron2/modeling/proposal_generator/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def forward(self, features):
for rois, logit in zip(pred_rois, pred_logits):
rois = ops.reshape()(rois, [N, -1, 4])
if self.topk > 0 and rois.shape()[1].value() > self.topk:
score_inds = ops.topk(k=self.topk)(ops.reshape()(logit, [N, -1]))
_, score_inds = ops.topk(k=self.topk)(ops.reshape()(logit, [N, -1]))
boxes_topk = ops.batch_gather()(rois, score_inds)
scores_topk = ops.batch_gather()(
ops.reshape()(logit, [N, -1, 1]), score_inds
Expand Down
13 changes: 2 additions & 11 deletions fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,17 +743,8 @@ def acc_ops_topk(
if sorted is not None:
logger.warning("Ignoring the value of 'sorted': %s", sorted)

result_indices = topk(k=k)(input_val)
# current AIT implementation only returns indices. to match the torch topk return types, create dummy values
#
# TODO remove the hard coded dtype below, once we know whether AIT will support fp32 (thus providing an option of
# fp16 or fp32 for values)
return (
AITTensor(
shape=result_indices.shape(), dtype="float16", name=f"{name}_result_values"
),
result_indices,
)
result = topk(k=k)(input_val)
return result


@ait_converter(acc_ops.tuple_construct)
Expand Down
3 changes: 1 addition & 2 deletions fx2ait/fx2ait/test/converters/test_ait_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ class TestTopkConverter(AITTestCase):
def test_simple(self, input: List[int], k: int) -> None:
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
values, indices = torch.topk(x, k)
return indices
return torch.topk(x, k)

model = TestModule().cuda()
inputs = [
Expand Down
44 changes: 31 additions & 13 deletions python/aitemplate/backend/common/tensor/topk_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
{{func_signature}}
{
topk_launcher<{{dtype}}>(stream, elem_cnt, instance_size, instance_num, top_k, input, workspace, output);
topk_launcher<{{dtype}}>(stream, elem_cnt, instance_size, instance_num, top_k, input, workspace, output_index, output_value);
}
"""
)
Expand Down Expand Up @@ -74,7 +74,8 @@

FUNC_SIGNATURE = jinja2.Template(
"""
void {{func_name}}(int64_t* output,
void {{func_name}}(int64_t* output_index,
void* output_value,
const void* input,
const {{index_type}} elem_cnt,
const {{index_type}} instance_size,
Expand All @@ -94,7 +95,9 @@
FUNC_CALL_TEMPLATE = jinja2.Template(
"""
{{indent}}{{func_name}}(
{{indent}} {{output}}, {{input}},
{{indent}} {{output_index}},
{{indent}} {{output_value}},
{{indent}} {{input}},
{{indent}} {{elem_cnt}},
{{indent}} {{instance_size}},
{{indent}} {{instance_num}},
Expand Down Expand Up @@ -508,7 +511,8 @@ class TmpBufferManager final {
const int64_t heap_size,
const int64_t init_index,
const T init_value,
int64_t* out_ptr) {
int64_t* out_index_ptr,
T* out_value_ptr) {
extern __shared__ char smem[];
auto* shared_entries = reinterpret_cast<Entry<T>*>(smem);
Expand Down Expand Up @@ -539,7 +543,8 @@ class TmpBufferManager final {
// Write top_k elements in sorted array to output
for (int64_t i = threadIdx.x; i < k; i += blockDim.x) {
(out_ptr + blockIdx.x * k)[i] = shared_entries[i].GetIndex();
(out_index_ptr + blockIdx.x * k)[i] = shared_entries[i].GetIndex();
(out_value_ptr + blockIdx.x * k)[i] = shared_entries[i].GetValue();
}
}
// ALIGNPTR
Expand All @@ -566,7 +571,8 @@ class TmpBufferManager final {
const int top_k,
const void* input,
void* workspace,
void* output) {
void* output_index,
void* output_value) {
const int32_t k = std::min(top_k, instance_size);
if (top_k < 100) {
Expand All @@ -593,7 +599,8 @@ class TmpBufferManager final {
heap_size,
std::numeric_limits<int64_t>::max(),
NumericTraits<T>::min(),
(int64_t*)output);
(int64_t*)output_index,
(T*)output_value);
} else {
const uintptr_t ALIGNMENT = 32;
Expand Down Expand Up @@ -621,14 +628,24 @@ class TmpBufferManager final {
stream);
{{prefix}}Memcpy2DAsync(
(int64_t*)output,
(int64_t*)output_index,
k * sizeof(int64_t),
buf_manager.SortedIndicesPtr(),
instance_size * sizeof(int64_t),
k * sizeof(int64_t),
instance_num,
{{prefix}}MemcpyDefault,
stream);
{{prefix}}Memcpy2DAsync(
(T*)output_value,
k * sizeof(T),
buf_manager.SortedInPtr(),
instance_size * sizeof(T),
k * sizeof(T),
instance_num,
{{prefix}}MemcpyDefault,
stream);
}
}
"""
Expand Down Expand Up @@ -706,12 +723,12 @@ def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent=" ") ->
str
Rendered function call.
"""
output_name = ""
assert len(func_attrs["outputs"]) == 1
assert len(func_attrs["outputs"]) == 2
assert len(func_attrs["inputs"]) == 1

output_name = FUNC_CALL_INT64_PARAM_TEMPLATE.render(
name=func_attrs["outputs"][0]._attrs["name"]
output_value_name = func_attrs["outputs"][0]._attrs["name"]
output_index_name = FUNC_CALL_INT64_PARAM_TEMPLATE.render(
name=func_attrs["outputs"][1]._attrs["name"]
)
input_name = func_attrs["inputs"][0]._attrs["name"]

Expand All @@ -726,7 +743,8 @@ def gen_function_call(func_attrs: Dict[str, Any], backend_spec, indent=" ") ->

return FUNC_CALL_TEMPLATE.render(
func_name=func_attrs["name"],
output=output_name,
output_index=output_index_name,
output_value=output_value_name,
input=input_name,
elem_cnt=elem_cnt,
instance_size=instance_size,
Expand Down
10 changes: 6 additions & 4 deletions python/aitemplate/compiler/ops/tensor/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class topk(Operator):
.. code-block:: python
X = Tensor(shape=[2, 800], name="X", is_input=True)
Y = ops.topk(k=300)(X)
y_shape = [d._attrs["values"][0] for d in Y.shape()]
value, indice = ops.topk(k=300)(X)
y_shape = [d._attrs["values"][0] for d in indice.shape()]
print(y_shape)
Outs:
Expand All @@ -87,8 +87,10 @@ def __call__(self, x: Tensor) -> Tensor:
self._set_depth()
output_shape = self._infer_shapes(x)
self._extract_exec_path(x)
output = Tensor(output_shape, src_ops={self}, dtype="int64")
self._attrs["outputs"] = [output]
output_index = Tensor(output_shape, src_ops={self}, dtype="int64")
output_value = Tensor(output_shape, src_ops={self}, dtype=x._attrs["dtype"])
output = (output_value, output_index)
self._attrs["outputs"] = [output_value, output_index]
return output

def _get_op_attributes(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/unittest/ops/test_batch_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _test_batch_gather_topk(
name="scores",
is_input=True,
)
X3 = ops.topk(k=topK)(X2)
_, X3 = ops.topk(k=topK)(X2)
X4 = ops.batch_gather()(X1, X3)
X4._attrs["is_output"] = True
X4._attrs["name"] = "output"
Expand Down
2 changes: 1 addition & 1 deletion tests/unittest/ops/test_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def model():
name="scores",
is_input=True,
)
score_inds = ops.topk(k=topK)(X_scores)
_, score_inds = ops.topk(k=topK)(X_scores)
bboxes = ops.batch_gather()(X_boxes, score_inds)
OP = ops.batched_nms(iou_threshold=iou, keep_n=N)
if copy_op:
Expand Down
26 changes: 20 additions & 6 deletions tests/unittest/ops/test_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,37 @@ def _test_topk(
name="X",
is_input=True,
)
X5 = Tensor(
shape=shape,
dtype=dtype,
name="Y",
is_input=True,
)
OP = ops.topk(k=topK)
if copy_op:
OP = ops.topk(**OP._get_op_attributes())
X4 = OP(X1)
X4, X5 = OP(X1)
X4._attrs["is_output"] = True
X4._attrs["is_output"] = True
X4._attrs["name"] = "output"
X5._attrs["is_output"] = True
X5._attrs["is_output"] = True
X5._attrs["name"] = "output2"

target = detect_target()
module = compile_model(X4, target, "./tmp", f"{test_name}_{self.test_count}")
module = compile_model(
(X4, X5), target, "./tmp", f"{test_name}_{self.test_count}"
)

scores = self._create_tensors(shape, dtype)
(values, y_pt) = torch.topk(scores, k=topK, dim=dim)

torch_dtype = string_to_torch_dtype(dtype)
x = scores.reshape(shape).contiguous()
y = torch.empty(o_shape).cuda().to(torch.int64)
module.run_with_tensors([x], [y])
torch.testing.assert_close(y_pt, y, atol=0, rtol=0)
y2 = torch.empty(o_shape).cuda().to(torch.int64)
y = torch.empty(o_shape).cuda().to(torch_dtype)
module.run_with_tensors([x], [y, y2])
torch.testing.assert_close(values, y, atol=0, rtol=0)
torch.testing.assert_close(y_pt, y2, atol=0, rtol=0)
self.test_count += 1

def test_topk_heap(self):
Expand Down

0 comments on commit 1d5a942

Please sign in to comment.