diff --git a/src/operator-run.c b/src/operator-run.c index 24d301f837c..6b7e1af90de 100644 --- a/src/operator-run.c +++ b/src/operator-run.c @@ -422,48 +422,15 @@ void xnn_compute_grouped_gemm(struct gemm_context* restrict context, void xnn_compute_gemm(struct gemm_context* restrict context, size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, size_t mr_block_size) { - const size_t a_stride = context->a_stride; - const size_t cm_stride = context->cm_stride; - - while (mr_block_size > 0) { - const size_t mr_step = min(mr_block_size, context->mr); - - context->ukernel.function[XNN_UARCH_DEFAULT]( - mr_step, nr_block_size, context->k_scaled, - (const void*)((uintptr_t)context->a + mr_block_start * a_stride), - a_stride, - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride), - (void*)((uintptr_t)context->c + mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, context->fused_params); - mr_block_size -= mr_step; - mr_block_start += mr_step; - } + xnn_compute_hmp_gemm(context, XNN_UARCH_DEFAULT, nr_block_start, + mr_block_start, nr_block_size, mr_block_size); } void xnn_compute_dqgemm(struct gemm_context* restrict context, size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, size_t mr_block_size) { - const size_t a_stride = context->a_stride; - const size_t cm_stride = context->cm_stride; - - while (mr_block_size > 0) { - const size_t mr_step = min(mr_block_size, context->mr); - - context->dq_ukernel.function[XNN_UARCH_DEFAULT]( - mr_step, nr_block_size, context->k_scaled, - (const void*)((uintptr_t)context->a + mr_block_start * a_stride), - a_stride, - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride), - (void*)((uintptr_t)context->c + mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, context->fused_params, - &context->quantization_params[mr_block_start]); - mr_block_size -= mr_step; - mr_block_start += mr_step; - } + xnn_compute_hmp_dqgemm(context, XNN_UARCH_DEFAULT, nr_block_start, + mr_block_start, nr_block_size, mr_block_size); } void xnn_compute_hmp_grouped_qp8gemm(struct gemm_context* restrict context, @@ -732,34 +699,13 @@ void xnn_compute_grouped_inline_packed_qp8gemm(struct gemm_context* context, mr_block_start, mr_block_size); } -void xnn_compute_grouped_batch_igemm(struct igemm_context* restrict context, - size_t batch_index, size_t group_index, - size_t nr_block_start, - size_t mr_block_start, - size_t nr_block_size, - size_t mr_block_size) { - const size_t ks = context->ks; - const size_t cm_stride = context->cm_stride; - - while (mr_block_size > 0) { - const size_t mr_step = min(mr_block_size, context->mr); - context->ukernel.function[XNN_UARCH_DEFAULT]( - mr_step, nr_block_size, context->kc, context->ks_scaled, - (const void**)((uintptr_t)context->indirect_a + - mr_block_start * ks * sizeof(void*)), - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride + - group_index * context->gw_stride), - (void*)((uintptr_t)context->c + group_index * context->gc_stride + - batch_index * context->bc_stride + mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, - context->a_offset + group_index * context->ga_stride + - batch_index * context->ba_stride, - context->zero, &context->params); - mr_block_size -= mr_step; - mr_block_start += mr_step; - } +void xnn_compute_igemm(struct igemm_context* restrict context, + size_t batch_index, size_t group_index, + size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size) { + xnn_compute_hmp_igemm(context, XNN_UARCH_DEFAULT, batch_index, group_index, + nr_block_start, mr_block_start, nr_block_size, + mr_block_size); } void xnn_compute_dq_zero_buffer_igemm(struct igemm_context* restrict context, @@ -779,96 +725,29 @@ void xnn_compute_dq_zero_buffer_subconv( } } -void xnn_compute_grouped_batch_dqigemm(struct igemm_context* restrict context, - size_t batch_index, size_t group_index, - size_t nr_block_start, - size_t mr_block_start, - size_t nr_block_size, - size_t mr_block_size) { - const size_t ks = context->ks; - const size_t cm_stride = context->cm_stride; - - while (mr_block_size > 0) { - const size_t mr_step = min(mr_block_size, context->mr); - context->dq_ukernel.function[XNN_UARCH_DEFAULT]( - mr_step, nr_block_size, context->kc, context->ks_scaled, - (const void**)((uintptr_t)context->indirect_a + - mr_block_start * ks * sizeof(void*)), - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride + - group_index * context->gw_stride), - (void*)((uintptr_t)context->c + group_index * context->gc_stride + - batch_index * context->bc_stride + mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, - context->a_offset + group_index * context->ga_stride + - batch_index * context->ba_stride, - context->zero, context->zero_buffers[batch_index], &context->params, - &context->quantization_params[batch_index]); - mr_block_size -= mr_step; - mr_block_start += mr_step; - } -} - -void xnn_compute_grouped_igemm(struct igemm_context* restrict context, - size_t group_index, size_t nr_block_start, - size_t mr_block_start, size_t nr_block_size, - size_t mr_block_size) { - const size_t ks = context->ks; - const size_t cm_stride = context->cm_stride; - - while (mr_block_size > 0) { - const size_t mr_step = min(mr_block_size, context->mr); - context->ukernel.function[XNN_UARCH_DEFAULT]( - mr_step, nr_block_size, context->kc, context->ks_scaled, - (const void**)((uintptr_t)context->indirect_a + - mr_block_start * ks * sizeof(void*)), - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride + - group_index * context->gw_stride), - (void*)((uintptr_t)context->c + group_index * context->gc_stride + - mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, - context->a_offset + group_index * context->ga_stride, context->zero, - &context->params); - mr_block_size -= mr_step; - mr_block_start += mr_step; - } +void xnn_compute_dqigemm(struct igemm_context* restrict context, + size_t batch_index, size_t group_index, + size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size) { + xnn_compute_hmp_dqigemm(context, XNN_UARCH_DEFAULT, batch_index, group_index, + nr_block_start, mr_block_start, nr_block_size, + mr_block_size); } -void xnn_compute_grouped_dqigemm(struct igemm_context* restrict context, - size_t group_index, size_t nr_block_start, - size_t mr_block_start, size_t nr_block_size, - size_t mr_block_size) { - const size_t ks = context->ks; - const size_t cm_stride = context->cm_stride; - - while (mr_block_size > 0) { - const size_t mr_step = min(mr_block_size, context->mr); - context->dq_ukernel.function[XNN_UARCH_DEFAULT]( - mr_step, nr_block_size, context->kc, context->ks_scaled, - (const void**)((uintptr_t)context->indirect_a + - mr_block_start * ks * sizeof(void*)), - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride + - group_index * context->gw_stride), - (void*)((uintptr_t)context->c + group_index * context->gc_stride + - mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, - context->a_offset + group_index * context->ga_stride, context->zero, - context->zero_buffers[0], &context->params, - context->quantization_params); - mr_block_size -= mr_step; - mr_block_start += mr_step; - } +void xnn_compute_inline_packed_igemm( + struct igemm_context* restrict context, uint32_t thread_id, + size_t batch_index, size_t group_index, size_t mr_block_start, + size_t mr_block_size) { + xnn_compute_hmp_inline_packed_igemm(context, XNN_UARCH_DEFAULT, thread_id, + batch_index, group_index, mr_block_start, + mr_block_size); } -static void compute_batch_inline_packed_igemm( - struct igemm_context* restrict context, uint32_t uarch_index, - uint32_t thread_id, size_t batch_index, size_t group_index, - size_t mr_block_start, size_t mr_block_size) { +void xnn_compute_hmp_inline_packed_igemm(struct igemm_context* restrict context, + uint32_t uarch_index, size_t thread_id, + size_t batch_index, size_t group_index, + size_t mr_block_start, + size_t mr_block_size) { const size_t mr = context->mr; const size_t mr_packed = context->mr_packed; const size_t kc = context->kc; @@ -905,141 +784,6 @@ static void compute_batch_inline_packed_igemm( } } -void xnn_compute_batch_inline_packed_igemm( - struct igemm_context* restrict context, uint32_t thread_id, - size_t batch_index, size_t mr_block_start, size_t mr_block_size) { - compute_batch_inline_packed_igemm(context, XNN_UARCH_DEFAULT, thread_id, - batch_index, /*group_index=*/0, - mr_block_start, mr_block_size); -} - -void xnn_compute_batch_hmp_inline_packed_igemm( - struct igemm_context* restrict context, uint32_t uarch_index, - size_t thread_id, size_t batch_index, size_t mr_block_start, - size_t mr_block_size) { - compute_batch_inline_packed_igemm(context, uarch_index, thread_id, - batch_index, /*group_index=*/0, - mr_block_start, mr_block_size); -} - -void xnn_compute_grouped_batch_inline_packed_igemm( - struct igemm_context* restrict context, uint32_t thread_id, - size_t batch_index, size_t group_index, size_t mr_block_start, - size_t mr_block_size) { - compute_batch_inline_packed_igemm(context, XNN_UARCH_DEFAULT, thread_id, - batch_index, group_index, mr_block_start, - mr_block_size); -} - -void xnn_compute_grouped_batch_hmp_inline_packed_igemm( - struct igemm_context* restrict context, uint32_t uarch_index, - size_t thread_id, size_t batch_index, size_t group_index, - size_t mr_block_start, size_t mr_block_size) { - compute_batch_inline_packed_igemm(context, uarch_index, thread_id, - batch_index, group_index, mr_block_start, - mr_block_size); -} - -void xnn_compute_batch_igemm(struct igemm_context* restrict context, - size_t batch_index, size_t nr_block_start, - size_t mr_block_start, size_t nr_block_size, - size_t mr_block_size) { - const size_t ks = context->ks; - const size_t cm_stride = context->cm_stride; - - while (mr_block_size > 0) { - const size_t mr_step = min(mr_block_size, context->mr); - context->ukernel.function[XNN_UARCH_DEFAULT]( - mr_step, nr_block_size, context->kc, context->ks_scaled, - (const void**)((uintptr_t)context->indirect_a + - mr_block_start * ks * sizeof(void*)), - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride), - (void*)((uintptr_t)context->c + batch_index * context->bc_stride + - mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, - context->a_offset + batch_index * context->ba_stride, context->zero, - &context->params); - mr_block_size -= mr_step; - mr_block_start += mr_step; - } -} - -void xnn_compute_batch_dqigemm(struct igemm_context* restrict context, - size_t batch_index, size_t nr_block_start, - size_t mr_block_start, size_t nr_block_size, - size_t mr_block_size) { - const size_t ks = context->ks; - const size_t cm_stride = context->cm_stride; - - while (mr_block_size > 0) { - const size_t mr_step = min(mr_block_size, context->mr); - context->dq_ukernel.function[XNN_UARCH_DEFAULT]( - mr_step, nr_block_size, context->kc, context->ks_scaled, - (const void**)((uintptr_t)context->indirect_a + - mr_block_start * ks * sizeof(void*)), - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride), - (void*)((uintptr_t)context->c + batch_index * context->bc_stride + - mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, - context->a_offset + batch_index * context->ba_stride, context->zero, - context->zero_buffers[batch_index], &context->params, - &context->quantization_params[batch_index]); - mr_block_size -= mr_step; - mr_block_start += mr_step; - } -} - -void xnn_compute_igemm(struct igemm_context* restrict context, - size_t nr_block_start, size_t mr_block_start, - size_t nr_block_size, size_t mr_block_size) { - const size_t ks = context->ks; - const size_t cm_stride = context->cm_stride; - - while (mr_block_size > 0) { - const size_t mr_step = min(mr_block_size, context->mr); - context->ukernel.function[XNN_UARCH_DEFAULT]( - mr_step, nr_block_size, context->kc, context->ks_scaled, - (const void**)((uintptr_t)context->indirect_a + - mr_block_start * ks * sizeof(void*)), - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride), - (void*)((uintptr_t)context->c + mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, context->a_offset, context->zero, - &context->params); - mr_block_size -= mr_step; - mr_block_start += mr_step; - } -} - -void xnn_compute_dqigemm(struct igemm_context* restrict context, - size_t nr_block_start, size_t mr_block_start, - size_t nr_block_size, size_t mr_block_size) { - const size_t ks = context->ks; - const size_t cm_stride = context->cm_stride; - - while (mr_block_size > 0) { - const size_t mr_step = min(mr_block_size, context->mr); - context->dq_ukernel.function[XNN_UARCH_DEFAULT]( - mr_step, nr_block_size, context->kc, context->ks_scaled, - (const void**)((uintptr_t)context->indirect_a + - mr_block_start * ks * sizeof(void*)), - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride), - (void*)((uintptr_t)context->c + mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, context->a_offset, context->zero, - context->zero_buffers[0], &context->params, - &context->quantization_params[/*mr_block_start=*/0]); - mr_block_size -= mr_step; - mr_block_start += mr_step; - } -} - // `output_tile_start` should be a multiple of igemm.mr (tile size). void xnn_compute_conv2d_igemm_indirection( struct conv2d_igemm_indirection_init_context* restrict context, @@ -1060,47 +804,11 @@ void xnn_compute_conv2d_igemm_indirection( } } -void xnn_compute_grouped_subgemm2d(struct subgemm_context* restrict context, - size_t batch_index, size_t group_index, - size_t subkernel_index, size_t slice_y, - size_t slice_x_start, size_t nc_block_start, - size_t slice_x_max, size_t nc_block_size) { - const struct subconvolution_params* subconvolution_params = - &context->subconvolution_params[subkernel_index]; - - if XNN_UNLIKELY (slice_y >= subconvolution_params->slice_height) { - return; - } - - const size_t slice_width = subconvolution_params->slice_width; - if XNN_UNLIKELY (slice_x_start >= slice_width) { - return; - } - const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); - - const size_t ax_stride = context->ax_stride; - const size_t cx_stride = context->cx_stride; - context->ukernel.function[XNN_UARCH_DEFAULT]( - slice_x_size, nc_block_size, context->kc, - (const void*)((uintptr_t)context->a + group_index * context->ga_stride + - slice_y * context->ay_stride + slice_x_start * ax_stride + - batch_index * context->ba_stride), - ax_stride, - (const void*)((uintptr_t)subconvolution_params->weights + - nc_block_start * subconvolution_params->w_stride + - group_index * context->gw_stride), - (void*)((uintptr_t)subconvolution_params->output + - group_index * context->gc_stride + slice_y * context->cy_stride + - slice_x_start * cx_stride + batch_index * context->bc_stride + - (nc_block_start << context->log2_csize)), - cx_stride, context->cn_stride, &context->params); -} - -void xnn_compute_grouped_subconv2d(struct subconv_context* restrict context, - size_t batch_index, size_t group_index, - size_t subkernel_index, size_t slice_y, - size_t slice_x_start, size_t nc_block_start, - size_t slice_x_max, size_t nc_block_size) { +void xnn_compute_subconv2d(struct subconv_context* restrict context, + size_t batch_index, size_t group_index, + size_t subkernel_index, size_t slice_y, + size_t slice_x_start, size_t nc_block_start, + size_t slice_x_max, size_t nc_block_size) { const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index]; @@ -1135,12 +843,11 @@ void xnn_compute_grouped_subconv2d(struct subconv_context* restrict context, context->zero, &context->params); } -void xnn_compute_grouped_dqsubconv2d(struct subconv_context* restrict context, - size_t batch_index, size_t group_index, - size_t subkernel_index, size_t slice_y, - size_t slice_x_start, - size_t nc_block_start, size_t slice_x_max, - size_t nc_block_size) { +void xnn_compute_dqsubconv2d(struct subconv_context* restrict context, + size_t batch_index, size_t group_index, + size_t subkernel_index, size_t slice_y, + size_t slice_x_start, size_t nc_block_start, + size_t slice_x_max, size_t nc_block_size) { const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index]; @@ -1176,81 +883,6 @@ void xnn_compute_grouped_dqsubconv2d(struct subconv_context* restrict context, &context->quantization_params[batch_index]); } -void xnn_compute_subconv2d(struct subconv_context* restrict context, - size_t batch_index, size_t subkernel_index, - size_t slice_y, size_t slice_x_start, - size_t nc_block_start, size_t slice_x_max, - size_t nc_block_size) { - const struct subconvolution_params* subconvolution_params = - &context->subconvolution_params[subkernel_index]; - - if XNN_UNLIKELY (slice_y >= subconvolution_params->slice_height) { - return; - } - - const size_t slice_width = subconvolution_params->slice_width; - if XNN_UNLIKELY (slice_x_start >= slice_width) { - return; - } - const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); - - const size_t cx_stride = context->cx_stride; - context->ukernel.function[XNN_UARCH_DEFAULT]( - slice_x_size, nc_block_size, context->kc, - subconvolution_params->scaled_kernel_size, - (const void**)((uintptr_t)subconvolution_params->indirection_buffer + - slice_y * subconvolution_params->indirection_y_stride + - slice_x_start * - subconvolution_params->indirection_x_stride), - (const void*)((uintptr_t)subconvolution_params->weights + - nc_block_start * subconvolution_params->w_stride), - (void*)((uintptr_t)subconvolution_params->output + - slice_y * context->cy_stride + slice_x_start * cx_stride + - batch_index * context->bc_stride + - (nc_block_start << context->log2_csize)), - cx_stride, context->cn_stride, - context->a_offset + batch_index * context->ba_stride, context->zero, - &context->params); -} - -void xnn_compute_dqsubconv2d(struct subconv_context* restrict context, - size_t batch_index, size_t subkernel_index, - size_t slice_y, size_t slice_x_start, - size_t nc_block_start, size_t slice_x_max, - size_t nc_block_size) { - const struct subconvolution_params* subconvolution_params = - &context->subconvolution_params[subkernel_index]; - - if XNN_UNLIKELY (slice_y >= subconvolution_params->slice_height) { - return; - } - - const size_t slice_width = subconvolution_params->slice_width; - if XNN_UNLIKELY (slice_x_start >= slice_width) { - return; - } - const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); - - const size_t cx_stride = context->cx_stride; - context->dq_ukernel.function[XNN_UARCH_DEFAULT]( - slice_x_size, nc_block_size, context->kc, - subconvolution_params->scaled_kernel_size, - (const void**)((uintptr_t)subconvolution_params->indirection_buffer + - slice_y * subconvolution_params->indirection_y_stride + - slice_x_start * - subconvolution_params->indirection_x_stride), - (const void*)((uintptr_t)subconvolution_params->weights + - nc_block_start * subconvolution_params->w_stride), - (void*)((uintptr_t)subconvolution_params->output + - slice_y * context->cy_stride + slice_x_start * cx_stride + - batch_index * context->bc_stride + - (nc_block_start << context->log2_csize)), - cx_stride, context->cn_stride, - context->a_offset + batch_index * context->ba_stride, context->zero, - context->zero_buffers[batch_index], &context->params, - &context->quantization_params[batch_index]); -} - void xnn_compute_conv2d_hwc2chw(struct conv2d_context* restrict context, size_t batch_index, size_t output_y_start, size_t output_y_slice) { @@ -1702,10 +1334,10 @@ void xnn_compute_univector_contiguous( context->ukernel(size, x, y, &context->params); } -void xnn_compute_contiguous_reduce( - struct reduce_context* restrict context, size_t output_idx0, - size_t output_idx1, size_t output_idx2, - size_t output2_block_size) { +void xnn_compute_contiguous_reduce(struct reduce_context* restrict context, + size_t output_idx0, size_t output_idx1, + size_t output_idx2, + size_t output2_block_size) { const size_t* input_stride = context->input_stride; const size_t* output_stride = context->output_stride; @@ -2073,7 +1705,6 @@ void xnn_compute_rope(struct rope_context* restrict context, size_t batch_index, context->vcmul(scaled_channels, input, weights, output, NULL); } -#if XNN_MAX_UARCH_TYPES > 1 void xnn_compute_hmp_gemm(struct gemm_context* restrict context, uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, @@ -2121,10 +1752,11 @@ void xnn_compute_hmp_dqgemm(struct gemm_context* restrict context, } } -void xnn_compute_hmp_grouped_batch_igemm( - struct igemm_context* restrict context, uint32_t uarch_index, - size_t batch_index, size_t group_index, size_t nr_block_start, - size_t mr_block_start, size_t nr_block_size, size_t mr_block_size) { +void xnn_compute_hmp_igemm(struct igemm_context* restrict context, + uint32_t uarch_index, size_t batch_index, + size_t group_index, size_t nr_block_start, + size_t mr_block_start, size_t nr_block_size, + size_t mr_block_size) { const size_t ks = context->ks; const size_t cm_stride = context->cm_stride; @@ -2149,10 +1781,11 @@ void xnn_compute_hmp_grouped_batch_igemm( } } -void xnn_compute_hmp_grouped_batch_dqigemm( - struct igemm_context* restrict context, uint32_t uarch_index, - size_t batch_index, size_t group_index, size_t nr_block_start, - size_t mr_block_start, size_t nr_block_size, size_t mr_block_size) { +void xnn_compute_hmp_dqigemm(struct igemm_context* restrict context, + uint32_t uarch_index, size_t batch_index, + size_t group_index, size_t nr_block_start, + size_t mr_block_start, size_t nr_block_size, + size_t mr_block_size) { const size_t ks = context->ks; const size_t cm_stride = context->cm_stride; @@ -2178,166 +1811,6 @@ void xnn_compute_hmp_grouped_batch_dqigemm( } } -void xnn_compute_hmp_grouped_igemm(struct igemm_context* restrict context, - uint32_t uarch_index, size_t group_index, - size_t nr_block_start, size_t mr_block_start, - size_t nr_block_size, size_t mr_block_size) { - const size_t ks = context->ks; - const size_t cm_stride = context->cm_stride; - - while (mr_block_size > 0) { - const size_t mr_step = min(mr_block_size, context->mr); - context->ukernel.function[uarch_index]( - mr_step, nr_block_size, context->kc, context->ks_scaled, - (const void**)((uintptr_t)context->indirect_a + - mr_block_start * ks * sizeof(void*)), - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride + - group_index * context->gw_stride), - (void*)((uintptr_t)context->c + group_index * context->gc_stride + - mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, - context->a_offset + group_index * context->ga_stride, context->zero, - &context->params); - mr_block_size -= mr_step; - mr_block_start += mr_step; - } -} - -void xnn_compute_hmp_grouped_dqigemm(struct igemm_context* restrict context, - uint32_t uarch_index, size_t group_index, - size_t nr_block_start, - size_t mr_block_start, - size_t nr_block_size, - size_t mr_block_size) { - const size_t ks = context->ks; - const size_t cm_stride = context->cm_stride; - - while (mr_block_size > 0) { - const size_t mr_step = min(mr_block_size, context->mr); - context->dq_ukernel.function[uarch_index]( - mr_step, nr_block_size, context->kc, context->ks_scaled, - (const void**)((uintptr_t)context->indirect_a + - mr_block_start * ks * sizeof(void*)), - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride + - group_index * context->gw_stride), - (void*)((uintptr_t)context->c + group_index * context->gc_stride + - mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, - context->a_offset + group_index * context->ga_stride, context->zero, - context->zero_buffers[0], &context->params, - (const void*)((uintptr_t)context->quantization_params)); - mr_block_size -= mr_step; - mr_block_start += mr_step; - } -} - -void xnn_compute_batch_hmp_igemm(struct igemm_context* restrict context, - uint32_t uarch_index, size_t batch_index, - size_t nr_block_start, size_t mr_block_start, - size_t nr_block_size, size_t mr_block_size) { - const size_t ks = context->ks; - const size_t cm_stride = context->cm_stride; - - while (mr_block_size > 0) { - const size_t mr_step = min(mr_block_size, context->mr); - context->ukernel.function[uarch_index]( - mr_step, nr_block_size, context->kc, context->ks_scaled, - (const void**)((uintptr_t)context->indirect_a + - mr_block_start * ks * sizeof(void*)), - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride), - (void*)((uintptr_t)context->c + batch_index * context->bc_stride + - mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, - context->a_offset + batch_index * context->ba_stride, context->zero, - &context->params); - mr_block_size -= mr_step; - mr_block_start += mr_step; - } -} - -void xnn_compute_batch_hmp_dqigemm(struct igemm_context* restrict context, - uint32_t uarch_index, size_t batch_index, - size_t nr_block_start, size_t mr_block_start, - size_t nr_block_size, size_t mr_block_size) { - const size_t ks = context->ks; - const size_t cm_stride = context->cm_stride; - - while (mr_block_size > 0) { - const size_t mr_step = min(mr_block_size, context->mr); - context->dq_ukernel.function[uarch_index]( - mr_step, nr_block_size, context->kc, context->ks_scaled, - (const void**)((uintptr_t)context->indirect_a + - mr_block_start * ks * sizeof(void*)), - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride), - (void*)((uintptr_t)context->c + batch_index * context->bc_stride + - mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, - context->a_offset + batch_index * context->ba_stride, context->zero, - context->zero_buffers[batch_index], &context->params, - &context->quantization_params[batch_index]); - mr_block_size -= mr_step; - mr_block_start += mr_step; - } -} - -void xnn_compute_hmp_igemm(struct igemm_context* restrict context, - uint32_t uarch_index, size_t nr_block_start, - size_t mr_block_start, size_t nr_block_size, - size_t mr_block_size) { - const size_t ks = context->ks; - const size_t cm_stride = context->cm_stride; - - while (mr_block_size > 0) { - const size_t mr_step = min(mr_block_size, context->mr); - context->ukernel.function[uarch_index]( - mr_step, nr_block_size, context->kc, context->ks_scaled, - (const void**)((uintptr_t)context->indirect_a + - mr_block_start * ks * sizeof(void*)), - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride), - (void*)((uintptr_t)context->c + mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, context->a_offset, context->zero, - &context->params); - mr_block_size -= mr_step; - mr_block_start += mr_step; - } -} - -void xnn_compute_hmp_dqigemm(struct igemm_context* restrict context, - uint32_t uarch_index, size_t nr_block_start, - size_t mr_block_start, size_t nr_block_size, - size_t mr_block_size) { - const size_t ks = context->ks; - const size_t cm_stride = context->cm_stride; - - while (mr_block_size > 0) { - const size_t mr_step = min(mr_block_size, context->mr); - context->dq_ukernel.function[uarch_index]( - mr_step, nr_block_size, context->kc, context->ks_scaled, - (const void**)((uintptr_t)context->indirect_a + - mr_block_start * ks * sizeof(void*)), - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride), - (void*)((uintptr_t)context->c + mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, context->a_offset, context->zero, - context->zero_buffers[0], &context->params, - (const void*)((uintptr_t)context->quantization_params)); - mr_block_size -= mr_step; - mr_block_start += mr_step; - } -} -#endif // XNN_MAX_UARCH_TYPES > 1 - enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool) { return xnn_run_operator_with_index(op, 0, 0, threadpool); } diff --git a/src/operators/convolution-nhwc.c b/src/operators/convolution-nhwc.c index a4e37b99995..8fdb8c75902 100644 --- a/src/operators/convolution-nhwc.c +++ b/src/operators/convolution-nhwc.c @@ -2241,166 +2241,57 @@ static enum xnn_status reshape_igemm( dq_zero_buffer_compute->range[0] = batch_size; } - if (groups == 1) { #if XNN_MAX_UARCH_TYPES > 1 - if (xnn_is_hmp_igemm_ukernel(igemm_ukernel)) { - if (packed_lh_config && inline_lhs_packing) { - igemm_compute->type = - xnn_parallelization_type_2d_tile_1d_dynamic_with_uarch_with_thread; - igemm_compute->task_2d_tile_1d_dynamic_with_id_with_thread = - (pthreadpool_task_2d_tile_1d_dynamic_with_id_with_thread_t) - xnn_compute_batch_hmp_inline_packed_igemm; - } else if (batch_size > 1) { - igemm_compute->type = - xnn_parallelization_type_3d_tile_2d_dynamic_with_uarch; - if (dynamic_quantization) { - igemm_compute->task_3d_tile_2d_dynamic_with_id = - (pthreadpool_task_3d_tile_2d_dynamic_with_id_t) - xnn_compute_batch_hmp_dqigemm; - } else { - igemm_compute->task_3d_tile_2d_dynamic_with_id = - (pthreadpool_task_3d_tile_2d_dynamic_with_id_t) - xnn_compute_batch_hmp_igemm; - } - } else { - igemm_compute->type = - xnn_parallelization_type_2d_tile_2d_dynamic_with_uarch; - if (dynamic_quantization) { - igemm_compute->task_2d_tile_2d_dynamic_with_id = - (pthreadpool_task_2d_tile_2d_dynamic_with_id_t) - xnn_compute_hmp_dqigemm; - } else { - igemm_compute->task_2d_tile_2d_dynamic_with_id = - (pthreadpool_task_2d_tile_2d_dynamic_with_id_t) - xnn_compute_hmp_igemm; - } - } - } else -#endif // XNN_MAX_UARCH_TYPES > 1 - if (packed_lh_config && inline_lhs_packing) { - igemm_compute->type = - xnn_parallelization_type_2d_tile_1d_dynamic_with_thread; - igemm_compute->task_2d_tile_1d_dynamic_with_id = - (pthreadpool_task_2d_tile_1d_dynamic_with_id_t) - xnn_compute_batch_inline_packed_igemm; - } else if (batch_size > 1) { - igemm_compute->type = xnn_parallelization_type_3d_tile_2d_dynamic; - if (dynamic_quantization) { - igemm_compute->task_3d_tile_2d_dynamic = - (pthreadpool_task_3d_tile_2d_dynamic_t)xnn_compute_batch_dqigemm; - } else { - igemm_compute->task_3d_tile_2d_dynamic = - (pthreadpool_task_3d_tile_2d_dynamic_t)xnn_compute_batch_igemm; - } - } else { - igemm_compute->type = xnn_parallelization_type_2d_tile_2d_dynamic; - if (dynamic_quantization) { - igemm_compute->task_2d_tile_2d_dynamic = - (pthreadpool_task_2d_tile_2d_dynamic_t)xnn_compute_dqigemm; - } else { - igemm_compute->task_2d_tile_2d_dynamic = - (pthreadpool_task_2d_tile_2d_dynamic_t)xnn_compute_igemm; - } - } + if (xnn_is_hmp_igemm_ukernel(igemm_ukernel)) { if (packed_lh_config && inline_lhs_packing) { - igemm_compute->range[0] = batch_size; - igemm_compute->range[1] = output_size; - igemm_compute->tile[0] = mr; - } else if (batch_size > 1) { - igemm_compute->range[0] = batch_size; - igemm_compute->range[1] = group_output_channels; - igemm_compute->range[2] = output_size; - igemm_compute->tile[0] = nc; - igemm_compute->tile[1] = mr; + igemm_compute->type = + xnn_parallelization_type_3d_tile_1d_dynamic_with_uarch_with_thread; + igemm_compute->task_3d_tile_1d_dynamic_with_id_with_thread = + (pthreadpool_task_3d_tile_1d_dynamic_with_id_with_thread_t) + xnn_compute_hmp_inline_packed_igemm; } else { - igemm_compute->range[0] = group_output_channels; - igemm_compute->range[1] = output_size; - igemm_compute->tile[0] = nc; - igemm_compute->tile[1] = mr; - } - } else { -#if XNN_MAX_UARCH_TYPES > 1 - if (xnn_is_hmp_igemm_ukernel(igemm_ukernel)) { - if (packed_lh_config && inline_lhs_packing) { - igemm_compute->type = - xnn_parallelization_type_3d_tile_1d_dynamic_with_uarch_with_thread; - igemm_compute->task_3d_tile_1d_dynamic_with_id_with_thread = - (pthreadpool_task_3d_tile_1d_dynamic_with_id_with_thread_t) - xnn_compute_grouped_batch_hmp_inline_packed_igemm; - } else if (batch_size > 1) { - igemm_compute->type = - xnn_parallelization_type_4d_tile_2d_dynamic_with_uarch; - if (dynamic_quantization) { - igemm_compute->task_4d_tile_2d_dynamic_with_id = - (pthreadpool_task_4d_tile_2d_dynamic_with_id_t) - xnn_compute_hmp_grouped_batch_dqigemm; - } else { - igemm_compute->task_4d_tile_2d_dynamic_with_id = - (pthreadpool_task_4d_tile_2d_dynamic_with_id_t) - xnn_compute_hmp_grouped_batch_igemm; - } + igemm_compute->type = + xnn_parallelization_type_4d_tile_2d_dynamic_with_uarch; + if (dynamic_quantization) { + igemm_compute->task_4d_tile_2d_dynamic_with_id = + (pthreadpool_task_4d_tile_2d_dynamic_with_id_t) + xnn_compute_hmp_dqigemm; } else { - igemm_compute->type = - xnn_parallelization_type_3d_tile_2d_dynamic_with_uarch; - if (dynamic_quantization) { - igemm_compute->task_3d_tile_2d_dynamic_with_id = - (pthreadpool_task_3d_tile_2d_dynamic_with_id_t) - xnn_compute_hmp_grouped_dqigemm; - } else { - igemm_compute->task_3d_tile_2d_dynamic_with_id = - (pthreadpool_task_3d_tile_2d_dynamic_with_id_t) - xnn_compute_hmp_grouped_igemm; - } + igemm_compute->task_4d_tile_2d_dynamic_with_id = + (pthreadpool_task_4d_tile_2d_dynamic_with_id_t) + xnn_compute_hmp_igemm; } - } else + } + } else #endif // XNN_MAX_UARCH_TYPES > 1 - if (packed_lh_config && inline_lhs_packing) { - igemm_compute->type = - xnn_parallelization_type_3d_tile_1d_dynamic_with_thread; - igemm_compute->task_3d_tile_1d_dynamic_with_id = - (pthreadpool_task_3d_tile_1d_dynamic_with_id_t) - xnn_compute_grouped_batch_inline_packed_igemm; - } else if (batch_size > 1) { - igemm_compute->type = xnn_parallelization_type_4d_tile_2d_dynamic; - if (dynamic_quantization) { - igemm_compute->task_4d_tile_2d_dynamic = - (pthreadpool_task_4d_tile_2d_dynamic_t) - xnn_compute_grouped_batch_dqigemm; - } else { - igemm_compute->task_4d_tile_2d = - (pthreadpool_task_4d_tile_2d_dynamic_t) - xnn_compute_grouped_batch_igemm; - } - } else { - igemm_compute->type = xnn_parallelization_type_3d_tile_2d_dynamic; - if (dynamic_quantization) { - igemm_compute->task_3d_tile_2d_dynamic = - (pthreadpool_task_3d_tile_2d_dynamic_t) - xnn_compute_grouped_dqigemm; - } else { - igemm_compute->task_3d_tile_2d_dynamic = - (pthreadpool_task_3d_tile_2d_dynamic_t)xnn_compute_grouped_igemm; - } - } if (packed_lh_config && inline_lhs_packing) { - igemm_compute->range[0] = batch_size; - igemm_compute->range[1] = groups; - igemm_compute->range[2] = output_size; - igemm_compute->tile[0] = mr; - } else if (batch_size > 1) { - igemm_compute->range[0] = batch_size; - igemm_compute->range[1] = groups; - igemm_compute->range[2] = group_output_channels; - igemm_compute->range[3] = output_size; - igemm_compute->tile[0] = nc; - igemm_compute->tile[1] = mr; + igemm_compute->type = + xnn_parallelization_type_3d_tile_1d_dynamic_with_thread; + igemm_compute->task_3d_tile_1d_dynamic_with_id = + (pthreadpool_task_3d_tile_1d_dynamic_with_id_t) + xnn_compute_inline_packed_igemm; } else { - igemm_compute->range[0] = groups; - igemm_compute->range[1] = group_output_channels; - igemm_compute->range[2] = output_size; - igemm_compute->tile[0] = nc; - igemm_compute->tile[1] = mr; + igemm_compute->type = xnn_parallelization_type_4d_tile_2d_dynamic; + if (dynamic_quantization) { + igemm_compute->task_4d_tile_2d_dynamic = + (pthreadpool_task_4d_tile_2d_dynamic_t)xnn_compute_dqigemm; + } else { + igemm_compute->task_4d_tile_2d = + (pthreadpool_task_4d_tile_2d_dynamic_t)xnn_compute_igemm; + } } + if (packed_lh_config && inline_lhs_packing) { + igemm_compute->range[0] = batch_size; + igemm_compute->range[1] = groups; + igemm_compute->range[2] = output_size; + igemm_compute->tile[0] = mr; + } else { + igemm_compute->range[0] = batch_size; + igemm_compute->range[1] = groups; + igemm_compute->range[2] = group_output_channels; + igemm_compute->range[3] = output_size; + igemm_compute->tile[0] = nc; + igemm_compute->tile[1] = mr; } convolution_op->state = xnn_run_state_needs_setup; diff --git a/src/operators/deconvolution-nhwc.c b/src/operators/deconvolution-nhwc.c index 6ab8f2eb443..2ddf3299b69 100644 --- a/src/operators/deconvolution-nhwc.c +++ b/src/operators/deconvolution-nhwc.c @@ -1336,167 +1336,58 @@ static enum xnn_status reshape_igemm_path( dq_zero_buffer_compute->range[0] = batch_size; } - if (groups == 1) { #if XNN_MAX_UARCH_TYPES > 1 - if (xnn_is_hmp_igemm_ukernel(igemm_ukernel)) { - if (packed_lh_config && inline_lhs_packing) { - igemm_compute->type = - xnn_parallelization_type_2d_tile_1d_dynamic_with_uarch_with_thread; - igemm_compute->task_2d_tile_1d_dynamic_with_id_with_thread = - (pthreadpool_task_2d_tile_1d_dynamic_with_id_with_thread_t) - xnn_compute_batch_hmp_inline_packed_igemm; - } else if (batch_size > 1) { - igemm_compute->type = - xnn_parallelization_type_3d_tile_2d_dynamic_with_uarch; - if (dynamic_quantization) { - igemm_compute->task_3d_tile_2d_dynamic_with_id = - (pthreadpool_task_3d_tile_2d_dynamic_with_id_t) - xnn_compute_batch_hmp_dqigemm; - } else { - igemm_compute->task_3d_tile_2d_dynamic_with_id = - (pthreadpool_task_3d_tile_2d_dynamic_with_id_t) - xnn_compute_batch_hmp_igemm; - } - } else { - igemm_compute->type = - xnn_parallelization_type_2d_tile_2d_dynamic_with_uarch; - if (dynamic_quantization) { - igemm_compute->task_2d_tile_2d_dynamic_with_id = - (pthreadpool_task_2d_tile_2d_dynamic_with_id_t) - xnn_compute_hmp_dqigemm; - } else { - igemm_compute->task_2d_tile_2d_dynamic_with_id = - (pthreadpool_task_2d_tile_2d_dynamic_with_id_t) - xnn_compute_hmp_igemm; - } - } - } else -#endif // XNN_MAX_UARCH_TYPES > 1 - if (packed_lh_config && inline_lhs_packing) { - igemm_compute->type = - xnn_parallelization_type_2d_tile_1d_dynamic_with_thread; - igemm_compute->task_2d_tile_1d_dynamic_with_id = - (pthreadpool_task_2d_tile_1d_dynamic_with_id_t) - xnn_compute_batch_inline_packed_igemm; - } else if (batch_size > 1) { - igemm_compute->type = xnn_parallelization_type_3d_tile_2d_dynamic; - if (dynamic_quantization) { - igemm_compute->task_3d_tile_2d_dynamic = - (pthreadpool_task_3d_tile_2d_dynamic_t)xnn_compute_batch_dqigemm; - } else { - igemm_compute->task_3d_tile_2d_dynamic = - (pthreadpool_task_3d_tile_2d_dynamic_t)xnn_compute_batch_igemm; - } - } else { - igemm_compute->type = xnn_parallelization_type_2d_tile_2d_dynamic; - if (dynamic_quantization) { - igemm_compute->task_2d_tile_2d_dynamic = - (pthreadpool_task_2d_tile_2d_dynamic_t)xnn_compute_dqigemm; - } else { - igemm_compute->task_2d_tile_2d_dynamic = - (pthreadpool_task_2d_tile_2d_dynamic_t)xnn_compute_igemm; - } - } + if (xnn_is_hmp_igemm_ukernel(igemm_ukernel)) { if (packed_lh_config && inline_lhs_packing) { - igemm_compute->range[0] = batch_size; - igemm_compute->range[1] = output_size; - igemm_compute->tile[0] = mr; - } else if (batch_size > 1) { - igemm_compute->range[0] = batch_size; - igemm_compute->range[1] = group_output_channels; - igemm_compute->range[2] = output_size; - igemm_compute->tile[0] = nc; - igemm_compute->tile[1] = mr; + igemm_compute->type = + xnn_parallelization_type_3d_tile_1d_dynamic_with_uarch_with_thread; + igemm_compute->task_3d_tile_1d_dynamic_with_id_with_thread = + (pthreadpool_task_3d_tile_1d_dynamic_with_id_with_thread_t) + xnn_compute_hmp_inline_packed_igemm; } else { - igemm_compute->range[0] = group_output_channels; - igemm_compute->range[1] = output_size; - igemm_compute->tile[0] = nc; - igemm_compute->tile[1] = mr; - } - } else { -#if XNN_MAX_UARCH_TYPES > 1 - if (xnn_is_hmp_igemm_ukernel(igemm_ukernel)) { - if (packed_lh_config && inline_lhs_packing) { - igemm_compute->type = - xnn_parallelization_type_3d_tile_1d_dynamic_with_uarch_with_thread; - igemm_compute->task_3d_tile_1d_dynamic_with_id_with_thread = - (pthreadpool_task_3d_tile_1d_dynamic_with_id_with_thread_t) - xnn_compute_grouped_batch_hmp_inline_packed_igemm; - } else if (batch_size > 1) { - igemm_compute->type = - xnn_parallelization_type_4d_tile_2d_dynamic_with_uarch; - if (dynamic_quantization) { - igemm_compute->task_4d_tile_2d_dynamic_with_id = - (pthreadpool_task_4d_tile_2d_dynamic_with_id_t) - xnn_compute_hmp_grouped_batch_dqigemm; - } else { - igemm_compute->task_4d_tile_2d_dynamic_with_id = - (pthreadpool_task_4d_tile_2d_dynamic_with_id_t) - xnn_compute_hmp_grouped_batch_igemm; - } + igemm_compute->type = + xnn_parallelization_type_4d_tile_2d_dynamic_with_uarch; + if (dynamic_quantization) { + igemm_compute->task_4d_tile_2d_dynamic_with_id = + (pthreadpool_task_4d_tile_2d_dynamic_with_id_t) + xnn_compute_hmp_dqigemm; } else { - igemm_compute->type = - xnn_parallelization_type_3d_tile_2d_dynamic_with_uarch; - if (dynamic_quantization) { - igemm_compute->task_3d_tile_2d_dynamic_with_id = - (pthreadpool_task_3d_tile_2d_dynamic_with_id_t) - xnn_compute_hmp_grouped_dqigemm; - } else { - igemm_compute->task_3d_tile_2d_dynamic_with_id = - (pthreadpool_task_3d_tile_2d_dynamic_with_id_t) - xnn_compute_hmp_grouped_igemm; - } + igemm_compute->task_4d_tile_2d_dynamic_with_id = + (pthreadpool_task_4d_tile_2d_dynamic_with_id_t) + xnn_compute_hmp_igemm; } - } else + } + } else #endif // XNN_MAX_UARCH_TYPES > 1 - if (packed_lh_config && inline_lhs_packing) { - igemm_compute->type = - xnn_parallelization_type_3d_tile_1d_dynamic_with_thread; - igemm_compute->task_3d_tile_1d_dynamic_with_id = - (pthreadpool_task_3d_tile_1d_dynamic_with_id_t) - xnn_compute_grouped_batch_inline_packed_igemm; - } else if (batch_size > 1) { - igemm_compute->type = xnn_parallelization_type_4d_tile_2d_dynamic; - if (dynamic_quantization) { - igemm_compute->task_4d_tile_2d_dynamic = - (pthreadpool_task_4d_tile_2d_dynamic_t) - xnn_compute_grouped_batch_dqigemm; - } else { - igemm_compute->task_4d_tile_2d_dynamic = - (pthreadpool_task_4d_tile_2d_dynamic_t) - xnn_compute_grouped_batch_igemm; - } - } else { - igemm_compute->type = xnn_parallelization_type_3d_tile_2d_dynamic; - if (dynamic_quantization) { - igemm_compute->task_3d_tile_2d_dynamic = - (pthreadpool_task_3d_tile_2d_dynamic_t) - xnn_compute_grouped_dqigemm; - } else { - igemm_compute->task_3d_tile_2d_dynamic = - (pthreadpool_task_3d_tile_2d_dynamic_t)xnn_compute_grouped_igemm; - } - } - if (packed_lh_config && inline_lhs_packing) { - igemm_compute->range[0] = batch_size; - igemm_compute->range[1] = groups; - igemm_compute->range[2] = output_size; - igemm_compute->tile[0] = mr; - } else if (batch_size > 1) { - igemm_compute->range[0] = batch_size; - igemm_compute->range[1] = groups; - igemm_compute->range[2] = group_output_channels; - igemm_compute->range[3] = output_size; - igemm_compute->tile[0] = nc; - igemm_compute->tile[1] = mr; + igemm_compute->type = + xnn_parallelization_type_3d_tile_1d_dynamic_with_thread; + igemm_compute->task_3d_tile_1d_dynamic_with_id = + (pthreadpool_task_3d_tile_1d_dynamic_with_id_t) + xnn_compute_inline_packed_igemm; } else { - igemm_compute->range[0] = groups; - igemm_compute->range[1] = group_output_channels; - igemm_compute->range[2] = output_size; - igemm_compute->tile[0] = nc; - igemm_compute->tile[1] = mr; + igemm_compute->type = xnn_parallelization_type_4d_tile_2d_dynamic; + if (dynamic_quantization) { + igemm_compute->task_4d_tile_2d_dynamic = + (pthreadpool_task_4d_tile_2d_dynamic_t)xnn_compute_dqigemm; + } else { + igemm_compute->task_4d_tile_2d_dynamic = + (pthreadpool_task_4d_tile_2d_dynamic_t)xnn_compute_igemm; + } } + + if (packed_lh_config && inline_lhs_packing) { + igemm_compute->range[0] = batch_size; + igemm_compute->range[1] = groups; + igemm_compute->range[2] = output_size; + igemm_compute->tile[0] = mr; + } else { + igemm_compute->range[0] = batch_size; + igemm_compute->range[1] = groups; + igemm_compute->range[2] = group_output_channels; + igemm_compute->range[3] = output_size; + igemm_compute->tile[0] = nc; + igemm_compute->tile[1] = mr; } deconvolution_op->state = xnn_run_state_needs_setup; @@ -1701,50 +1592,28 @@ static enum xnn_status reshape_subconv2d_path( deconvolution_op->compute[igemm_compute_index].tile[0] = 1; ++igemm_compute_index; } - if (groups == 1) { - deconvolution_op->compute[igemm_compute_index].type = - xnn_parallelization_type_5d_tile_2d; - if (dynamic_quantization) { - deconvolution_op->compute[igemm_compute_index].task_5d_tile_2d = - (pthreadpool_task_5d_tile_2d_t)xnn_compute_dqsubconv2d; - } else { - deconvolution_op->compute[igemm_compute_index].task_5d_tile_2d = - (pthreadpool_task_5d_tile_2d_t)xnn_compute_subconv2d; - } - deconvolution_op->compute[igemm_compute_index].range[0] = batch_size; - deconvolution_op->compute[igemm_compute_index].range[1] = - stride_height * stride_width; - deconvolution_op->compute[igemm_compute_index].range[2] = - output_height_positions; - deconvolution_op->compute[igemm_compute_index].range[3] = - output_width_positions; - deconvolution_op->compute[igemm_compute_index].range[4] = - group_output_channels; - deconvolution_op->compute[igemm_compute_index].tile[0] = mr; - deconvolution_op->compute[igemm_compute_index].tile[1] = nc; + + deconvolution_op->compute[igemm_compute_index].type = + xnn_parallelization_type_6d_tile_2d; + if (dynamic_quantization) { + deconvolution_op->compute[igemm_compute_index].task_6d_tile_2d = + (pthreadpool_task_6d_tile_2d_t)xnn_compute_dqsubconv2d; } else { - deconvolution_op->compute[igemm_compute_index].type = - xnn_parallelization_type_6d_tile_2d; - if (dynamic_quantization) { - deconvolution_op->compute[igemm_compute_index].task_6d_tile_2d = - (pthreadpool_task_6d_tile_2d_t)xnn_compute_grouped_dqsubconv2d; - } else { - deconvolution_op->compute[igemm_compute_index].task_6d_tile_2d = - (pthreadpool_task_6d_tile_2d_t)xnn_compute_grouped_subconv2d; - } - deconvolution_op->compute[igemm_compute_index].range[0] = batch_size; - deconvolution_op->compute[igemm_compute_index].range[1] = groups; - deconvolution_op->compute[igemm_compute_index].range[2] = - stride_height * stride_width; - deconvolution_op->compute[igemm_compute_index].range[3] = - output_height_positions; - deconvolution_op->compute[igemm_compute_index].range[4] = - output_width_positions; - deconvolution_op->compute[igemm_compute_index].range[5] = - group_output_channels; - deconvolution_op->compute[igemm_compute_index].tile[0] = mr; - deconvolution_op->compute[igemm_compute_index].tile[1] = nc; - } + deconvolution_op->compute[igemm_compute_index].task_6d_tile_2d = + (pthreadpool_task_6d_tile_2d_t)xnn_compute_subconv2d; + } + deconvolution_op->compute[igemm_compute_index].range[0] = batch_size; + deconvolution_op->compute[igemm_compute_index].range[1] = groups; + deconvolution_op->compute[igemm_compute_index].range[2] = + stride_height * stride_width; + deconvolution_op->compute[igemm_compute_index].range[3] = + output_height_positions; + deconvolution_op->compute[igemm_compute_index].range[4] = + output_width_positions; + deconvolution_op->compute[igemm_compute_index].range[5] = + group_output_channels; + deconvolution_op->compute[igemm_compute_index].tile[0] = mr; + deconvolution_op->compute[igemm_compute_index].tile[1] = nc; deconvolution_op->state = xnn_run_state_needs_setup; return xnn_status_success; diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h index 6dd285c9ecf..84c53a5e26a 100644 --- a/src/xnnpack/compute.h +++ b/src/xnnpack/compute.h @@ -402,7 +402,6 @@ XNN_PRIVATE void xnn_compute_grouped_inline_packed_qp8gemm( struct gemm_context* context, uint32_t thread_id, size_t group_index, size_t mr_block_start, size_t mr_block_size); -#if XNN_MAX_UARCH_TYPES > 1 XNN_PRIVATE void xnn_compute_hmp_grouped_gemm( struct gemm_context* context, uint32_t uarch_index, size_t group_index, size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, @@ -432,7 +431,6 @@ XNN_PRIVATE void xnn_compute_hmp_qp8gemm( XNN_PRIVATE void xnn_compute_hmp_inline_packed_qp8gemm( struct gemm_context* context, uint32_t uarch_index, size_t thread_id, size_t mr_block_start, size_t mr_block_size); -#endif // XNN_MAX_UARCH_TYPES > 1 // Context for Sparse Matrix-Dense Matrix Multiplication. // C [MxN] := A [MxK] * B [KxN] + bias [N] @@ -569,134 +567,43 @@ struct igemm_context { size_t per_thread_workspace_size; }; -XNN_PRIVATE void xnn_compute_grouped_dqigemm( - struct igemm_context* context, size_t group_index, size_t nr_block_start, - size_t mr_block_start, size_t nr_block_size, size_t mr_block_size); - -XNN_PRIVATE void xnn_compute_grouped_igemm( - struct igemm_context* context, size_t group_index, size_t nr_block_start, - size_t mr_block_start, size_t nr_block_size, size_t mr_block_size); - -XNN_PRIVATE void xnn_compute_grouped_batch_dqigemm( - struct igemm_context* context, size_t batch_index, size_t group_index, - size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, - size_t mr_block_size); - -XNN_PRIVATE void xnn_compute_grouped_batch_igemm( - struct igemm_context* context, size_t batch_index, size_t group_index, - size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, - size_t mr_block_size); - -XNN_PRIVATE void xnn_compute_dq_zero_buffer_igemm(struct igemm_context* context, - size_t batch_index); - XNN_PRIVATE void xnn_compute_dqigemm(struct igemm_context* context, + size_t batch_index, size_t group_index, size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, size_t mr_block_size); XNN_PRIVATE void xnn_compute_igemm(struct igemm_context* context, + size_t batch_index, size_t group_index, size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, size_t mr_block_size); +XNN_PRIVATE void xnn_compute_dq_zero_buffer_igemm(struct igemm_context* context, + size_t batch_index); + XNN_PRIVATE void xnn_compute_conv2d_igemm_indirection( struct conv2d_igemm_indirection_init_context* context, size_t output_tile_start, size_t output_tile_size); -XNN_PRIVATE void xnn_compute_batch_dqigemm( - struct igemm_context* context, size_t batch_index, size_t nr_block_start, - size_t mr_block_start, size_t nr_block_size, size_t mr_block_size); - -XNN_PRIVATE void xnn_compute_batch_igemm( - struct igemm_context* context, size_t batch_index, size_t nr_block_start, - size_t mr_block_start, size_t nr_block_size, size_t mr_block_size); - -XNN_PRIVATE void xnn_compute_batch_inline_packed_igemm( - struct igemm_context* context, uint32_t thread_id, size_t batch_index, - size_t mr_block_start, size_t mr_block_size); - -XNN_PRIVATE void xnn_compute_grouped_batch_inline_packed_igemm( +XNN_PRIVATE void xnn_compute_inline_packed_igemm( struct igemm_context* context, uint32_t thread_id, size_t batch_index, size_t group_index, size_t mr_block_start, size_t mr_block_size); -#if XNN_MAX_UARCH_TYPES > 1 -XNN_PRIVATE void xnn_compute_hmp_grouped_igemm( - struct igemm_context* context, uint32_t uarch_index, size_t group_index, - size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, - size_t mr_block_size); - -XNN_PRIVATE void xnn_compute_hmp_grouped_dqigemm( - struct igemm_context* context, uint32_t uarch_index, size_t group_index, - size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, - size_t mr_block_size); - -XNN_PRIVATE void xnn_compute_hmp_grouped_batch_dqigemm( +XNN_PRIVATE void xnn_compute_hmp_dqigemm( struct igemm_context* context, uint32_t uarch_index, size_t batch_index, size_t group_index, size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, size_t mr_block_size); -XNN_PRIVATE void xnn_compute_hmp_grouped_batch_igemm( +XNN_PRIVATE void xnn_compute_hmp_igemm( struct igemm_context* context, uint32_t uarch_index, size_t batch_index, size_t group_index, size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, size_t mr_block_size); -XNN_PRIVATE void xnn_compute_hmp_dqigemm( - struct igemm_context* context, uint32_t uarch_index, size_t nr_block_start, - size_t mr_block_start, size_t nr_block_size, size_t mr_block_size); - -XNN_PRIVATE void xnn_compute_hmp_igemm( - struct igemm_context* context, uint32_t uarch_index, size_t nr_block_start, - size_t mr_block_start, size_t nr_block_size, size_t mr_block_size); - -XNN_PRIVATE void xnn_compute_batch_hmp_dqigemm( - struct igemm_context* context, uint32_t uarch_index, size_t batch_index, - size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, - size_t mr_block_size); - -XNN_PRIVATE void xnn_compute_batch_hmp_igemm( - struct igemm_context* context, uint32_t uarch_index, size_t batch_index, - size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, - size_t mr_block_size); - -XNN_PRIVATE void xnn_compute_batch_hmp_inline_packed_igemm( - struct igemm_context* context, uint32_t uarch_index, size_t thread_id, - size_t batch_index, size_t mr_block_start, size_t mr_block_size); - -XNN_PRIVATE void xnn_compute_grouped_batch_hmp_inline_packed_igemm( +XNN_PRIVATE void xnn_compute_hmp_inline_packed_igemm( struct igemm_context* context, uint32_t uarch_index, size_t thread_id, size_t batch_index, size_t group_index, size_t mr_block_start, size_t mr_block_size); -#endif // XNN_MAX_UARCH_TYPES > 1 - -struct subgemm_context { - const struct subconvolution_params* subconvolution_params; - size_t kc; - const void* a; - size_t ax_stride; - size_t ay_stride; - size_t cx_stride; - size_t cy_stride; - size_t cn_stride; - size_t ga_stride; - size_t gw_stride; - size_t gc_stride; - size_t ba_stride; - size_t bc_stride; - uint32_t log2_csize; - struct xnn_hmp_gemm_ukernel ukernel; - union { - union xnn_qs8_conv_minmax_params qs8; - union xnn_qu8_conv_minmax_params qu8; - struct xnn_f16_scaleminmax_params f16; - struct xnn_f32_minmax_params f32; - } params; -}; - -XNN_PRIVATE void xnn_compute_grouped_subgemm2d( - struct subgemm_context* context, size_t batch_index, size_t group_index, - size_t subkernel_index, size_t slice_y, size_t slice_x_start, - size_t nc_block_start, size_t slice_x_max, size_t nc_block_size); struct subconv_context { const struct subconvolution_params* subconvolution_params; @@ -733,26 +640,16 @@ struct subconv_context { XNN_PRIVATE void xnn_compute_dq_zero_buffer_subconv( struct subconv_context* context, size_t batch_index, size_t batch_size); -XNN_PRIVATE void xnn_compute_grouped_subconv2d( +XNN_PRIVATE void xnn_compute_subconv2d( struct subconv_context* context, size_t batch_index, size_t group_index, size_t subkernel_index, size_t slice_y, size_t slice_x_start, size_t nc_block_start, size_t slice_x_max, size_t nc_block_size); -XNN_PRIVATE void xnn_compute_grouped_dqsubconv2d( +XNN_PRIVATE void xnn_compute_dqsubconv2d( struct subconv_context* context, size_t batch_index, size_t group_index, size_t subkernel_index, size_t slice_y, size_t slice_x_start, size_t nc_block_start, size_t slice_x_max, size_t nc_block_size); -XNN_PRIVATE void xnn_compute_subconv2d( - struct subconv_context* context, size_t batch_index, size_t subkernel_index, - size_t slice_y, size_t slice_x_start, size_t nc_block_start, - size_t slice_x_max, size_t nc_block_size); - -XNN_PRIVATE void xnn_compute_dqsubconv2d( - struct subconv_context* context, size_t batch_index, size_t subkernel_index, - size_t slice_y, size_t slice_x_start, size_t nc_block_start, - size_t slice_x_max, size_t nc_block_size); - struct conv2d_context { size_t input_height; size_t input_width; diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h index 00e5610aa92..c1c71e0002e 100644 --- a/src/xnnpack/operator.h +++ b/src/xnnpack/operator.h @@ -349,7 +349,6 @@ struct xnn_operator { struct slice_context slice; struct spmm_context spmm; struct subconv_context subconv; - struct subgemm_context subgemm; struct transpose_context transpose; struct floating_point_softmax_context floating_point_softmax; struct u8_softmax_context u8_softmax;