Skip to content

Commit 736076e

Browse files
gonnetxnnpack-bot
authored andcommitted
Convert batch-matrix-multiply nodes with 2D weights to fully-connected nodes.
PiperOrigin-RevId: 829464154
1 parent 6eba261 commit 736076e

File tree

4 files changed

+153
-47
lines changed

4 files changed

+153
-47
lines changed

src/subgraph.c

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3621,6 +3621,56 @@ static enum xnn_status optimize_common_subgraphs_gemm_rhs_transpose(
36213621
return xnn_status_success;
36223622
}
36233623

3624+
// Converts batch-matrix-multiply nodes with 2D weights to fully-connected nodes
3625+
// for consistency.
3626+
static enum xnn_status optimize_common_subgraphs_bmm_to_fc(
3627+
xnn_subgraph_t subgraph, uint32_t node_id, size_t* changes) {
3628+
struct xnn_node* node = &subgraph->nodes[node_id];
3629+
if (node->type != xnn_node_type_batch_matrix_multiply) {
3630+
return xnn_status_success;
3631+
}
3632+
3633+
const uint32_t input_a_id = node->inputs[0];
3634+
const uint32_t input_b_id = node->inputs[1];
3635+
const uint32_t output_id = node->outputs[0];
3636+
struct xnn_value* input_b_value = &subgraph->values[input_b_id];
3637+
const enum xnn_datatype packed_input_datatype = node->packed_input_datatype;
3638+
3639+
// Weights should have at least two dimensions, and batch dimensions
3640+
// should all be 1.
3641+
if (input_b_value->shape.num_dims != 2) {
3642+
return xnn_status_success;
3643+
}
3644+
3645+
// If the weights are dynamic, restrict to fp32/fp16.
3646+
if (!xnn_value_is_static(input_b_value->allocation_type) &&
3647+
!(input_b_value->datatype == xnn_datatype_fp32 ||
3648+
input_b_value->datatype == xnn_datatype_fp16)) {
3649+
return xnn_status_success;
3650+
}
3651+
3652+
// Replace with a fully-connected node.
3653+
XNN_RETURN_IF_ERROR(
3654+
xnn_define_fully_connected(
3655+
subgraph,
3656+
/*output_min=*/-INFINITY, /*output_max=*/INFINITY, input_a_id,
3657+
input_b_id, /*bias_id=*/XNN_INVALID_VALUE_ID, output_id,
3658+
node->flags ^ XNN_FLAG_TRANSPOSE_WEIGHTS),
3659+
"Failed to create new `fully_connected` node.");
3660+
node = &subgraph->nodes[node_id];
3661+
*node = subgraph->nodes[--subgraph->num_nodes];
3662+
node->id = node_id;
3663+
node->packed_input_datatype = packed_input_datatype;
3664+
3665+
xnn_log_info(
3666+
"Converted batch_matrix_multiply[#%u](v%03u, v%03u) to "
3667+
"fully_connected[#%u](v%03u, v%03u).",
3668+
node_id, input_a_id, input_b_id, node_id, input_a_id, input_b_id);
3669+
(*changes)++;
3670+
3671+
return xnn_status_success;
3672+
}
3673+
36243674
static enum xnn_status optimize_common_subgraphs_iter(
36253675
xnn_subgraph_t subgraph, uint32_t optimization_flags, size_t* changes) {
36263676
// Loop over the nodes in this subgraph.
@@ -3739,8 +3789,14 @@ static enum xnn_status optimize_common_subgraphs_iter(
37393789
// be pushed back to the static value.
37403790
break;
37413791

3742-
case xnn_node_type_fully_connected:
37433792
case xnn_node_type_batch_matrix_multiply:
3793+
// Convert batch-matrix-multiply nodes with 2D weights to
3794+
// fully-connected nodes for consistency.
3795+
XNN_RETURN_IF_ERROR(
3796+
optimize_common_subgraphs_bmm_to_fc(subgraph, node_id, changes));
3797+
XNN_FALLTHROUGH
3798+
3799+
case xnn_node_type_fully_connected:
37443800
// Merge or remove transposes of the RHS of a batch-matrix-multiply or
37453801
// fully-connected op.
37463802
XNN_RETURN_IF_ERROR(optimize_common_subgraphs_gemm_rhs_transpose(
@@ -4178,10 +4234,6 @@ enum xnn_status xnn_subgraph_optimize(xnn_subgraph_t subgraph,
41784234
return xnn_status_unsupported_hardware;
41794235
}
41804236

4181-
// Apply some common subgraph optimizations.
4182-
XNN_RETURN_IF_ERROR(
4183-
xnn_subgraph_optimize_common_subgraphs(subgraph, optimization_flags));
4184-
41854237
if ((optimization_flags & XNN_FLAG_FORCE_FP16_INFERENCE) &&
41864238
(!xnn_is_f16_compatible_config(hardware_config))) {
41874239
xnn_log_error(
@@ -4234,6 +4286,10 @@ enum xnn_status xnn_subgraph_optimize(xnn_subgraph_t subgraph,
42344286
XNN_RETURN_IF_ERROR(
42354287
xnn_subgraph_optimize_packed_lhs(subgraph, optimization_flags));
42364288

4289+
// Apply some common subgraph optimizations.
4290+
XNN_RETURN_IF_ERROR(
4291+
xnn_subgraph_optimize_common_subgraphs(subgraph, optimization_flags));
4292+
42374293
return xnn_status_success;
42384294
}
42394295

src/subgraph/fully-connected.c

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,17 @@ static enum xnn_status create_fully_connected_operator(
262262
const struct xnn_runtime_value* output_value = &values[output_id];
263263

264264
size_t output_channels, input_channels;
265+
const struct xnn_shape* filter_shape = &filter_value->shape;
265266
if (node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
266-
input_channels = filter_value->shape.dim[0];
267-
output_channels = filter_value->shape.dim[1];
267+
input_channels =
268+
xnn_shape_multiply_batch_dims(filter_shape, /*num_nonbatch_dims=*/1);
269+
output_channels = filter_shape->dim[filter_shape->num_dims - 1];
268270
} else {
269-
output_channels = filter_value->shape.dim[0];
271+
output_channels =
272+
xnn_shape_multiply_batch_dims(filter_shape, /*num_nonbatch_dims=*/1);
270273
// Note that for convolutions, the filter shape can be `[H, 1, 1, W]`, so we
271274
// need to look at the last dimension of the filter.
272-
input_channels = filter_value->shape.dim[filter_value->shape.num_dims - 1];
275+
input_channels = filter_shape->dim[filter_shape->num_dims - 1];
273276
}
274277

275278
const void* kernel_data = filter_value->data;
@@ -765,18 +768,20 @@ enum xnn_status resize_fully_connected_output_tensor(
765768
const uint32_t input_id = opdata->inputs[0];
766769
const struct xnn_runtime_value* input = &values[input_id];
767770

768-
output->shape.num_dims = input->shape.num_dims;
769-
// Infer output channels.
770-
const uint32_t filter_output_channel_index =
771-
(opdata->flags & XNN_FLAG_TRANSPOSE_WEIGHTS) ? 1 : 0;
772-
output->shape.dim[output->shape.num_dims - 1] =
773-
filter->shape.dim[filter_output_channel_index];
774-
775771
// Propagate input shape to output.
772+
output->shape.num_dims = input->shape.num_dims;
776773
for (size_t cur_dim = 0; cur_dim < input->shape.num_dims - 1; cur_dim++) {
777774
output->shape.dim[cur_dim] = input->shape.dim[cur_dim];
778775
}
779776

777+
// Infer output channels.
778+
const size_t filter_output_channels =
779+
(opdata->flags & XNN_FLAG_TRANSPOSE_WEIGHTS)
780+
? filter->shape.dim[filter->shape.num_dims - 1]
781+
: xnn_shape_multiply_batch_dims(&filter->shape,
782+
/*num_nonbatch_dims=*/1);
783+
output->shape.dim[output->shape.num_dims - 1] = filter_output_channels;
784+
780785
const size_t new_size = xnn_runtime_tensor_get_size(output);
781786
if (new_size > output->size || old_workspace_size < opdata->workspace_size) {
782787
output->size = new_size;
@@ -804,21 +809,22 @@ static enum xnn_status reshape_fully_connected_operator(
804809
if (output_value->flags & XNN_VALUE_FLAG_LAYOUT_NCHW) {
805810
return reshape_convolution_operator(opdata, values, num_values, threadpool);
806811
}
807-
const size_t num_input_elements =
808-
xnn_shape_multiply_all_dims(&input_value->shape);
809812
size_t output_channels, input_channels;
813+
const struct xnn_shape* filter_shape = &filter_value->shape;
810814
if (opdata->flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
811-
input_channels = filter_value->shape.dim[0];
812-
output_channels = filter_value->shape.dim[1];
815+
input_channels =
816+
xnn_shape_multiply_batch_dims(filter_shape, /*num_nonbatch_dims=*/1);
817+
output_channels = filter_shape->dim[filter_shape->num_dims - 1];
813818
} else {
814-
output_channels = filter_value->shape.dim[0];
819+
output_channels =
820+
xnn_shape_multiply_batch_dims(filter_shape, /*num_nonbatch_dims=*/1);
815821
// Note that for convolutions, the filter shape can be `[H, 1, 1, W]`, so we
816822
// need to look at the last dimension of the filter.
817-
input_channels = filter_value->shape.dim[filter_value->shape.num_dims - 1];
823+
input_channels = filter_shape->dim[filter_shape->num_dims - 1];
818824
}
819825

820-
const size_t batch_size = num_input_elements / input_channels;
821-
assert(batch_size * input_channels == num_input_elements);
826+
const size_t batch_size = xnn_shape_multiply_batch_dims(
827+
&input_value->shape, /*num_nonbatch_dims=*/1);
822828
const size_t old_workspace_size = opdata->workspace_size;
823829
enum xnn_status status = xnn_status_invalid_state;
824830

@@ -1280,15 +1286,17 @@ static inline bool validate_datatypes_with_bias(
12801286
bias_datatype == xnn_datatype_fp32 &&
12811287
output_datatype == xnn_datatype_fp32) {
12821288
return true;
1283-
} else if (input_datatype == xnn_datatype_qdint8 &&
1289+
} else if ((input_datatype == xnn_datatype_qdint8 ||
1290+
input_datatype == xnn_datatype_qduint8) &&
12841291
bias_datatype == xnn_datatype_fp32 &&
12851292
output_datatype == xnn_datatype_fp32) {
12861293
return true;
12871294
} else if (input_datatype == xnn_datatype_qpint8 &&
12881295
bias_datatype == xnn_datatype_fp32 &&
12891296
output_datatype == xnn_datatype_fp32) {
12901297
return true;
1291-
} else if (input_datatype == xnn_datatype_qdint8 &&
1298+
} else if ((input_datatype == xnn_datatype_qdint8 ||
1299+
input_datatype == xnn_datatype_qduint8) &&
12921300
bias_datatype == xnn_datatype_fp32 &&
12931301
output_datatype == xnn_datatype_fp16) {
12941302
return true;
@@ -1299,7 +1307,8 @@ static inline bool validate_datatypes_with_bias(
12991307
}
13001308
break;
13011309
case xnn_datatype_qbint4:
1302-
if (input_datatype == xnn_datatype_qdint8 &&
1310+
if ((input_datatype == xnn_datatype_qdint8 ||
1311+
input_datatype == xnn_datatype_qduint8) &&
13031312
bias_datatype == xnn_datatype_fp32 &&
13041313
output_datatype == xnn_datatype_fp32) {
13051314
return true;
@@ -1318,15 +1327,17 @@ static inline bool validate_datatypes_with_bias(
13181327
bias_datatype == xnn_datatype_fp32 &&
13191328
output_datatype == xnn_datatype_fp32) {
13201329
return true;
1321-
} else if (input_datatype == xnn_datatype_qdint8 &&
1330+
} else if ((input_datatype == xnn_datatype_qdint8 ||
1331+
input_datatype == xnn_datatype_qduint8) &&
13221332
bias_datatype == xnn_datatype_fp32 &&
13231333
output_datatype == xnn_datatype_fp32) {
13241334
return true;
13251335
} else if (input_datatype == xnn_datatype_qpint8 &&
13261336
bias_datatype == xnn_datatype_fp32 &&
13271337
output_datatype == xnn_datatype_fp32) {
13281338
return true;
1329-
} else if (input_datatype == xnn_datatype_qdint8 &&
1339+
} else if ((input_datatype == xnn_datatype_qdint8 ||
1340+
input_datatype == xnn_datatype_qduint8) &&
13301341
bias_datatype == xnn_datatype_fp32 &&
13311342
output_datatype == xnn_datatype_fp16) {
13321343
return true;
@@ -1390,13 +1401,15 @@ static inline bool validate_datatypes_without_bias(
13901401
if (input_datatype == xnn_datatype_fp32 &&
13911402
output_datatype == xnn_datatype_fp32) {
13921403
return true;
1393-
} else if (input_datatype == xnn_datatype_qdint8 &&
1404+
} else if ((input_datatype == xnn_datatype_qdint8 ||
1405+
input_datatype == xnn_datatype_qduint8) &&
13941406
output_datatype == xnn_datatype_fp32) {
13951407
return true;
13961408
} else if (input_datatype == xnn_datatype_qpint8 &&
13971409
output_datatype == xnn_datatype_fp32) {
13981410
return true;
1399-
} else if (input_datatype == xnn_datatype_qdint8 &&
1411+
} else if ((input_datatype == xnn_datatype_qdint8 ||
1412+
input_datatype == xnn_datatype_qduint8) &&
14001413
output_datatype == xnn_datatype_fp16) {
14011414
return true;
14021415
} else if (input_datatype == xnn_datatype_qint8 &&
@@ -1405,7 +1418,8 @@ static inline bool validate_datatypes_without_bias(
14051418
}
14061419
break;
14071420
case xnn_datatype_qbint4:
1408-
if (input_datatype == xnn_datatype_qdint8 &&
1421+
if ((input_datatype == xnn_datatype_qdint8 ||
1422+
input_datatype == xnn_datatype_qduint8) &&
14091423
output_datatype == xnn_datatype_fp32) {
14101424
return true;
14111425
} else if (input_datatype == xnn_datatype_qdint8 &&
@@ -1420,13 +1434,15 @@ static inline bool validate_datatypes_without_bias(
14201434
if (input_datatype == xnn_datatype_fp32 &&
14211435
output_datatype == xnn_datatype_fp32) {
14221436
return true;
1423-
} else if (input_datatype == xnn_datatype_qdint8 &&
1437+
} else if ((input_datatype == xnn_datatype_qdint8 ||
1438+
input_datatype == xnn_datatype_qduint8) &&
14241439
output_datatype == xnn_datatype_fp32) {
14251440
return true;
14261441
} else if (input_datatype == xnn_datatype_qpint8 &&
14271442
output_datatype == xnn_datatype_fp32) {
14281443
return true;
1429-
} else if (input_datatype == xnn_datatype_qdint8 &&
1444+
} else if ((input_datatype == xnn_datatype_qdint8 ||
1445+
input_datatype == xnn_datatype_qduint8) &&
14301446
output_datatype == xnn_datatype_fp16) {
14311447
return true;
14321448
} else if (input_datatype == xnn_datatype_qint8 &&
@@ -1491,6 +1507,7 @@ enum xnn_status xnn_define_fully_connected(xnn_subgraph_t subgraph,
14911507
case xnn_datatype_qpint8:
14921508
break;
14931509
case xnn_datatype_qdint8:
1510+
case xnn_datatype_qduint8:
14941511
if (input_value->quantization.num_nonbatch_dims >
14951512
input_value->shape.num_dims) {
14961513
xnn_log_error("failed to define %s operator with input ID #%" PRIu32

test/subgraph/rewrites.cc

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "include/xnnpack.h"
2424
#include "src/subgraph/subgraph-utils.h"
2525
#include "src/xnnpack/buffer.h"
26+
#include "src/xnnpack/common.h"
2627
#include "src/xnnpack/datatype.h"
2728
#include "src/xnnpack/node-type.h"
2829
#include "src/xnnpack/subgraph.h"
@@ -1875,9 +1876,17 @@ TEST_P(RewriteGemmTest, RewritesGoiToGioAndElidesSpuriousTranspose) {
18751876
/*expected_node_type_counts=*/{{xnn_node_type_static_transpose, 0}},
18761877
/*test_fn=*/
18771878
[](xnn_subgraph_t subgraph) {
1878-
const xnn_node* bmm_node = &subgraph->nodes[subgraph->num_nodes - 1];
1879-
ASSERT_EQ(bmm_node->type, xnn_node_type_batch_matrix_multiply);
1880-
ASSERT_EQ(bmm_node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0);
1879+
const xnn_node* node = &subgraph->nodes[subgraph->num_nodes - 1];
1880+
switch (node->type) {
1881+
case xnn_node_type_batch_matrix_multiply:
1882+
ASSERT_EQ(node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0);
1883+
break;
1884+
case xnn_node_type_fully_connected:
1885+
ASSERT_NE(node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0);
1886+
break;
1887+
default:
1888+
XNN_UNREACHABLE;
1889+
}
18811890
});
18821891
}
18831892

@@ -1946,9 +1955,17 @@ TEST_P(RewriteGemmTest, RewritesGioToGoiAndKeepsNonSpuriousTranspose) {
19461955
/*expected_node_type_counts=*/{{xnn_node_type_static_transpose, 1}},
19471956
/*test_fn=*/
19481957
[](xnn_subgraph_t subgraph) {
1949-
const xnn_node* bmm_node = &subgraph->nodes[subgraph->num_nodes - 1];
1950-
ASSERT_EQ(bmm_node->type, xnn_node_type_batch_matrix_multiply);
1951-
ASSERT_NE(bmm_node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0);
1958+
const xnn_node* node = &subgraph->nodes[subgraph->num_nodes - 1];
1959+
switch (node->type) {
1960+
case xnn_node_type_batch_matrix_multiply:
1961+
ASSERT_NE(node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0);
1962+
break;
1963+
case xnn_node_type_fully_connected:
1964+
ASSERT_EQ(node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0);
1965+
break;
1966+
default:
1967+
XNN_UNREACHABLE;
1968+
}
19521969
});
19531970
}
19541971

@@ -2003,9 +2020,17 @@ TEST_P(RewriteGemmTest, DoesNotRewritesGoiToGioWithNonSpuriousTranspose) {
20032020
/*expected_node_type_counts=*/{{xnn_node_type_static_transpose, 1}},
20042021
/*test_fn=*/
20052022
[](xnn_subgraph_t subgraph) {
2006-
const xnn_node* bmm_node = &subgraph->nodes[subgraph->num_nodes - 1];
2007-
ASSERT_EQ(bmm_node->type, xnn_node_type_batch_matrix_multiply);
2008-
ASSERT_NE(bmm_node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0);
2023+
const xnn_node* node = &subgraph->nodes[subgraph->num_nodes - 1];
2024+
switch (node->type) {
2025+
case xnn_node_type_batch_matrix_multiply:
2026+
ASSERT_NE(node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0);
2027+
break;
2028+
case xnn_node_type_fully_connected:
2029+
ASSERT_EQ(node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0);
2030+
break;
2031+
default:
2032+
XNN_UNREACHABLE;
2033+
}
20092034
});
20102035
}
20112036

test/subgraph/subgraph-fp16.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,13 +1135,17 @@ TEST(SUBGRAPH_FP16_BATCH_MATRIX_MULTIPLY, with_static_value) {
11351135
switch (tester.NumNodes()) {
11361136
case 3:
11371137
ASSERT_EQ(tester.Node(0)->type, xnn_node_type_convert);
1138-
ASSERT_EQ(tester.Node(1)->type, xnn_node_type_batch_matrix_multiply);
1138+
ASSERT_THAT(tester.Node(1)->type,
1139+
testing::AnyOf(xnn_node_type_batch_matrix_multiply,
1140+
xnn_node_type_fully_connected));
11391141
ASSERT_EQ(tester.Node(2)->type, xnn_node_type_convert);
11401142
break;
11411143
case 4:
11421144
ASSERT_EQ(tester.Node(0)->type, xnn_node_type_convert);
11431145
ASSERT_EQ(tester.Node(1)->type, xnn_node_type_pack_lh);
1144-
ASSERT_EQ(tester.Node(2)->type, xnn_node_type_batch_matrix_multiply);
1146+
ASSERT_THAT(tester.Node(2)->type,
1147+
testing::AnyOf(xnn_node_type_batch_matrix_multiply,
1148+
xnn_node_type_fully_connected));
11451149
ASSERT_EQ(tester.Node(3)->type, xnn_node_type_convert);
11461150
break;
11471151
default:
@@ -1204,14 +1208,18 @@ TEST(SUBGRAPH_FP16_BATCH_MATRIX_MULTIPLY, with_non_static_value) {
12041208
case 4:
12051209
ASSERT_EQ(tester.Node(0)->type, xnn_node_type_convert);
12061210
ASSERT_EQ(tester.Node(1)->type, xnn_node_type_convert);
1207-
ASSERT_EQ(tester.Node(2)->type, xnn_node_type_batch_matrix_multiply);
1211+
ASSERT_THAT(tester.Node(2)->type,
1212+
testing::AnyOf(xnn_node_type_batch_matrix_multiply,
1213+
xnn_node_type_fully_connected));
12081214
ASSERT_EQ(tester.Node(3)->type, xnn_node_type_convert);
12091215
break;
12101216
case 5:
12111217
ASSERT_EQ(tester.Node(0)->type, xnn_node_type_convert);
12121218
ASSERT_EQ(tester.Node(1)->type, xnn_node_type_pack_lh);
12131219
ASSERT_EQ(tester.Node(2)->type, xnn_node_type_convert);
1214-
ASSERT_EQ(tester.Node(3)->type, xnn_node_type_batch_matrix_multiply);
1220+
ASSERT_THAT(tester.Node(3)->type,
1221+
testing::AnyOf(xnn_node_type_batch_matrix_multiply,
1222+
xnn_node_type_fully_connected));
12151223
ASSERT_EQ(tester.Node(4)->type, xnn_node_type_convert);
12161224
break;
12171225
default:

0 commit comments

Comments
 (0)