From 60769b77f9abe29aafabda4d5d1cd625e7c61f9f Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Mon, 13 Aug 2018 09:35:46 -0700 Subject: [PATCH] Fixed bugs for SSD sorting and multbox detection (#1578) --- topi/python/topi/cuda/nms.py | 480 ++++++++++++++++++++------ topi/python/topi/cuda/ssd/multibox.py | 225 ++++++++---- 2 files changed, 534 insertions(+), 171 deletions(-) diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 4d4e402de5c2..361208bf1cfb 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -7,19 +7,155 @@ from topi.vision import nms -def sort_ir(data, index, output, axis, is_descend): - """Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. +def sort_pre_ir(index, sizes_out, axis_mul_before, axis_mul_after): + """Low level IR routing subfunction 1/4 for computing segments' staring locatons. + + Parameters + ---------- + index : Buffer + Buffer of number of valid output boxes. + + sizes_out : Buffer + Output buffer of start locations of each sorting segment. + + axis_mul_before : int + The multiplication result of axis dimensions before axis. + + axis_mul_after : int + The multiplication result of axis dimensions after axis. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + max_threads = int( + tvm.target.current_target(allow_none=False).max_num_threads) + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib = tvm.ir_builder.create() + p_index = ib.buffer_ptr(index) + dshape = sizes_out.shape + sizes = ib.buffer_ptr(sizes_out) + nthread_tx = max_threads + nthread_bx = dshape[0] // max_threads + 1 + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + + with ib.if_scope(tid < axis_mul_before * axis_mul_after): + sizes[tid] = p_index[tid] + + # scan + with ib.if_scope(tid < 1): + with ib.for_range(0, axis_mul_before * axis_mul_after - 1, name="k") as k: + sizes[k + 1] += sizes[k] + body = ib.get() + return body + + +def sort_pre_ir_data(data, index, sizes_in, data_out, index_out, \ + axis, axis_mul_before, axis_mul_after): + """Low level IR routing subfunction 2/4 for flattening data and indices into segmented format. Parameters ---------- data: Buffer - 2D Buffer of input boxes' score with shape [batch_size, num_anchors]. + Buffer of output boxes with class and score. index : Buffer - Buffer of number of valid number of boxes. + Buffer of number of valid output boxes. - output : Buffer - Output buffer of indicies of sorted tensor. + sizes_in : Buffer + Buffer of start locations of each sorting segment. + + data_out : Buffer + Buffer of flattened segmented data. + + index_out : Buffer + Buffer of flattened segmented indices. + + axis : int + The axis used for sorting. + + axis_mul_before : int + The multiplication result of axis dimensions before axis. + + axis_mul_after : int + The multiplication result of axis dimensions after axis. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + ib = tvm.ir_builder.create() + sizes = ib.buffer_ptr(sizes_in) + p_index = ib.buffer_ptr(index) + p_data = ib.buffer_ptr(data) + data_new = ib.buffer_ptr(data_out) + index_new = ib.buffer_ptr(index_out) + max_threads = int( + tvm.target.current_target(allow_none=False).max_num_threads) + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + dshape = tvm.max(sizes_in.shape[0], p_index[0]) + nthread_tx = max_threads + nthread_bx = dshape // max_threads + 1 + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(axis_mul_before * axis_mul_after > 1): + with ib.if_scope(tid < axis_mul_before * axis_mul_after): + i = tid / axis_mul_after + j = tid % axis_mul_after + current_sort_num = p_index[tid] + base_idx = i * data.shape[axis] * axis_mul_after + j + with ib.for_range(0, current_sort_num, name="k") as k: + full_idx = base_idx + k * axis_mul_after + with ib.if_scope(tid == 0): + start = 0 + with ib.else_scope(): + start = sizes[tid-1] + index_new[start + k] = k + data_new[start + k] = p_data[full_idx] + with ib.else_scope(): + with ib.if_scope(tid == 0): + with ib.for_range(0, p_index[0], name="k") as k: + index_new[k] = k + + body = ib.get() + return body + +def sort_oet_ir(data, index, new_data, new_index, loc, out_index, axis_mul_before, \ + axis_mul_after, axis, is_descend): + """Low level IR routing subfunction 3/4 for Odd-Even-Transposition sorting. + + Parameters + ---------- + data: Buffer + Buffer of output boxes with class and score. + + index : Buffer + Buffer of number of valid output boxes. + + new_data : Buffer + Buffer of flattened segmented data. + + new_index : Buffer + Buffer of flattened segmented indices. + + loc : Buffer + Buffer of start locations of each sorting segment. + + out_index : Buffer + Output buffer of output box indexes sorted by score in a flattened segmented format. + + axis_mul_before : int + The multiplication result of axis dimensions before axis. + + axis_mul_after : int + The multiplication result of axis dimensions after axis. axis : int The axis used for sorting. @@ -32,15 +168,197 @@ def sort_ir(data, index, output, axis, is_descend): stmt : Stmt The result IR statement. """ - max_threads = int( tvm.target.current_target(allow_none=False).max_num_threads) tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") ib = tvm.ir_builder.create() + dshape = loc.shape + fshape = data.shape[axis] * dshape[0] + temp_data = ib.allocate( + "float32", dshape, name="temp_data", scope="local") p_data = ib.buffer_ptr(data) p_index = ib.buffer_ptr(index) + data_new = ib.buffer_ptr(new_data) + index_new = ib.buffer_ptr(new_index) + index_out = ib.buffer_ptr(out_index) + sizes = ib.buffer_ptr(loc) + nthread_tx = max_threads + nthread_bx = fshape // max_threads + 1 + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + + with ib.if_scope(axis_mul_before * axis_mul_after > 1): + with ib.if_scope(tid < axis_mul_before * axis_mul_after): + with ib.if_scope(tid == 0): + start = 0 + with ib.else_scope(): + start = sizes[tid-1] + # OddEvenTransposeSort + with ib.for_range(0, p_index[tid], name="k") as k: + with ib.for_range(0, p_index[tid] - 1, name="i") as i: + with ib.if_scope(i % 2 == k % 2): + with ib.if_scope(((data_new[i+start] < data_new[i+start+1]) == is_descend)): + temp_data[tid] = data_new[i+start] + data_new[i+start] = data_new[i+start+1] + data_new[i+start+1] = temp_data[tid] + index_out[tid] = index_new[i+start] + index_new[i+start] = index_new[i+start+1] + index_new[i+start+1] = index_out[tid] + with ib.if_scope(tid < 1): + with ib.for_range(0, sizes[dshape[0] - 1], name="i") as i: + index_out[i] = index_new[i] + with ib.else_scope(): + with ib.for_range(0, fshape, name="k", for_type="unroll") as k: + with ib.if_scope(tvm.all(k % 2 == tid % 2, tid < fshape)): + with ib.if_scope(k % 2 == 0): + with ib.if_scope(tvm.all(tid + 1 < fshape, (p_data[tid] < p_data[tid+1]) \ + == is_descend)): + data_new[tid] = p_data[tid+1] + index_out[tid] = index_new[tid+1] + with ib.else_scope(): + data_new[tid] = p_data[tid] + index_out[tid] = index_new[tid] + with ib.else_scope(): + with ib.if_scope(tvm.all(tid + 1 < fshape, (data_new[tid] < data_new[tid+1]) \ + == is_descend)): + p_data[tid] = data_new[tid+1] + index_new[tid] = index_out[tid+1] + with ib.else_scope(): + p_data[tid] = data_new[tid] + index_new[tid] = index_out[tid] + with ib.if_scope(tvm.all(k % 2 != tid % 2, tid < fshape)): + with ib.if_scope(k % 2 == 0): + with ib.if_scope(tvm.all(tid > 0, (p_data[tid-1] < p_data[tid]) == is_descend)): + data_new[tid] = p_data[tid-1] + index_out[tid] = index_new[tid-1] + with ib.else_scope(): + data_new[tid] = p_data[tid] + index_out[tid] = index_new[tid] + with ib.else_scope(): + with ib.if_scope(tvm.all(tid > 0, (data_new[tid-1] < data_new[tid]) \ + == is_descend)): + p_data[tid] = data_new[tid-1] + index_new[tid] = index_out[tid-1] + with ib.else_scope(): + p_data[tid] = data_new[tid] + index_new[tid] = index_out[tid] + with ib.if_scope(fshape % 2 == 1): + with ib.if_scope(tid < 1): + with ib.for_range(0, fshape, name="k") as k: + index_out[tid] = index_new[tid] + body = ib.get() + return body + + +def sort_ir_out(data, index, new_index, loc, output, axis_mul_before, axis_mul_after, axis): + """Low level IR routing subfunction 4/4 for writing sorted indices to output format. + + Parameters + ---------- + data: Buffer + Buffer of output boxes with class and score. + + index : Buffer + Buffer of number of valid output boxes. + + new_index : Buffer + Buffer of sorted indices in a flatten format. + + loc : Buffer + Buffer of start locations of each sorting segment. + + output : Buffer + Output buffer of output box indexes sorted by score. + + axis_mul_before : int + The multiplication result of axis dimensions before axis. + + axis_mul_after : int + The multiplication result of axis dimensions after axis. + + axis : int + The axis used for sorting. + + is_descend : bool + If the sorted data is in descending order. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + max_threads = int( + tvm.target.current_target(allow_none=False).max_num_threads) + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib = tvm.ir_builder.create() + dshape = tvm.max(loc.shape[0], data.shape[axis]) + p_index = ib.buffer_ptr(index) + index_new = ib.buffer_ptr(new_index) + sizes = ib.buffer_ptr(loc) p_out = ib.buffer_ptr(output) + nthread_tx = max_threads + nthread_bx = dshape // max_threads + 1 + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + + with ib.if_scope(axis_mul_before * axis_mul_after > 1): + with ib.if_scope(tid < axis_mul_before * axis_mul_after): + i = tid / axis_mul_after + j = tid % axis_mul_after + base_idx = i * data.shape[axis] * axis_mul_after + j + with ib.for_range(0, data.shape[axis], name="k") as k: + with ib.if_scope(tid == 0): + start = 0 + with ib.else_scope(): + start = sizes[tid-1] + p_out[base_idx + k * axis_mul_after] = tvm.select( + k < p_index[tid], index_new[k+start], k) + with ib.else_scope(): + with ib.if_scope(tid < data.shape[axis]): + p_out[tid] = tvm.select(tid < p_index[0], index_new[tid], tid) + + body = ib.get() + return body + + +def sort_gpu(data, data_buf, index, index_buf, output_buf, axis, is_descend): + """Function to generate low level IR to do sorting on the GPU, use it by calling sort_gpu. + + Parameters + ---------- + data: tvm.Tensor + 3-D tensor with shape [batch_size, num_anchors, 6]. + The last dimension should be in format of + [class_id, score, box_left, box_top, box_right, box_bottom]. + + data_buf: Buffer + 2D Buffer of input boxes' score with shape [batch_size, num_anchors]. + + index : tvm.Tensor + 1-D tensor for valid number of boxes. + + index_buf : Buffer + Buffer of number of valid number of boxes. + + output_buf : Buffer + Output buffer of indicies of sorted tensor. + + axis : int + The axis used for sorting. + + is_descend : bool + If the sorted data is in descending order. + + Returns + ------- + out : tvm.Tensor + 3-D tensor with shape [batch_size, num_anchors]. + """ + ndim = len(data.shape) assert data.dtype == "float32", "Currently only supports input dtype to be float32" assert axis < ndim, "Axis out of boundary for input ndim %d" % ndim @@ -55,89 +373,60 @@ def sort_ir(data, index, output, axis, is_descend): elif i > axis: axis_mul_after *= data.shape[i] - dshape = 0 - for i in range(0, len(index.shape)): - dshape += index.shape[i] - dshape = tvm.select(dshape > axis_mul_before*axis_mul_after, dshape, - axis_mul_before*axis_mul_after) - - sizes_temp = ib.allocate( - "int32", dshape, name="sizes_temp", scope="global") - sizes = ib.allocate("int32", dshape, name="sizes", scope="global") - temp_index = ib.allocate("int32", dshape, name="temp_index", scope="local") - temp_data = ib.allocate("float32", dshape, name="temp_data", scope="local") - data_new = ib.allocate("float32", dshape, name="data_new", scope="global") - index_new = ib.allocate("int32", dshape, name="index_new", scope="global") - nthread_tx = max_threads - nthread_bx = dshape // max_threads + 1 - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx - - with ib.if_scope(tid < axis_mul_before * axis_mul_after): - sizes[tid] = p_index[tid] - sizes_temp[tid] = p_index[tid] - - with ib.if_scope(tid < axis_mul_before * axis_mul_after): - with ib.for_range(0, tvm.floor(tvm.sqrt((axis_mul_before * axis_mul_after) \ - .astype("float32"))) + 1, name="k") as k: - with ib.if_scope(tid - (tvm.const(1, "int32") << k) >= 0): - with ib.if_scope(k % 2 == 0): - sizes[tid] += sizes_temp[tid - ( - tvm.const(1, "int32") << k)] - sizes_temp[tid] = sizes[tid] - with ib.else_scope(): - sizes_temp[tid] += sizes[tid - ( - tvm.const(1, "int32") << k)] - sizes[tid] = sizes_temp[tid] - - with ib.if_scope(tid < axis_mul_before * axis_mul_after): - i = tid / axis_mul_after - j = tid % axis_mul_after - current_sort_num = p_index[tid] - base_idx = i * data.shape[axis] * axis_mul_after + j - with ib.for_range(0, current_sort_num, name="k") as k: - full_idx = base_idx + k * axis_mul_after - with ib.if_scope(tid == 0): - start = 0 - with ib.else_scope(): - start = sizes[tid-1] - index_new[start + k] = k - data_new[start + k] = p_data[full_idx] - - with ib.if_scope(tid < axis_mul_before * axis_mul_after): - with ib.if_scope(tid == 0): - start = 0 - with ib.else_scope(): - start = sizes[tid-1] - # OddEvenTransposeSort - with ib.for_range(0, p_index[tid], name="k") as k: - with ib.for_range(0, p_index[tid] - 1, name="i") as i: - with ib.if_scope(i % 2 == (k & 1)): - with ib.if_scope(((data_new[i+start] < data_new[i+start+1]) ^ - is_descend) == False): - temp_data[tid] = data_new[i+start] - data_new[i+start] = data_new[i+start+1] - data_new[i+start+1] = temp_data[tid] - temp_index[tid] = index_new[i+start] - index_new[i+start] = index_new[i+start+1] - index_new[i+start+1] = temp_index[tid] - - with ib.if_scope(tid < axis_mul_before * axis_mul_after): - i = tid / axis_mul_after - j = tid % axis_mul_after - current_sort_num = p_index[tid] - base_idx = i * data.shape[axis] * axis_mul_after + j - with ib.for_range(0, data.shape[axis], name="k") as k: - with ib.if_scope(tid == 0): - start = 0 - with ib.else_scope(): - start = sizes[tid-1] - p_out[base_idx + k * axis_mul_after] = tvm.select( - k < current_sort_num, - index_new[k+start], k) - body = ib.get() - return body + dshape = axis_mul_before*axis_mul_after + fshape = data.shape[axis] * dshape + + loc_buf = api.decl_buffer(dshape, index.dtype, "sizes", data_alignment=8) + new_index_buf = api.decl_buffer( + fshape, index.dtype, "index_new", data_alignment=8) + out_index_buf = api.decl_buffer( + fshape, index.dtype, "index_out", data_alignment=8) + new_data_buf = api.decl_buffer( + dshape, data.dtype, "data_new", data_alignment=8) + + loc = \ + tvm.extern([(dshape,)], + [index], + lambda ins, outs: sort_pre_ir( + ins[0], outs[0], axis_mul_before, axis_mul_after), + dtype=[index.dtype], + in_buffers=index_buf, + out_buffers=[loc_buf], + tag="sorting_prepare") + + data_new, index_new = \ + tvm.extern([(dshape,), (fshape,)], + [data, index, loc], + lambda ins, outs: sort_pre_ir_data( + ins[0], ins[1], ins[2], outs[0], outs[1], axis, + axis_mul_before, axis_mul_after), + dtype=[data.dtype, index.dtype], + in_buffers=[data_buf, index_buf, loc_buf], + out_buffers=[new_data_buf, new_index_buf], + tag="sorting_data") + + index_out = \ + tvm.extern([(fshape,)], + [data, index, data_new, index_new, loc], + lambda ins, outs: sort_oet_ir( + ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], + axis_mul_before, axis_mul_after, axis, is_descend), + dtype=[index.dtype], + in_buffers=[data_buf, index_buf, + new_data_buf, new_index_buf, loc_buf], + out_buffers=[out_index_buf], + tag="sorting_oet") + out = \ + tvm.extern([data.shape], + [data, index, index_out, loc], + lambda ins, outs: sort_ir_out( + ins[0], ins[1], ins[2], ins[3], outs[0], + axis_mul_before, axis_mul_after, axis), + dtype=[index.dtype], + in_buffers=[data_buf, index_buf, out_index_buf, loc_buf], + out_buffers=output_buf, + tag="sorting_output") + return out def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk): @@ -333,15 +622,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype, "sort_tensor_buf", data_alignment=8) - sort_tensor = \ - tvm.extern(score_shape, - [score_tensor, valid_count], - lambda ins, outs: sort_ir( - ins[0], ins[1], outs[0], score_axis, True), - dtype=sort_tensor_dtype, - in_buffers=[score_tensor_buf, valid_count_buf], - out_buffers=sort_tensor_buf, - name="nms_sort") + sort_tensor = sort_gpu(score_tensor, score_tensor_buf, valid_count, + valid_count_buf, sort_tensor_buf, score_axis, True) out = \ tvm.extern(data.shape, [data, sort_tensor, valid_count], diff --git a/topi/python/topi/cuda/ssd/multibox.py b/topi/python/topi/cuda/ssd/multibox.py index c22e7a513d7d..3c013c4d1605 100644 --- a/topi/python/topi/cuda/ssd/multibox.py +++ b/topi/python/topi/cuda/ssd/multibox.py @@ -1,4 +1,4 @@ -# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements +# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, too-many-function-args """SSD multibox operators""" from __future__ import absolute_import as _abs import math @@ -13,6 +13,7 @@ from topi.vision.ssd import multibox_transform_loc from ..nms import nms + def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): """Low level IR routing for multibox_prior operator. @@ -41,7 +42,8 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): stmt : Stmt The result IR statement. """ - max_threads = int(math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads)) + max_threads = int(math.sqrt( + tvm.target.current_target(allow_none=False).max_num_threads)) tx = tvm.thread_axis("threadIdx.x") ty = tvm.thread_axis("threadIdx.y") bx = tvm.thread_axis("blockIdx.x") @@ -76,7 +78,8 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): for k in range(num_sizes + num_ratios - 1): w = tvm.select(k < num_sizes, - size_ratio_concat[k] * in_height / in_width / 2.0, + size_ratio_concat[ + k] * in_height / in_width / 2.0, size_ratio_concat[0] * in_height / in_width * math.sqrt(size_ratio_concat[k + 1]) / 2.0) h = tvm.select(k < num_sizes, size_ratio_concat[k] / 2.0, @@ -93,7 +96,7 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): @multibox_prior.register(["cuda", "gpu"]) -def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1), \ +def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False): """Generate prior(anchor) boxes from data, sizes and ratios. @@ -124,31 +127,114 @@ def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1), \ """ num_sizes = len(sizes) num_ratios = len(ratios) - oshape = (1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4) + oshape = ( + 1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4) out = tvm.extern(oshape, [data], lambda ins, outs: - multibox_prior_ir(ins[0], outs[0], sizes, ratios, steps, offsets), + multibox_prior_ir( + ins[0], outs[0], sizes, ratios, steps, offsets), tag="multibox_prior") if clip: out = topi.clip(out, 0, 1) return out -def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, threshold, variances): - """Low level IR routing for transform location in multibox_detection operator. +def transform_loc_pre(cls_prob, valid_count, temp_flag, temp_id, temp_score_out, threshold): + """Low level IR routing for transform location data preparation. Parameters ---------- cls_prob : Buffer Buffer of class probabilities. + valid_count : Buffer + Buffer of number of valid output boxes. + + temp_flag : Buffer + Output intermediate result buffer + + temp_id : Buffer + Output intermediate result buffer + + temp_score_out : Buffer + Output buffer + + threshold : float + Threshold to be a positive prediction. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch_size = cls_prob.shape[0] + num_classes = cls_prob.shape[1] + num_anchors = cls_prob.shape[2] + + max_threads = int( + tvm.target.current_target(allow_none=False).max_num_threads) + ib = tvm.ir_builder.create() + score = ib.buffer_ptr(temp_score_out) + cls_id = ib.buffer_ptr(temp_id) + flag = ib.buffer_ptr(temp_flag) + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + nthread_tx = max_threads + nthread_bx = (batch_size * num_anchors * num_classes) // max_threads + 1 + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + p_cls_prob = ib.buffer_ptr(cls_prob) + p_valid_count = ib.buffer_ptr(valid_count) + + with ib.if_scope(tid < batch_size * num_anchors): + n = tid / num_anchors # number of batches + i = tid % num_anchors # number of anchors + score[i] = -1.0 + cls_id[i] = 0 + p_valid_count[n] = 0 + with ib.for_range(0, num_classes-1, name="k") as k: + temp = p_cls_prob[n * num_anchors * num_classes + (k + 1) * num_anchors + i] + with ib.if_scope(temp > score[i]): + cls_id[i] = k + 1 + score[i] = temp + with ib.if_scope(tvm.all(cls_id[i] > 0, score[i] < threshold)): + cls_id[i] = 0 + with ib.if_scope(cls_id[i] > 0): + flag[i] = 1 + with ib.else_scope(): + flag[i] = 0 + + with ib.if_scope(tid < batch_size): + with ib.for_range(0, num_anchors, name="k") as k: + with ib.if_scope(k > 0): + flag[tid * num_anchors + + k] += flag[tid * num_anchors + k - 1] + p_valid_count[n] = flag[tid * num_anchors + num_anchors - 1] + + body = ib.get() + return body + + +def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \ + out, clip, variances, batch_size, num_classes, num_anchors): + """Low level IR routing for transform location in multibox_detection operator. + + Parameters + ---------- loc_pred : Buffer Buffer of location regression predictions. anchor : Buffer Buffer of prior anchor boxes. - valid_count : Buffer - Buffer of number of valid output boxes. + temp_flag : Buffer + Intermediate result buffer. + + temp_id : Buffer + Intermediate result buffer. + + temp_score_in : Buffer + Input buffer which stores intermediate results. out : Buffer Output buffer. @@ -156,12 +242,18 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho clip : boolean Whether to clip out-of-boundary boxes. - threshold : float - Threshold to be a positive prediction. - variances : tuple of float Variances to be decoded from box regression output. + batch_size : int + Batch size + + num_classes : int + Number of classes + + num_anchors : int + Number of anchors + Returns ------- stmt : Stmt @@ -187,21 +279,16 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, ow = tvm.exp(pw * vw) * aw / 2.0 oh = tvm.exp(ph * vh) * ah / 2.0 return tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox - ow)), ox - ow), \ - tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \ - tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \ - tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh) - - batch_size = cls_prob.shape[0] - num_classes = cls_prob.shape[1] - num_anchors = cls_prob.shape[2] + tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \ + tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \ + tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh) + max_threads = int( + tvm.target.current_target(allow_none=False).max_num_threads) ib = tvm.ir_builder.create() - temp_score = ib.allocate('float32', (batch_size * (num_classes -1) * num_anchors, \ - ), name="temp_score", scope="global") - score = ib.allocate('float32', (batch_size * num_anchors, ), name="score", scope="local") - cls_id = ib.allocate('int32', (batch_size * num_anchors, ), name="id", scope="local") - flag = ib.allocate('int32', (batch_size * num_anchors, ), name="flag", scope="global") - max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + score = ib.buffer_ptr(temp_score_in) + cls_id = ib.buffer_ptr(temp_id) + flag = ib.buffer_ptr(temp_flag) tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") nthread_tx = max_threads @@ -209,42 +296,13 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx - p_cls_prob = ib.buffer_ptr(cls_prob) p_loc_pred = ib.buffer_ptr(loc_pred) p_anchor = ib.buffer_ptr(anchor) - p_valid_count = ib.buffer_ptr(valid_count) p_out = ib.buffer_ptr(out) - with ib.if_scope(tid < batch_size * num_anchors * num_classes): - n = tid / (num_anchors * num_classes) - j = (tid % (num_anchors * num_classes)) / num_anchors - i = tid % num_anchors - with ib.if_scope(j > 0): - temp_score[n * num_anchors * num_classes + i * (num_classes - 1) + j-1] = \ - p_cls_prob[tid] - p_valid_count[n] = 0 - with ib.if_scope(tid < batch_size * num_anchors): - n = tid / num_anchors - i = tid % num_anchors - score[tid] = -1.0 - cls_id[tid] = 0 - with ib.for_range(0, num_classes-1, name="k") as k: - temp = temp_score[tid * (num_classes-1) + k] - cls_id[tid] = tvm.select(temp > score[tid], k + 1, cls_id[tid]) - score[tid] = tvm.make.Max(temp, score[tid]) - with ib.if_scope(tvm.all(cls_id[tid] > 0, score[tid] < threshold)): - cls_id[tid] = 0 - with ib.if_scope(cls_id[tid] > 0): - flag[tid] = 1 - with ib.else_scope(): - flag[tid] = 0 - with ib.if_scope(tid < batch_size): - with ib.for_range(0, num_anchors, name="k") as k: - with ib.if_scope(k > 0): - flag[tid * num_anchors + k] += flag[tid * num_anchors + k - 1] - p_valid_count[tid] = flag[tid * num_anchors + num_anchors - 1] + with ib.if_scope(tid < batch_size * num_anchors): - n = tid / num_anchors - i = tid % num_anchors + n = tid / num_anchors # number of batches + i = tid % num_anchors # number of anchors with ib.if_scope(cls_id[tid] > 0): with ib.if_scope(tid == 0): out_base_idx = n * num_anchors * 6 @@ -253,17 +311,17 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, p_out[out_base_idx] = cls_id[tid] - 1.0 p_out[out_base_idx + 1] = score[tid] p_out[out_base_idx + 2], p_out[out_base_idx + 3], p_out[out_base_idx + 4], \ - p_out[out_base_idx + 5] = transform_loc(p_loc_pred, tid * 4, p_anchor, i*4, - clip, variances[0], variances[1], - variances[2], variances[3]) + p_out[out_base_idx + 5] = transform_loc(p_loc_pred, tid * 4, + p_anchor, i*4, clip, variances[0], + variances[1], variances[2], variances[3]) body = ib.get() return body @multibox_transform_loc.register(["cuda", "gpu"]) -def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, - variances=(0.1, 0.1, 0.2, 0.2)): +def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, \ + threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)): """Location transformation for multibox detection Parameters @@ -297,20 +355,42 @@ def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, threshold= 1-D tensor with shape (batch_size,), number of valid anchor boxes. """ batch_size = cls_prob.shape[0] - num_anchors = anchor.shape[1] + num_classes = cls_prob.shape[1] + num_anchors = cls_prob.shape[2] oshape = (batch_size, num_anchors, 6) # Define data alignment for intermediate buffer valid_count_dtype = "int32" valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype, "valid_count_buf", data_alignment=4) - out_buf = api.decl_buffer(oshape, cls_prob.dtype, "out_buf", data_alignment=8) - valid_count, out = \ - tvm.extern([(batch_size,), oshape], - [cls_prob, loc_pred, anchor], + out_buf = api.decl_buffer( + oshape, cls_prob.dtype, "out_buf", data_alignment=8) + size = num_anchors + temp_flag_buf = api.decl_buffer( + (size,), valid_count_dtype, "flag", data_alignment=8) + temp_id_buf = api.decl_buffer( + (size,), valid_count_dtype, "cls_id", data_alignment=8) + temp_score_buf = api.decl_buffer( + (size,), cls_prob.dtype, "score", data_alignment=8) + + valid_count, temp_flag, temp_id, temp_score = \ + tvm.extern([(batch_size,), (size,), (size,), (size,)], + [cls_prob], + lambda ins, outs: transform_loc_pre( + ins[0], outs[0], outs[1], outs[2], outs[3], threshold), + dtype=[valid_count_dtype, + valid_count_dtype, valid_count_dtype, cls_prob.dtype], + out_buffers=[valid_count_buf, + temp_flag_buf, temp_id_buf, temp_score_buf], + tag="multibox_transform_loc_first_step") + + out = \ + tvm.extern([oshape], + [loc_pred, anchor, temp_flag, temp_id, temp_score], lambda ins, outs: transform_loc_ir( - ins[0], ins[1], ins[2], outs[0], outs[1], clip, threshold, variances), - dtype=[valid_count_dtype, cls_prob.dtype], - out_buffers=[valid_count_buf, out_buf], + ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, \ + variances, batch_size, num_classes, num_anchors), + dtype=[cls_prob.dtype], + out_buffers=[out_buf], tag="multibox_transform_loc") return [out, valid_count] @@ -356,5 +436,6 @@ def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01 """ inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances) - out = nms(inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk) + out = nms( + inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk) return out