Skip to content

Commit 08d8a18

Browse files
gonnetxnnpack-bot
authored andcommitted
Split large fully-connected nodes along k.
This is just a `subgraph` rewrite that replaces nodes that look like ``` fully_connected(input_id, weights_id, bias_id) ``` where the data along the `k`th dimension exceeds 16kB, with ``` fully_connected(static_slice(K0, input_id), static_slice(K0, weights_id), bias_id) + fully_connected(static_slice(K1, input_id), static_slice(K1, weights_id), 0) + ... + fully_connected(static_slice(Kn, input_id), static_slice(Kn, weights_id), 0) ``` where `static_slice(ival, value_id)` slices `value_id` in the `k` dimension. PiperOrigin-RevId: 806264392
1 parent 6eba261 commit 08d8a18

File tree

16 files changed

+777
-207
lines changed

16 files changed

+777
-207
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,7 @@ xnnpack_cc_library(
839839
":memory",
840840
":microkernel_hdrs",
841841
":microkernel_type",
842+
":microkernel_utils",
842843
":mutex",
843844
":node_type",
844845
":operator_type",

src/datatype.c

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
// Copyright 2019 Google LLC
1+
// Copyright 2019-2025 Google LLC
22
//
33
// This source code is licensed under the BSD-style license found in the
44
// LICENSE file in the root directory of this source tree.
55

66
#include "src/xnnpack/datatype.h"
77

8+
#include <assert.h>
9+
#include <stddef.h>
10+
811
#include "include/xnnpack.h"
12+
#include "src/xnnpack/common.h"
913

1014
bool xnn_datatype_is_real(enum xnn_datatype t) {
1115
switch (t) {
@@ -88,6 +92,20 @@ bool xnn_datatype_is_quantized(enum xnn_datatype t) {
8892
return false;
8993
}
9094

95+
bool xnn_datatype_is_dynamically_quantized(enum xnn_datatype t) {
96+
switch (t) {
97+
case xnn_datatype_pqint8:
98+
case xnn_datatype_qdint8:
99+
case xnn_datatype_qduint8:
100+
case xnn_datatype_qpint8:
101+
return true;
102+
default:
103+
return false;
104+
}
105+
XNN_UNREACHABLE;
106+
return false;
107+
}
108+
91109
bool xnn_datatype_is_channelwise_quantized(enum xnn_datatype t) {
92110
switch (t) {
93111
case xnn_datatype_qcint8:

src/operator-run.c

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,25 +1494,30 @@ void xnn_compute_pad_5d(struct pad_context* restrict context, size_t i,
14941494
}
14951495
}
14961496

1497-
void xnn_compute_slice_1d(struct slice_context* restrict context, size_t i) {
1498-
const void* input =
1499-
(const void*)((uintptr_t)context->input + i * context->input_stride[0]);
1500-
void* output =
1501-
(void*)((uintptr_t)context->output + i * context->output_stride[0]);
1497+
void xnn_compute_slice_1d(struct slice_context* restrict context, size_t offset,
1498+
size_t count) {
1499+
for (size_t i = offset; i < offset + count; i++) {
1500+
const void* input =
1501+
(const void*)((uintptr_t)context->input + i * context->input_stride[0]);
1502+
void* output =
1503+
(void*)((uintptr_t)context->output + i * context->output_stride[0]);
15021504

1503-
context->ukernel(context->contiguous_size, input, output, NULL);
1505+
context->ukernel(context->contiguous_size, input, output, NULL);
1506+
}
15041507
}
15051508

15061509
void xnn_compute_slice_2d(struct slice_context* restrict context, size_t i,
1507-
size_t j) {
1508-
const void* input =
1509-
(const void*)((uintptr_t)context->input + i * context->input_stride[1] +
1510-
j * context->input_stride[0]);
1511-
void* output =
1512-
(void*)((uintptr_t)context->output + i * context->output_stride[1] +
1513-
j * context->output_stride[0]);
1510+
size_t offset, size_t count) {
1511+
for (size_t j = offset; j < offset + count; j++) {
1512+
const void* input =
1513+
(const void*)((uintptr_t)context->input + i * context->input_stride[1] +
1514+
j * context->input_stride[0]);
1515+
void* output =
1516+
(void*)((uintptr_t)context->output + i * context->output_stride[1] +
1517+
j * context->output_stride[0]);
15141518

1515-
context->ukernel(context->contiguous_size, input, output, NULL);
1519+
context->ukernel(context->contiguous_size, input, output, NULL);
1520+
}
15161521
}
15171522

15181523
void xnn_compute_slice_3d(struct slice_context* restrict context, size_t i,

src/operators/fully-connected-nc.c

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2302,9 +2302,9 @@ static enum xnn_status reshape_fully_connected_nc(
23022302
const size_t nc = xnn_gemm_best_tile_size(
23032303
/*num_groups=*/1, /*m=*/batch_size, /*n=*/output_channels,
23042304
/*m_stride=*/
2305-
fully_connected_op->input_pixel_stride
2306-
<< (packed_lh_config ? packed_lh_config->log2_packed_element_size
2307-
: log2_input_element_size),
2305+
input_channels << (packed_lh_config
2306+
? packed_lh_config->log2_packed_element_size
2307+
: log2_input_element_size),
23082308
/*n_stride=*/
23092309
fully_connected_op->weights_stride,
23102310
/*cn_stride=*/1 << log2_output_element_size, mr, nr,
@@ -2362,7 +2362,7 @@ static enum xnn_status reshape_fully_connected_nc(
23622362
.mr = mr_packed,
23632363
.kr = kr,
23642364
.sr = sr,
2365-
.lhs_stride = input_channels
2365+
.lhs_stride = fully_connected_op->input_pixel_stride
23662366
<< packed_lh_config->log2_input_element_size,
23672367
.packed_offset_fn = packed_lh_config->offset_fn,
23682368
.pack_lh_ukernel = packed_lh_config->pack_lh_fn,
@@ -2411,6 +2411,7 @@ static enum xnn_status reshape_fully_connected_nc(
24112411
<< log2_output_element_size,
24122412
.cn_stride = nr << log2_output_element_size,
24132413
.log2_csize = log2_output_element_size,
2414+
.log2_asize = log2_input_element_size,
24142415
.ukernel = gemm_ukernel,
24152416
.mr = mr,
24162417
.nc = output_channels,
@@ -2873,6 +2874,11 @@ static enum xnn_status setup_fully_connected_nc(
28732874
struct gemm_op_context* gemm_context =
28742875
fully_connected_op->dynamic_context.gemm;
28752876

2877+
input =
2878+
(const void*)((uintptr_t)input +
2879+
(fully_connected_op->fully_connected.input_element_offset
2880+
<< gemm_context->gemm.log2_asize));
2881+
28762882
if (fully_connected_op->num_compute_invocations == 2) {
28772883
gemm_context->pack_lh.lhs = input;
28782884
gemm_context->pack_lh.lhs_packed = workspace;

src/operators/slice-nd.c

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "src/xnnpack/config-types.h"
1717
#include "src/xnnpack/config.h"
1818
#include "src/xnnpack/log.h"
19+
#include "src/xnnpack/math.h"
1920
#include "src/xnnpack/normalization.h"
2021
#include "src/xnnpack/operator-type.h"
2122
#include "src/xnnpack/operator-utils.h"
@@ -221,15 +222,21 @@ static enum xnn_status reshape_slice_nd(
221222
switch (num_normalized_dims) {
222223
case 1:
223224
case 2:
224-
slice_op->compute[0].type = xnn_parallelization_type_1d;
225-
slice_op->compute[0].task_1d = (pthreadpool_task_1d_t)xnn_compute_slice_1d;
225+
slice_op->compute[0].type = xnn_parallelization_type_1d_tile_1d_dynamic;
226+
slice_op->compute[0].task_1d_tile_1d_dynamic =
227+
(pthreadpool_task_1d_tile_1d_dynamic_t)xnn_compute_slice_1d;
226228
slice_op->compute[0].range[0] = normalized_output_shape[XNN_MAX_TENSOR_DIMS - 2];
229+
slice_op->compute[0].tile[0] =
230+
divide_round_up(slice_op->context.slice.contiguous_size, 32 * 1024);
227231
break;
228232
case 3:
229-
slice_op->compute[0].type = xnn_parallelization_type_2d;
230-
slice_op->compute[0].task_2d = (pthreadpool_task_2d_t) xnn_compute_slice_2d;
233+
slice_op->compute[0].type = xnn_parallelization_type_2d_tile_1d_dynamic;
234+
slice_op->compute[0].task_2d_tile_1d_dynamic =
235+
(pthreadpool_task_2d_tile_1d_dynamic_t)xnn_compute_slice_2d;
231236
slice_op->compute[0].range[0] = normalized_output_shape[XNN_MAX_TENSOR_DIMS - 3];
232237
slice_op->compute[0].range[1] = normalized_output_shape[XNN_MAX_TENSOR_DIMS - 2];
238+
slice_op->compute[0].tile[0] =
239+
divide_round_up(slice_op->context.slice.contiguous_size, 32 * 1024);
233240
break;
234241
case 4:
235242
slice_op->compute[0].type = xnn_parallelization_type_3d;

src/runtime.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,8 @@ enum xnn_status xnn_create_runtime_with_threadpool(
828828
threadpool =
829829
pthreadpool_create_v2(&executor, xnn_threadpool->scheduler_context, 0);
830830
flags |= XNN_FLAG_RUNTIME_OWNS_THREADPOOL;
831+
xnn_log_info("Created pthreadpool from scheduler with %zu threads.",
832+
pthreadpool_get_threads_count(threadpool));
831833
#endif // XNN_SLINKY_ENABLED
832834
}
833835

0 commit comments

Comments
 (0)