diff --git a/bench/subgraph/fully-connected.cc b/bench/subgraph/fully-connected.cc index f8e04656280..beeb67e9408 100644 --- a/bench/subgraph/fully-connected.cc +++ b/bench/subgraph/fully-connected.cc @@ -202,16 +202,20 @@ BENCHMARK(QD8FullyConnected) static void FullyConnectedArgs(benchmark::internal::Benchmark* b) { b->ArgNames({"M", "K", "N"}); - static const std::array kDims = { - 1, 2, 4, 8, 16, 32, 64, 128, - 256, 512, 1024, 2048, 4096, 8192, 16384, 65536}; - const int64_t kMinK = 8; + static const std::array kDims = { + 1, 2, 4, 8, 16, 32, 64, 128, 256, + 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}; + const int64_t kMinK = 1024; + const int64_t kMinM = 32; const int64_t kMaxSmall = 16; const int64_t kMinHuge = 1024; const int64_t kMinFLOPs = (int64_t)1 << 16; const int64_t kMaxFLOPs = (int64_t)1 << 30; for (int64_t m : kDims) { + if (m < kMinM) { + continue; + } for (int64_t k : kDims) { if (k < kMinK) { continue; diff --git a/src/configs/BUILD b/src/configs/BUILD index 25c5a309a0c..3d4e76a41a1 100644 --- a/src/configs/BUILD +++ b/src/configs/BUILD @@ -6,6 +6,7 @@ load( "//:build_defs.bzl", "xnnpack_cc_library", + "xnnpack_cxx_library", "xnnpack_select_if", ) @@ -48,11 +49,24 @@ xnnpack_cc_library( ], ) +xnnpack_cxx_library( + name = "hardware_utils", + srcs = ["hardware_utils.cc"], + hdrs = ["hardware_utils.h"], + textual_hdrs = ["//:src/xnnpack/hardware-config.h"], + deps = [ + "//:common", + "//:logging", + "@com_google_benchmark//:benchmark", + ], +) + xnnpack_cc_library( name = "hardware_config", srcs = ["hardware-config.c"], hdrs = ["//:src/xnnpack/hardware-config.h"], deps = [ + ":hardware_utils", "//:common", "//:init_once", "//:logging", diff --git a/src/configs/hardware-config.c b/src/configs/hardware-config.c index 74828411e91..21d576d5582 100644 --- a/src/configs/hardware-config.c +++ b/src/configs/hardware-config.c @@ -6,6 +6,7 @@ #include #include #include +#include "src/configs/hardware_utils.h" #if XNN_ENABLE_CPUINFO #include @@ -462,6 +463,7 @@ static void init_hardware_config(void) { hardware_config.uarch[i] = xnn_uarch_unknown; } #endif // XNN_ENABLE_CPUINFO + xnn_set_cache_data(&hardware_config); } const struct xnn_hardware_config* xnn_init_hardware_config() { diff --git a/src/configs/hardware_utils.cc b/src/configs/hardware_utils.cc new file mode 100644 index 00000000000..ea22a5ab3e1 --- /dev/null +++ b/src/configs/hardware_utils.cc @@ -0,0 +1,35 @@ +#include "src/configs/hardware_utils.h" + +#include "src/xnnpack/hardware-config.h" +#include "src/xnnpack/log.h" +#include + +bool xnn_set_cache_data(struct xnn_hardware_config* hardware_config) { + // Get the CPUInfo. + const benchmark::CPUInfo& cpu_info = benchmark::CPUInfo::Get(); + + // Populate the `hardware_config` fields with it. + for (const auto& cache : cpu_info.caches) { + if (cache.level == 1 && (cache.type == "Data" || cache.type == "Unified")) { + hardware_config->l1_data_cache_bytes = cache.size; + xnn_log_info( + "l1_data_cache_bytes=%zu, l1_data_cache_line_size=%zu, " + "l1_data_cache_associativity=%zu, l1_data_cache_num_sets=%zu.", + hardware_config->l1_data_cache_bytes, + hardware_config->l1_data_cache_line_size, + hardware_config->l1_data_cache_associativity, + hardware_config->l1_data_cache_num_sets); + } else if (cache.level == 2 && + (cache.type == "Data" || cache.type == "Unified")) { + hardware_config->l2_data_cache_bytes = cache.size; + xnn_log_info( + "l2_data_cache_bytes=%zu, l2_data_cache_line_size=%zu, " + "l2_data_cache_associativity=%zu, l2_data_cache_num_sets=%zu.", + hardware_config->l2_data_cache_bytes, + hardware_config->l2_data_cache_line_size, + hardware_config->l2_data_cache_associativity, + hardware_config->l2_data_cache_num_sets); + } + } + return true; +} diff --git a/src/configs/hardware_utils.h b/src/configs/hardware_utils.h new file mode 100644 index 00000000000..b7dc2028fb4 --- /dev/null +++ b/src/configs/hardware_utils.h @@ -0,0 +1,18 @@ +#ifndef XNNPACK_SRC_CONFIGS_HARDWARE_UTILS_H_ +#define XNNPACK_SRC_CONFIGS_HARDWARE_UTILS_H_ + +#include "src/xnnpack/common.h" +#include "src/xnnpack/hardware-config.h" + +#ifdef __cplusplus +extern "C" { +#endif + +XNN_INTERNAL bool xnn_set_cache_data( + struct xnn_hardware_config* hardware_config); + +#ifdef __cplusplus +} +#endif + +#endif // XNNPACK_SRC_CONFIGS_HARDWARE_UTILS_H_ diff --git a/src/microkernel-utils.c b/src/microkernel-utils.c index a2b6dbf8c96..f42229e5591 100644 --- a/src/microkernel-utils.c +++ b/src/microkernel-utils.c @@ -155,17 +155,15 @@ size_t xnn_gemm_best_tile_size(size_t num_groups, size_t m, size_t n, // Checks whether to use the `nr2` config or not. bool xnn_use_nr2(size_t nr, size_t nr2, size_t output_channels) { - size_t nr_overcompute = (nr - output_channels % nr) % nr; - size_t nr2_overcompute = (nr2 - output_channels % nr2) % nr2; - // Switch to alternative microkernel when: - // 1. Alternative microkernel better supports fewer output channels, or - // 2. Alternative microkernel has less overcompute and default wastes >1% of - // output channels - if (nr > output_channels || (nr2_overcompute < nr_overcompute && - nr_overcompute * 100 > output_channels)) { - // Default microkernel is suboptimal, use a microkernel that better - // supports fewer output channels. - return true; + if (nr > output_channels) { + size_t nr_overcompute = (nr - output_channels % nr) % nr; + size_t nr2_overcompute = (nr2 - output_channels % nr2) % nr2; + // Switch to alternative microkernel when: + // 1. Alternative microkernel better supports fewer output channels, or + // 2. Alternative microkernel has less overcompute and default wastes >1% of + // output channels + return nr2_overcompute < nr_overcompute && + nr_overcompute * 100 > output_channels; } return false; } diff --git a/src/subgraph.c b/src/subgraph.c index 7014a2611c6..6aef914f4ba 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -3621,6 +3621,57 @@ static enum xnn_status optimize_common_subgraphs_gemm_rhs_transpose( return xnn_status_success; } +// Converts batch-matrix-multiply nodes with 2D weights to fully-connected nodes +// for consistency. +static enum xnn_status optimize_common_subgraphs_bmm_to_fc( + xnn_subgraph_t subgraph, uint32_t node_id, size_t* changes) { + struct xnn_node* node = &subgraph->nodes[node_id]; + if (node->type != xnn_node_type_batch_matrix_multiply) { + return xnn_status_success; + } + + const uint32_t input_a_id = node->inputs[0]; + const uint32_t input_b_id = node->inputs[1]; + const uint32_t output_id = node->outputs[0]; + struct xnn_value* input_b_value = &subgraph->values[input_b_id]; + const enum xnn_datatype packed_input_datatype = node->packed_input_datatype; + + // Weights should have at least two dimensions, and batch dimensions + // should all be 1. + if (input_b_value->shape.num_dims != 2) { + return xnn_status_success; + } + + // If the weights are dynamic, restrict to fp32/fp16. + if (!xnn_value_is_static(input_b_value->allocation_type) && + !(input_b_value->datatype == xnn_datatype_fp32 || + input_b_value->datatype == xnn_datatype_fp16)) { + return xnn_status_success; + } + + // Replace with a fully-connected node. + XNN_RETURN_IF_ERROR( + xnn_define_fully_connected( + subgraph, + /*output_min=*/-INFINITY, /*output_max=*/INFINITY, input_a_id, + input_b_id, /*bias_id=*/XNN_INVALID_VALUE_ID, output_id, + node->flags ^ XNN_FLAG_TRANSPOSE_WEIGHTS), + "Failed to create new `fully_connected` node."); + node = &subgraph->nodes[node_id]; + *node = subgraph->nodes[--subgraph->num_nodes]; + node->id = node_id; + node->packed_input_datatype = packed_input_datatype; + subgraph->values[input_a_id].flags |= XNN_FLAG_SQUASH_GROUPS; + + xnn_log_info( + "Converted batch_matrix_multiply[#%u](v%03u, v%03u) to " + "fully_connected[#%u](v%03u, v%03u).", + node_id, input_a_id, input_b_id, node_id, input_a_id, input_b_id); + (*changes)++; + + return xnn_status_success; +} + static enum xnn_status optimize_common_subgraphs_iter( xnn_subgraph_t subgraph, uint32_t optimization_flags, size_t* changes) { // Loop over the nodes in this subgraph. @@ -3739,8 +3790,14 @@ static enum xnn_status optimize_common_subgraphs_iter( // be pushed back to the static value. break; - case xnn_node_type_fully_connected: case xnn_node_type_batch_matrix_multiply: + // Convert batch-matrix-multiply nodes with 2D weights to + // fully-connected nodes for consistency. + XNN_RETURN_IF_ERROR( + optimize_common_subgraphs_bmm_to_fc(subgraph, node_id, changes)); + XNN_FALLTHROUGH + + case xnn_node_type_fully_connected: // Merge or remove transposes of the RHS of a batch-matrix-multiply or // fully-connected op. XNN_RETURN_IF_ERROR(optimize_common_subgraphs_gemm_rhs_transpose( @@ -3907,8 +3964,8 @@ enum xnn_status xnn_subgraph_optimize_packed_lhs(xnn_subgraph_t subgraph, input_id, xnn_node_type_to_string(xnn_node_type_convert), xnn_datatype_to_string(input_datatype), xnn_datatype_to_string(xnn_datatype_qpint8)); - subgraph->values[input_id].datatype = assumed_datatype; - subgraph->values[input_id].gemm_config = gemm_config; + input_value->datatype = assumed_datatype; + input_value->gemm_config = gemm_config; } else { // Insert a node to pack the LHS. xnn_log_debug( @@ -3920,15 +3977,15 @@ enum xnn_status xnn_subgraph_optimize_packed_lhs(xnn_subgraph_t subgraph, uint32_t new_id = XNN_INVALID_VALUE_ID; XNN_RETURN_IF_ERROR( xnn_insert_pack_lh_node(subgraph, input_id, &new_id)); - subgraph->nodes[node_id].inputs[0] = new_id; + node = &subgraph->nodes[node_id]; + node->inputs[0] = new_id; changes++; } // If this is a fully-connected op, we need to coerce the shape of // the inputs from `[B, M, K]` to `[B * M, K]` to avoid batch-wise // packing. if (node->type == xnn_node_type_fully_connected) { - subgraph->values[subgraph->nodes[node_id].inputs[0]].flags |= - XNN_FLAG_SQUASH_GROUPS; + subgraph->values[node->inputs[0]].flags |= XNN_FLAG_SQUASH_GROUPS; } } else { if (input_datatype == xnn_datatype_qdint8) { @@ -4178,10 +4235,6 @@ enum xnn_status xnn_subgraph_optimize(xnn_subgraph_t subgraph, return xnn_status_unsupported_hardware; } - // Apply some common subgraph optimizations. - XNN_RETURN_IF_ERROR( - xnn_subgraph_optimize_common_subgraphs(subgraph, optimization_flags)); - if ((optimization_flags & XNN_FLAG_FORCE_FP16_INFERENCE) && (!xnn_is_f16_compatible_config(hardware_config))) { xnn_log_error( @@ -4234,6 +4287,10 @@ enum xnn_status xnn_subgraph_optimize(xnn_subgraph_t subgraph, XNN_RETURN_IF_ERROR( xnn_subgraph_optimize_packed_lhs(subgraph, optimization_flags)); + // Apply some common subgraph optimizations. + XNN_RETURN_IF_ERROR( + xnn_subgraph_optimize_common_subgraphs(subgraph, optimization_flags)); + return xnn_status_success; } diff --git a/src/subgraph/fully-connected.c b/src/subgraph/fully-connected.c index cf00e72082e..3dcba3e1aee 100644 --- a/src/subgraph/fully-connected.c +++ b/src/subgraph/fully-connected.c @@ -262,14 +262,17 @@ static enum xnn_status create_fully_connected_operator( const struct xnn_runtime_value* output_value = &values[output_id]; size_t output_channels, input_channels; + const struct xnn_shape* filter_shape = &filter_value->shape; if (node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { - input_channels = filter_value->shape.dim[0]; - output_channels = filter_value->shape.dim[1]; + input_channels = + xnn_shape_multiply_batch_dims(filter_shape, /*num_nonbatch_dims=*/1); + output_channels = filter_shape->dim[filter_shape->num_dims - 1]; } else { - output_channels = filter_value->shape.dim[0]; + output_channels = + xnn_shape_multiply_batch_dims(filter_shape, /*num_nonbatch_dims=*/1); // Note that for convolutions, the filter shape can be `[H, 1, 1, W]`, so we // need to look at the last dimension of the filter. - input_channels = filter_value->shape.dim[filter_value->shape.num_dims - 1]; + input_channels = filter_shape->dim[filter_shape->num_dims - 1]; } const void* kernel_data = filter_value->data; @@ -765,18 +768,20 @@ enum xnn_status resize_fully_connected_output_tensor( const uint32_t input_id = opdata->inputs[0]; const struct xnn_runtime_value* input = &values[input_id]; - output->shape.num_dims = input->shape.num_dims; - // Infer output channels. - const uint32_t filter_output_channel_index = - (opdata->flags & XNN_FLAG_TRANSPOSE_WEIGHTS) ? 1 : 0; - output->shape.dim[output->shape.num_dims - 1] = - filter->shape.dim[filter_output_channel_index]; - // Propagate input shape to output. + output->shape.num_dims = input->shape.num_dims; for (size_t cur_dim = 0; cur_dim < input->shape.num_dims - 1; cur_dim++) { output->shape.dim[cur_dim] = input->shape.dim[cur_dim]; } + // Infer output channels. + const size_t filter_output_channels = + (opdata->flags & XNN_FLAG_TRANSPOSE_WEIGHTS) + ? filter->shape.dim[filter->shape.num_dims - 1] + : xnn_shape_multiply_batch_dims(&filter->shape, + /*num_nonbatch_dims=*/1); + output->shape.dim[output->shape.num_dims - 1] = filter_output_channels; + const size_t new_size = xnn_runtime_tensor_get_size(output); if (new_size > output->size || old_workspace_size < opdata->workspace_size) { output->size = new_size; @@ -804,21 +809,22 @@ static enum xnn_status reshape_fully_connected_operator( if (output_value->flags & XNN_VALUE_FLAG_LAYOUT_NCHW) { return reshape_convolution_operator(opdata, values, num_values, threadpool); } - const size_t num_input_elements = - xnn_shape_multiply_all_dims(&input_value->shape); size_t output_channels, input_channels; + const struct xnn_shape* filter_shape = &filter_value->shape; if (opdata->flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { - input_channels = filter_value->shape.dim[0]; - output_channels = filter_value->shape.dim[1]; + input_channels = + xnn_shape_multiply_batch_dims(filter_shape, /*num_nonbatch_dims=*/1); + output_channels = filter_shape->dim[filter_shape->num_dims - 1]; } else { - output_channels = filter_value->shape.dim[0]; + output_channels = + xnn_shape_multiply_batch_dims(filter_shape, /*num_nonbatch_dims=*/1); // Note that for convolutions, the filter shape can be `[H, 1, 1, W]`, so we // need to look at the last dimension of the filter. - input_channels = filter_value->shape.dim[filter_value->shape.num_dims - 1]; + input_channels = filter_shape->dim[filter_shape->num_dims - 1]; } - const size_t batch_size = num_input_elements / input_channels; - assert(batch_size * input_channels == num_input_elements); + const size_t batch_size = xnn_shape_multiply_batch_dims( + &input_value->shape, /*num_nonbatch_dims=*/1); const size_t old_workspace_size = opdata->workspace_size; enum xnn_status status = xnn_status_invalid_state; @@ -1280,7 +1286,8 @@ static inline bool validate_datatypes_with_bias( bias_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) { return true; - } else if (input_datatype == xnn_datatype_qdint8 && + } else if ((input_datatype == xnn_datatype_qdint8 || + input_datatype == xnn_datatype_qduint8) && bias_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) { return true; @@ -1288,7 +1295,8 @@ static inline bool validate_datatypes_with_bias( bias_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) { return true; - } else if (input_datatype == xnn_datatype_qdint8 && + } else if ((input_datatype == xnn_datatype_qdint8 || + input_datatype == xnn_datatype_qduint8) && bias_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp16) { return true; @@ -1299,7 +1307,8 @@ static inline bool validate_datatypes_with_bias( } break; case xnn_datatype_qbint4: - if (input_datatype == xnn_datatype_qdint8 && + if ((input_datatype == xnn_datatype_qdint8 || + input_datatype == xnn_datatype_qduint8) && bias_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) { return true; @@ -1318,7 +1327,8 @@ static inline bool validate_datatypes_with_bias( bias_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) { return true; - } else if (input_datatype == xnn_datatype_qdint8 && + } else if ((input_datatype == xnn_datatype_qdint8 || + input_datatype == xnn_datatype_qduint8) && bias_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) { return true; @@ -1326,7 +1336,8 @@ static inline bool validate_datatypes_with_bias( bias_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) { return true; - } else if (input_datatype == xnn_datatype_qdint8 && + } else if ((input_datatype == xnn_datatype_qdint8 || + input_datatype == xnn_datatype_qduint8) && bias_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp16) { return true; @@ -1390,13 +1401,15 @@ static inline bool validate_datatypes_without_bias( if (input_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) { return true; - } else if (input_datatype == xnn_datatype_qdint8 && + } else if ((input_datatype == xnn_datatype_qdint8 || + input_datatype == xnn_datatype_qduint8) && output_datatype == xnn_datatype_fp32) { return true; } else if (input_datatype == xnn_datatype_qpint8 && output_datatype == xnn_datatype_fp32) { return true; - } else if (input_datatype == xnn_datatype_qdint8 && + } else if ((input_datatype == xnn_datatype_qdint8 || + input_datatype == xnn_datatype_qduint8) && output_datatype == xnn_datatype_fp16) { return true; } else if (input_datatype == xnn_datatype_qint8 && @@ -1405,7 +1418,8 @@ static inline bool validate_datatypes_without_bias( } break; case xnn_datatype_qbint4: - if (input_datatype == xnn_datatype_qdint8 && + if ((input_datatype == xnn_datatype_qdint8 || + input_datatype == xnn_datatype_qduint8) && output_datatype == xnn_datatype_fp32) { return true; } else if (input_datatype == xnn_datatype_qdint8 && @@ -1420,13 +1434,15 @@ static inline bool validate_datatypes_without_bias( if (input_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) { return true; - } else if (input_datatype == xnn_datatype_qdint8 && + } else if ((input_datatype == xnn_datatype_qdint8 || + input_datatype == xnn_datatype_qduint8) && output_datatype == xnn_datatype_fp32) { return true; } else if (input_datatype == xnn_datatype_qpint8 && output_datatype == xnn_datatype_fp32) { return true; - } else if (input_datatype == xnn_datatype_qdint8 && + } else if ((input_datatype == xnn_datatype_qdint8 || + input_datatype == xnn_datatype_qduint8) && output_datatype == xnn_datatype_fp16) { return true; } else if (input_datatype == xnn_datatype_qint8 && @@ -1491,6 +1507,7 @@ enum xnn_status xnn_define_fully_connected(xnn_subgraph_t subgraph, case xnn_datatype_qpint8: break; case xnn_datatype_qdint8: + case xnn_datatype_qduint8: if (input_value->quantization.num_nonbatch_dims > input_value->shape.num_dims) { xnn_log_error("failed to define %s operator with input ID #%" PRIu32 diff --git a/test/replicable_random_device.h b/test/replicable_random_device.h index 966c4b31a7a..778c9d4135f 100644 --- a/test/replicable_random_device.h +++ b/test/replicable_random_device.h @@ -94,7 +94,7 @@ class ReplicableRandomDevice { static bool is_set = false; static int last_seed = 0; if (!is_set || last_seed != random_seed_) { - std::cout + std::cerr << "Creating a random device for testing, to replicate it re-run the " "test with `--gtest_random_seed=" << random_seed_ << "`." << std::endl; diff --git a/test/subgraph/rewrites.cc b/test/subgraph/rewrites.cc index 3b549c9c119..b35de2dc1a6 100644 --- a/test/subgraph/rewrites.cc +++ b/test/subgraph/rewrites.cc @@ -23,6 +23,7 @@ #include "include/xnnpack.h" #include "src/subgraph/subgraph-utils.h" #include "src/xnnpack/buffer.h" +#include "src/xnnpack/common.h" #include "src/xnnpack/datatype.h" #include "src/xnnpack/node-type.h" #include "src/xnnpack/subgraph.h" @@ -1875,9 +1876,17 @@ TEST_P(RewriteGemmTest, RewritesGoiToGioAndElidesSpuriousTranspose) { /*expected_node_type_counts=*/{{xnn_node_type_static_transpose, 0}}, /*test_fn=*/ [](xnn_subgraph_t subgraph) { - const xnn_node* bmm_node = &subgraph->nodes[subgraph->num_nodes - 1]; - ASSERT_EQ(bmm_node->type, xnn_node_type_batch_matrix_multiply); - ASSERT_EQ(bmm_node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0); + const xnn_node* node = &subgraph->nodes[subgraph->num_nodes - 1]; + switch (node->type) { + case xnn_node_type_batch_matrix_multiply: + ASSERT_EQ(node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0); + break; + case xnn_node_type_fully_connected: + ASSERT_NE(node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0); + break; + default: + XNN_UNREACHABLE; + } }); } @@ -1946,9 +1955,17 @@ TEST_P(RewriteGemmTest, RewritesGioToGoiAndKeepsNonSpuriousTranspose) { /*expected_node_type_counts=*/{{xnn_node_type_static_transpose, 1}}, /*test_fn=*/ [](xnn_subgraph_t subgraph) { - const xnn_node* bmm_node = &subgraph->nodes[subgraph->num_nodes - 1]; - ASSERT_EQ(bmm_node->type, xnn_node_type_batch_matrix_multiply); - ASSERT_NE(bmm_node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0); + const xnn_node* node = &subgraph->nodes[subgraph->num_nodes - 1]; + switch (node->type) { + case xnn_node_type_batch_matrix_multiply: + ASSERT_NE(node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0); + break; + case xnn_node_type_fully_connected: + ASSERT_EQ(node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0); + break; + default: + XNN_UNREACHABLE; + } }); } @@ -2003,9 +2020,17 @@ TEST_P(RewriteGemmTest, DoesNotRewritesGoiToGioWithNonSpuriousTranspose) { /*expected_node_type_counts=*/{{xnn_node_type_static_transpose, 1}}, /*test_fn=*/ [](xnn_subgraph_t subgraph) { - const xnn_node* bmm_node = &subgraph->nodes[subgraph->num_nodes - 1]; - ASSERT_EQ(bmm_node->type, xnn_node_type_batch_matrix_multiply); - ASSERT_NE(bmm_node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0); + const xnn_node* node = &subgraph->nodes[subgraph->num_nodes - 1]; + switch (node->type) { + case xnn_node_type_batch_matrix_multiply: + ASSERT_NE(node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0); + break; + case xnn_node_type_fully_connected: + ASSERT_EQ(node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS, 0); + break; + default: + XNN_UNREACHABLE; + } }); } diff --git a/test/subgraph/subgraph-fp16.cc b/test/subgraph/subgraph-fp16.cc index 14ebdccd700..838143cce8d 100644 --- a/test/subgraph/subgraph-fp16.cc +++ b/test/subgraph/subgraph-fp16.cc @@ -1135,13 +1135,17 @@ TEST(SUBGRAPH_FP16_BATCH_MATRIX_MULTIPLY, with_static_value) { switch (tester.NumNodes()) { case 3: ASSERT_EQ(tester.Node(0)->type, xnn_node_type_convert); - ASSERT_EQ(tester.Node(1)->type, xnn_node_type_batch_matrix_multiply); + ASSERT_THAT(tester.Node(1)->type, + testing::AnyOf(xnn_node_type_batch_matrix_multiply, + xnn_node_type_fully_connected)); ASSERT_EQ(tester.Node(2)->type, xnn_node_type_convert); break; case 4: ASSERT_EQ(tester.Node(0)->type, xnn_node_type_convert); ASSERT_EQ(tester.Node(1)->type, xnn_node_type_pack_lh); - ASSERT_EQ(tester.Node(2)->type, xnn_node_type_batch_matrix_multiply); + ASSERT_THAT(tester.Node(2)->type, + testing::AnyOf(xnn_node_type_batch_matrix_multiply, + xnn_node_type_fully_connected)); ASSERT_EQ(tester.Node(3)->type, xnn_node_type_convert); break; default: @@ -1204,14 +1208,18 @@ TEST(SUBGRAPH_FP16_BATCH_MATRIX_MULTIPLY, with_non_static_value) { case 4: ASSERT_EQ(tester.Node(0)->type, xnn_node_type_convert); ASSERT_EQ(tester.Node(1)->type, xnn_node_type_convert); - ASSERT_EQ(tester.Node(2)->type, xnn_node_type_batch_matrix_multiply); + ASSERT_THAT(tester.Node(2)->type, + testing::AnyOf(xnn_node_type_batch_matrix_multiply, + xnn_node_type_fully_connected)); ASSERT_EQ(tester.Node(3)->type, xnn_node_type_convert); break; case 5: ASSERT_EQ(tester.Node(0)->type, xnn_node_type_convert); ASSERT_EQ(tester.Node(1)->type, xnn_node_type_pack_lh); ASSERT_EQ(tester.Node(2)->type, xnn_node_type_convert); - ASSERT_EQ(tester.Node(3)->type, xnn_node_type_batch_matrix_multiply); + ASSERT_THAT(tester.Node(3)->type, + testing::AnyOf(xnn_node_type_batch_matrix_multiply, + xnn_node_type_fully_connected)); ASSERT_EQ(tester.Node(4)->type, xnn_node_type_convert); break; default: