Skip to content

Commit 3dbc2f2

Browse files
gonnetxnnpack-bot
authored andcommitted
Split large fully-connected nodes along k.
This is just a `subgraph` rewrite that replaces nodes that look like ``` fully_connected(input_id, weights_id, bias_id) ``` where the data along the `k`th dimension exceeds 16kB, with ``` fully_connected(static_slice(K0, input_id), static_slice(K0, weights_id), bias_id) + fully_connected(static_slice(K1, input_id), static_slice(K1, weights_id), 0) + ... + fully_connected(static_slice(Kn, input_id), static_slice(Kn, weights_id), 0) ``` where `static_slice(ival, value_id)` slices `value_id` in the `k` dimension. PiperOrigin-RevId: 806264392
1 parent b2a89c9 commit 3dbc2f2

21 files changed

+634
-187
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,7 @@ xnnpack_cc_library(
839839
":memory",
840840
":microkernel_hdrs",
841841
":microkernel_type",
842+
":microkernel_utils",
842843
":mutex",
843844
":node_type",
844845
":operator_type",

cmake/gen/amd64_microkernels.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@ SET(PROD_AMD64_ASM_MICROKERNEL_SRCS
1414
src/bf16-f32-gemm/gen/bf16-f32-gemm-11x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
1515
src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-fma3-broadcast.S
1616
src/f32-gemm/gen/f32-gemm-1x16c2-minmax-asm-amd64-avx512f-broadcast.S
17-
src/f32-gemm/gen/f32-gemm-1x32c2-minmax-asm-amd64-avx512f-broadcast.S
1817
src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-fma3-broadcast.S
19-
src/f32-gemm/gen/f32-gemm-5x32c2-minmax-asm-amd64-avx512f-broadcast.S
2018
src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-fma3-broadcast.S
2119
src/f32-gemm/gen/f32-gemm-10x16c2-minmax-asm-amd64-avx512f-broadcast.S
2220
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S
@@ -53,6 +51,7 @@ SET(NON_PROD_AMD64_ASM_MICROKERNEL_SRCS
5351
src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-amd64-fma3-broadcast.S
5452
src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S
5553
src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S
54+
src/f32-gemm/gen/f32-gemm-1x32c2-minmax-asm-amd64-avx512f-broadcast.S
5655
src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S
5756
src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-amd64-fma3-broadcast.S
5857
src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S
@@ -79,6 +78,7 @@ SET(NON_PROD_AMD64_ASM_MICROKERNEL_SRCS
7978
src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-fma3-broadcast.S
8079
src/f32-gemm/gen/f32-gemm-5x16c2-minmax-asm-amd64-avx512f-broadcast.S
8180
src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S
81+
src/f32-gemm/gen/f32-gemm-5x32c2-minmax-asm-amd64-avx512f-broadcast.S
8282
src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S
8383
src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-amd64-fma3-broadcast.S
8484
src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S

cmake/gen/avx512f_microkernels.cmake

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ SET(PROD_AVX512F_MICROKERNEL_SRCS
1717
src/f32-dwconv/gen/f32-dwconv-25p16c-minmax-avx512f.c
1818
src/f32-gemm/gen/f32-gemm-1x16-minmax-avx512f-broadcast.c
1919
src/f32-gemm/gen/f32-gemm-1x32-minmax-avx512f-broadcast.c
20+
src/f32-gemm/gen/f32-gemm-1x64-minmax-avx512f-broadcast.c
21+
src/f32-gemm/gen/f32-gemm-5x64-minmax-avx512f-broadcast.c
2022
src/f32-gemm/gen/f32-gemm-7x16-minmax-avx512f-broadcast.c
2123
src/f32-gemm/gen/f32-gemm-7x32-minmax-avx512f-broadcast.c
2224
src/f32-igemm/gen/f32-igemm-1x16-minmax-avx512f-broadcast.c
@@ -81,8 +83,7 @@ SET(PROD_AVX512F_MICROKERNEL_SRCS
8183
src/x32-packw/gen/x32-packw-x16-gemm-gio-avx512f-u8.c
8284
src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4-prfm.c
8385
src/x32-packw/gen/x32-packw-x32-gemm-gio-avx512f-u8.c
84-
src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4-prfm.c
85-
src/x32-packw/gen/x32-packw-x32c2-gemm-goi-avx512f-u4-prfm.c)
86+
src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4-prfm.c)
8687

8788
SET(NON_PROD_AVX512F_MICROKERNEL_SRCS
8889
src/f32-dwconv/gen/f32-dwconv-3p16c-minmax-avx512f-acc2.c
@@ -97,13 +98,11 @@ SET(NON_PROD_AVX512F_MICROKERNEL_SRCS
9798
src/f32-dwconv/gen/f32-dwconv-25p16c-minmax-avx512f-acc2.c
9899
src/f32-dwconv/gen/f32-dwconv-25p32c-minmax-avx512f-acc2.c
99100
src/f32-dwconv/gen/f32-dwconv-25p32c-minmax-avx512f.c
100-
src/f32-gemm/gen/f32-gemm-1x64-minmax-avx512f-broadcast.c
101101
src/f32-gemm/gen/f32-gemm-4x16-minmax-avx512f-broadcast.c
102102
src/f32-gemm/gen/f32-gemm-4x32-minmax-avx512f-broadcast.c
103103
src/f32-gemm/gen/f32-gemm-4x64-minmax-avx512f-broadcast.c
104104
src/f32-gemm/gen/f32-gemm-5x16-minmax-avx512f-broadcast.c
105105
src/f32-gemm/gen/f32-gemm-5x32-minmax-avx512f-broadcast.c
106-
src/f32-gemm/gen/f32-gemm-5x64-minmax-avx512f-broadcast.c
107106
src/f32-gemm/gen/f32-gemm-6x16-minmax-avx512f-broadcast.c
108107
src/f32-gemm/gen/f32-gemm-6x32-minmax-avx512f-broadcast.c
109108
src/f32-gemm/gen/f32-gemm-6x64-minmax-avx512f-broadcast.c
@@ -279,6 +278,7 @@ SET(NON_PROD_AVX512F_MICROKERNEL_SRCS
279278
src/x32-packw/gen/x32-packw-x32-gemm-gio-avx512f-u1.c
280279
src/x32-packw/gen/x32-packw-x32-gemm-gio-avx512f-u8-prfm.c
281280
src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4.c
281+
src/x32-packw/gen/x32-packw-x32c2-gemm-goi-avx512f-u4-prfm.c
282282
src/x32-packw/gen/x32-packw-x32c2-gemm-goi-avx512f-u4.c)
283283

284284
SET(ALL_AVX512F_MICROKERNEL_SRCS ${PROD_AVX512F_MICROKERNEL_SRCS} + ${NON_PROD_AVX512F_MICROKERNEL_SRCS})

gen/amd64_microkernels.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ PROD_AMD64_ASM_MICROKERNEL_SRCS = [
1010
"src/bf16-f32-gemm/gen/bf16-f32-gemm-11x32c2-minmax-asm-amd64-avx512bf16-broadcast.S",
1111
"src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-fma3-broadcast.S",
1212
"src/f32-gemm/gen/f32-gemm-1x16c2-minmax-asm-amd64-avx512f-broadcast.S",
13-
"src/f32-gemm/gen/f32-gemm-1x32c2-minmax-asm-amd64-avx512f-broadcast.S",
1413
"src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-fma3-broadcast.S",
15-
"src/f32-gemm/gen/f32-gemm-5x32c2-minmax-asm-amd64-avx512f-broadcast.S",
1614
"src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-fma3-broadcast.S",
1715
"src/f32-gemm/gen/f32-gemm-10x16c2-minmax-asm-amd64-avx512f-broadcast.S",
1816
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S",
@@ -50,6 +48,7 @@ NON_PROD_AMD64_ASM_MICROKERNEL_SRCS = [
5048
"src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-amd64-fma3-broadcast.S",
5149
"src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S",
5250
"src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S",
51+
"src/f32-gemm/gen/f32-gemm-1x32c2-minmax-asm-amd64-avx512f-broadcast.S",
5352
"src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S",
5453
"src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-amd64-fma3-broadcast.S",
5554
"src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S",
@@ -76,6 +75,7 @@ NON_PROD_AMD64_ASM_MICROKERNEL_SRCS = [
7675
"src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-fma3-broadcast.S",
7776
"src/f32-gemm/gen/f32-gemm-5x16c2-minmax-asm-amd64-avx512f-broadcast.S",
7877
"src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S",
78+
"src/f32-gemm/gen/f32-gemm-5x32c2-minmax-asm-amd64-avx512f-broadcast.S",
7979
"src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S",
8080
"src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-amd64-fma3-broadcast.S",
8181
"src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S",

gen/avx512f_microkernels.bzl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ PROD_AVX512F_MICROKERNEL_SRCS = [
1313
"src/f32-dwconv/gen/f32-dwconv-25p16c-minmax-avx512f.c",
1414
"src/f32-gemm/gen/f32-gemm-1x16-minmax-avx512f-broadcast.c",
1515
"src/f32-gemm/gen/f32-gemm-1x32-minmax-avx512f-broadcast.c",
16+
"src/f32-gemm/gen/f32-gemm-1x64-minmax-avx512f-broadcast.c",
17+
"src/f32-gemm/gen/f32-gemm-5x64-minmax-avx512f-broadcast.c",
1618
"src/f32-gemm/gen/f32-gemm-7x16-minmax-avx512f-broadcast.c",
1719
"src/f32-gemm/gen/f32-gemm-7x32-minmax-avx512f-broadcast.c",
1820
"src/f32-igemm/gen/f32-igemm-1x16-minmax-avx512f-broadcast.c",
@@ -78,7 +80,6 @@ PROD_AVX512F_MICROKERNEL_SRCS = [
7880
"src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4-prfm.c",
7981
"src/x32-packw/gen/x32-packw-x32-gemm-gio-avx512f-u8.c",
8082
"src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4-prfm.c",
81-
"src/x32-packw/gen/x32-packw-x32c2-gemm-goi-avx512f-u4-prfm.c",
8283
]
8384

8485
NON_PROD_AVX512F_MICROKERNEL_SRCS = [
@@ -94,13 +95,11 @@ NON_PROD_AVX512F_MICROKERNEL_SRCS = [
9495
"src/f32-dwconv/gen/f32-dwconv-25p16c-minmax-avx512f-acc2.c",
9596
"src/f32-dwconv/gen/f32-dwconv-25p32c-minmax-avx512f-acc2.c",
9697
"src/f32-dwconv/gen/f32-dwconv-25p32c-minmax-avx512f.c",
97-
"src/f32-gemm/gen/f32-gemm-1x64-minmax-avx512f-broadcast.c",
9898
"src/f32-gemm/gen/f32-gemm-4x16-minmax-avx512f-broadcast.c",
9999
"src/f32-gemm/gen/f32-gemm-4x32-minmax-avx512f-broadcast.c",
100100
"src/f32-gemm/gen/f32-gemm-4x64-minmax-avx512f-broadcast.c",
101101
"src/f32-gemm/gen/f32-gemm-5x16-minmax-avx512f-broadcast.c",
102102
"src/f32-gemm/gen/f32-gemm-5x32-minmax-avx512f-broadcast.c",
103-
"src/f32-gemm/gen/f32-gemm-5x64-minmax-avx512f-broadcast.c",
104103
"src/f32-gemm/gen/f32-gemm-6x16-minmax-avx512f-broadcast.c",
105104
"src/f32-gemm/gen/f32-gemm-6x32-minmax-avx512f-broadcast.c",
106105
"src/f32-gemm/gen/f32-gemm-6x64-minmax-avx512f-broadcast.c",
@@ -276,6 +275,7 @@ NON_PROD_AVX512F_MICROKERNEL_SRCS = [
276275
"src/x32-packw/gen/x32-packw-x32-gemm-gio-avx512f-u1.c",
277276
"src/x32-packw/gen/x32-packw-x32-gemm-gio-avx512f-u8-prfm.c",
278277
"src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4.c",
278+
"src/x32-packw/gen/x32-packw-x32c2-gemm-goi-avx512f-u4-prfm.c",
279279
"src/x32-packw/gen/x32-packw-x32c2-gemm-goi-avx512f-u4.c",
280280
]
281281

src/configs/gemm-config.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -802,14 +802,14 @@ static void init_f32_gemm_config_impl(struct xnn_gemm_config* f32_gemm_config, b
802802
(void) hardware_config; // May be unused.
803803
#if XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && !XNN_PLATFORM_WINDOWS && XNN_ENABLE_ASSEMBLY
804804
if (!consistent_arithmetic && hardware_config->arch_flags & xnn_arch_x86_avx512f) {
805-
f32_gemm_config->minmax.gemm[XNN_MR_TO_INDEX(1)] = XNN_INIT_HMP_GEMM_UKERNEL(xnn_f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast);
806-
f32_gemm_config->minmax.gemm[XNN_MR_TO_INDEX(5)] = XNN_INIT_HMP_GEMM_UKERNEL(xnn_f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast);
805+
f32_gemm_config->minmax.gemm[XNN_MR_TO_INDEX(1)] = XNN_INIT_HMP_GEMM_UKERNEL(xnn_f32_gemm_minmax_ukernel_1x64__avx512f_broadcast);
806+
f32_gemm_config->minmax.gemm[XNN_MR_TO_INDEX(5)] = XNN_INIT_HMP_GEMM_UKERNEL(xnn_f32_gemm_minmax_ukernel_5x64__avx512f_broadcast);
807807
f32_gemm_config->init.f32 = xnn_init_f32_minmax_scalar_params;
808-
f32_gemm_config->pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w;
809-
f32_gemm_config->pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x32c2__avx512f_u4_prfm;
808+
f32_gemm_config->pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_x32_packw_gemm_gio_ukernel_x32__avx512f_u8;
809+
f32_gemm_config->pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x32__avx512f_u4_prfm;
810810
f32_gemm_config->mr = 5;
811-
f32_gemm_config->nr = 32;
812-
f32_gemm_config->log2_kr = 1;
811+
f32_gemm_config->nr = 64;
812+
f32_gemm_config->log2_kr = 0;
813813
f32_gemm_config->log2_sr = 0;
814814
} else
815815
#endif

src/datatype.c

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
// Copyright 2019 Google LLC
1+
// Copyright 2019-2025 Google LLC
22
//
33
// This source code is licensed under the BSD-style license found in the
44
// LICENSE file in the root directory of this source tree.
55

66
#include "src/xnnpack/datatype.h"
77

8+
#include <assert.h>
9+
#include <stddef.h>
10+
811
#include "include/xnnpack.h"
12+
#include "src/xnnpack/common.h"
913

1014
bool xnn_datatype_is_real(enum xnn_datatype t) {
1115
switch (t) {
@@ -88,6 +92,20 @@ bool xnn_datatype_is_quantized(enum xnn_datatype t) {
8892
return false;
8993
}
9094

95+
bool xnn_datatype_is_dynamically_quantized(enum xnn_datatype t) {
96+
switch (t) {
97+
case xnn_datatype_pqint8:
98+
case xnn_datatype_qdint8:
99+
case xnn_datatype_qduint8:
100+
case xnn_datatype_qpint8:
101+
return true;
102+
default:
103+
return false;
104+
}
105+
XNN_UNREACHABLE;
106+
return false;
107+
}
108+
91109
bool xnn_datatype_is_channelwise_quantized(enum xnn_datatype t) {
92110
switch (t) {
93111
case xnn_datatype_qcint8:

src/f32-gemm/gen/f32-gemm-5x32c2-minmax-asm-amd64-avx512f-broadcast.S

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast
4747
mov r11, [rsp + 80]
4848

4949
# Align the stack pointer.
50-
mov r13, rsp
51-
sub rsp, 64
52-
and rsp, 0xFFFFFFFFFFFFFFC0
50+
# mov r13, rsp
51+
# sub rsp, 64
52+
# and rsp, 0xFFFFFFFFFFFFFFC0
5353
# Store the old stack pointer containing the return address
54-
mov [rsp], r13
54+
# mov [rsp], r13
5555

5656
# Allocate some space on the stack.
5757
sub rsp, 192
@@ -329,8 +329,8 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast
329329

330330
.Lreturn:
331331
add rsp, 192
332-
mov r13, [rsp]
333-
mov rsp, r13
332+
# mov r13, [rsp]
333+
# mov rsp, r13
334334
# Restore the callee saved registers.
335335
pop r12
336336
pop r13

src/operator-run.c

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,25 +1494,30 @@ void xnn_compute_pad_5d(struct pad_context* restrict context, size_t i,
14941494
}
14951495
}
14961496

1497-
void xnn_compute_slice_1d(struct slice_context* restrict context, size_t i) {
1498-
const void* input =
1499-
(const void*)((uintptr_t)context->input + i * context->input_stride[0]);
1500-
void* output =
1501-
(void*)((uintptr_t)context->output + i * context->output_stride[0]);
1497+
void xnn_compute_slice_1d(struct slice_context* restrict context, size_t offset,
1498+
size_t count) {
1499+
for (size_t i = offset; i < offset + count; i++) {
1500+
const void* input =
1501+
(const void*)((uintptr_t)context->input + i * context->input_stride[0]);
1502+
void* output =
1503+
(void*)((uintptr_t)context->output + i * context->output_stride[0]);
15021504

1503-
context->ukernel(context->contiguous_size, input, output, NULL);
1505+
context->ukernel(context->contiguous_size, input, output, NULL);
1506+
}
15041507
}
15051508

15061509
void xnn_compute_slice_2d(struct slice_context* restrict context, size_t i,
1507-
size_t j) {
1508-
const void* input =
1509-
(const void*)((uintptr_t)context->input + i * context->input_stride[1] +
1510-
j * context->input_stride[0]);
1511-
void* output =
1512-
(void*)((uintptr_t)context->output + i * context->output_stride[1] +
1513-
j * context->output_stride[0]);
1510+
size_t offset, size_t count) {
1511+
for (size_t j = offset; j < offset + count; j++) {
1512+
const void* input =
1513+
(const void*)((uintptr_t)context->input + i * context->input_stride[1] +
1514+
j * context->input_stride[0]);
1515+
void* output =
1516+
(void*)((uintptr_t)context->output + i * context->output_stride[1] +
1517+
j * context->output_stride[0]);
15141518

1515-
context->ukernel(context->contiguous_size, input, output, NULL);
1519+
context->ukernel(context->contiguous_size, input, output, NULL);
1520+
}
15161521
}
15171522

15181523
void xnn_compute_slice_3d(struct slice_context* restrict context, size_t i,

src/operators/fully-connected-nc.c

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2302,9 +2302,9 @@ static enum xnn_status reshape_fully_connected_nc(
23022302
const size_t nc = xnn_gemm_best_tile_size(
23032303
/*num_groups=*/1, /*m=*/batch_size, /*n=*/output_channels,
23042304
/*m_stride=*/
2305-
fully_connected_op->input_pixel_stride
2306-
<< (packed_lh_config ? packed_lh_config->log2_packed_element_size
2307-
: log2_input_element_size),
2305+
input_channels << (packed_lh_config
2306+
? packed_lh_config->log2_packed_element_size
2307+
: log2_input_element_size),
23082308
/*n_stride=*/
23092309
fully_connected_op->weights_stride,
23102310
/*cn_stride=*/1 << log2_output_element_size, mr, nr,
@@ -2362,7 +2362,7 @@ static enum xnn_status reshape_fully_connected_nc(
23622362
.mr = mr_packed,
23632363
.kr = kr,
23642364
.sr = sr,
2365-
.lhs_stride = input_channels
2365+
.lhs_stride = fully_connected_op->input_pixel_stride
23662366
<< packed_lh_config->log2_input_element_size,
23672367
.packed_offset_fn = packed_lh_config->offset_fn,
23682368
.pack_lh_ukernel = packed_lh_config->pack_lh_fn,
@@ -2411,6 +2411,7 @@ static enum xnn_status reshape_fully_connected_nc(
24112411
<< log2_output_element_size,
24122412
.cn_stride = nr << log2_output_element_size,
24132413
.log2_csize = log2_output_element_size,
2414+
.log2_asize = log2_input_element_size,
24142415
.ukernel = gemm_ukernel,
24152416
.mr = mr,
24162417
.nc = output_channels,
@@ -2873,6 +2874,11 @@ static enum xnn_status setup_fully_connected_nc(
28732874
struct gemm_op_context* gemm_context =
28742875
fully_connected_op->dynamic_context.gemm;
28752876

2877+
input =
2878+
(const void*)((uintptr_t)input +
2879+
(fully_connected_op->fully_connected.input_element_offset
2880+
<< gemm_context->gemm.log2_asize));
2881+
28762882
if (fully_connected_op->num_compute_invocations == 2) {
28772883
gemm_context->pack_lh.lhs = input;
28782884
gemm_context->pack_lh.lhs_packed = workspace;

0 commit comments

Comments
 (0)