@@ -262,14 +262,17 @@ static enum xnn_status create_fully_connected_operator(
262262 const struct xnn_runtime_value * output_value = & values [output_id ];
263263
264264 size_t output_channels , input_channels ;
265+ const struct xnn_shape * filter_shape = & filter_value -> shape ;
265266 if (node -> flags & XNN_FLAG_TRANSPOSE_WEIGHTS ) {
266- input_channels = filter_value -> shape .dim [0 ];
267- output_channels = filter_value -> shape .dim [1 ];
267+ input_channels =
268+ xnn_shape_multiply_batch_dims (filter_shape , /*num_nonbatch_dims=*/ 1 );
269+ output_channels = filter_shape -> dim [filter_shape -> num_dims - 1 ];
268270 } else {
269- output_channels = filter_value -> shape .dim [0 ];
271+ output_channels =
272+ xnn_shape_multiply_batch_dims (filter_shape , /*num_nonbatch_dims=*/ 1 );
270273 // Note that for convolutions, the filter shape can be `[H, 1, 1, W]`, so we
271274 // need to look at the last dimension of the filter.
272- input_channels = filter_value -> shape . dim [filter_value -> shape . num_dims - 1 ];
275+ input_channels = filter_shape -> dim [filter_shape -> num_dims - 1 ];
273276 }
274277
275278 const void * kernel_data = filter_value -> data ;
@@ -765,18 +768,20 @@ enum xnn_status resize_fully_connected_output_tensor(
765768 const uint32_t input_id = opdata -> inputs [0 ];
766769 const struct xnn_runtime_value * input = & values [input_id ];
767770
768- output -> shape .num_dims = input -> shape .num_dims ;
769- // Infer output channels.
770- const uint32_t filter_output_channel_index =
771- (opdata -> flags & XNN_FLAG_TRANSPOSE_WEIGHTS ) ? 1 : 0 ;
772- output -> shape .dim [output -> shape .num_dims - 1 ] =
773- filter -> shape .dim [filter_output_channel_index ];
774-
775771 // Propagate input shape to output.
772+ output -> shape .num_dims = input -> shape .num_dims ;
776773 for (size_t cur_dim = 0 ; cur_dim < input -> shape .num_dims - 1 ; cur_dim ++ ) {
777774 output -> shape .dim [cur_dim ] = input -> shape .dim [cur_dim ];
778775 }
779776
777+ // Infer output channels.
778+ const size_t filter_output_channels =
779+ (opdata -> flags & XNN_FLAG_TRANSPOSE_WEIGHTS )
780+ ? filter -> shape .dim [filter -> shape .num_dims - 1 ]
781+ : xnn_shape_multiply_batch_dims (& filter -> shape ,
782+ /*num_nonbatch_dims=*/ 1 );
783+ output -> shape .dim [output -> shape .num_dims - 1 ] = filter_output_channels ;
784+
780785 const size_t new_size = xnn_runtime_tensor_get_size (output );
781786 if (new_size > output -> size || old_workspace_size < opdata -> workspace_size ) {
782787 output -> size = new_size ;
@@ -804,21 +809,22 @@ static enum xnn_status reshape_fully_connected_operator(
804809 if (output_value -> flags & XNN_VALUE_FLAG_LAYOUT_NCHW ) {
805810 return reshape_convolution_operator (opdata , values , num_values , threadpool );
806811 }
807- const size_t num_input_elements =
808- xnn_shape_multiply_all_dims (& input_value -> shape );
809812 size_t output_channels , input_channels ;
813+ const struct xnn_shape * filter_shape = & filter_value -> shape ;
810814 if (opdata -> flags & XNN_FLAG_TRANSPOSE_WEIGHTS ) {
811- input_channels = filter_value -> shape .dim [0 ];
812- output_channels = filter_value -> shape .dim [1 ];
815+ input_channels =
816+ xnn_shape_multiply_batch_dims (filter_shape , /*num_nonbatch_dims=*/ 1 );
817+ output_channels = filter_shape -> dim [filter_shape -> num_dims - 1 ];
813818 } else {
814- output_channels = filter_value -> shape .dim [0 ];
819+ output_channels =
820+ xnn_shape_multiply_batch_dims (filter_shape , /*num_nonbatch_dims=*/ 1 );
815821 // Note that for convolutions, the filter shape can be `[H, 1, 1, W]`, so we
816822 // need to look at the last dimension of the filter.
817- input_channels = filter_value -> shape . dim [filter_value -> shape . num_dims - 1 ];
823+ input_channels = filter_shape -> dim [filter_shape -> num_dims - 1 ];
818824 }
819825
820- const size_t batch_size = num_input_elements / input_channels ;
821- assert ( batch_size * input_channels == num_input_elements );
826+ const size_t batch_size = xnn_shape_multiply_batch_dims (
827+ & input_value -> shape , /*num_nonbatch_dims=*/ 1 );
822828 const size_t old_workspace_size = opdata -> workspace_size ;
823829 enum xnn_status status = xnn_status_invalid_state ;
824830
@@ -1280,15 +1286,17 @@ static inline bool validate_datatypes_with_bias(
12801286 bias_datatype == xnn_datatype_fp32 &&
12811287 output_datatype == xnn_datatype_fp32 ) {
12821288 return true;
1283- } else if (input_datatype == xnn_datatype_qdint8 &&
1289+ } else if ((input_datatype == xnn_datatype_qdint8 ||
1290+ input_datatype == xnn_datatype_qduint8 ) &&
12841291 bias_datatype == xnn_datatype_fp32 &&
12851292 output_datatype == xnn_datatype_fp32 ) {
12861293 return true;
12871294 } else if (input_datatype == xnn_datatype_qpint8 &&
12881295 bias_datatype == xnn_datatype_fp32 &&
12891296 output_datatype == xnn_datatype_fp32 ) {
12901297 return true;
1291- } else if (input_datatype == xnn_datatype_qdint8 &&
1298+ } else if ((input_datatype == xnn_datatype_qdint8 ||
1299+ input_datatype == xnn_datatype_qduint8 ) &&
12921300 bias_datatype == xnn_datatype_fp32 &&
12931301 output_datatype == xnn_datatype_fp16 ) {
12941302 return true;
@@ -1299,7 +1307,8 @@ static inline bool validate_datatypes_with_bias(
12991307 }
13001308 break ;
13011309 case xnn_datatype_qbint4 :
1302- if (input_datatype == xnn_datatype_qdint8 &&
1310+ if ((input_datatype == xnn_datatype_qdint8 ||
1311+ input_datatype == xnn_datatype_qduint8 ) &&
13031312 bias_datatype == xnn_datatype_fp32 &&
13041313 output_datatype == xnn_datatype_fp32 ) {
13051314 return true;
@@ -1318,15 +1327,17 @@ static inline bool validate_datatypes_with_bias(
13181327 bias_datatype == xnn_datatype_fp32 &&
13191328 output_datatype == xnn_datatype_fp32 ) {
13201329 return true;
1321- } else if (input_datatype == xnn_datatype_qdint8 &&
1330+ } else if ((input_datatype == xnn_datatype_qdint8 ||
1331+ input_datatype == xnn_datatype_qduint8 ) &&
13221332 bias_datatype == xnn_datatype_fp32 &&
13231333 output_datatype == xnn_datatype_fp32 ) {
13241334 return true;
13251335 } else if (input_datatype == xnn_datatype_qpint8 &&
13261336 bias_datatype == xnn_datatype_fp32 &&
13271337 output_datatype == xnn_datatype_fp32 ) {
13281338 return true;
1329- } else if (input_datatype == xnn_datatype_qdint8 &&
1339+ } else if ((input_datatype == xnn_datatype_qdint8 ||
1340+ input_datatype == xnn_datatype_qduint8 ) &&
13301341 bias_datatype == xnn_datatype_fp32 &&
13311342 output_datatype == xnn_datatype_fp16 ) {
13321343 return true;
@@ -1390,13 +1401,15 @@ static inline bool validate_datatypes_without_bias(
13901401 if (input_datatype == xnn_datatype_fp32 &&
13911402 output_datatype == xnn_datatype_fp32 ) {
13921403 return true;
1393- } else if (input_datatype == xnn_datatype_qdint8 &&
1404+ } else if ((input_datatype == xnn_datatype_qdint8 ||
1405+ input_datatype == xnn_datatype_qduint8 ) &&
13941406 output_datatype == xnn_datatype_fp32 ) {
13951407 return true;
13961408 } else if (input_datatype == xnn_datatype_qpint8 &&
13971409 output_datatype == xnn_datatype_fp32 ) {
13981410 return true;
1399- } else if (input_datatype == xnn_datatype_qdint8 &&
1411+ } else if ((input_datatype == xnn_datatype_qdint8 ||
1412+ input_datatype == xnn_datatype_qduint8 ) &&
14001413 output_datatype == xnn_datatype_fp16 ) {
14011414 return true;
14021415 } else if (input_datatype == xnn_datatype_qint8 &&
@@ -1405,7 +1418,8 @@ static inline bool validate_datatypes_without_bias(
14051418 }
14061419 break ;
14071420 case xnn_datatype_qbint4 :
1408- if (input_datatype == xnn_datatype_qdint8 &&
1421+ if ((input_datatype == xnn_datatype_qdint8 ||
1422+ input_datatype == xnn_datatype_qduint8 ) &&
14091423 output_datatype == xnn_datatype_fp32 ) {
14101424 return true;
14111425 } else if (input_datatype == xnn_datatype_qdint8 &&
@@ -1420,13 +1434,15 @@ static inline bool validate_datatypes_without_bias(
14201434 if (input_datatype == xnn_datatype_fp32 &&
14211435 output_datatype == xnn_datatype_fp32 ) {
14221436 return true;
1423- } else if (input_datatype == xnn_datatype_qdint8 &&
1437+ } else if ((input_datatype == xnn_datatype_qdint8 ||
1438+ input_datatype == xnn_datatype_qduint8 ) &&
14241439 output_datatype == xnn_datatype_fp32 ) {
14251440 return true;
14261441 } else if (input_datatype == xnn_datatype_qpint8 &&
14271442 output_datatype == xnn_datatype_fp32 ) {
14281443 return true;
1429- } else if (input_datatype == xnn_datatype_qdint8 &&
1444+ } else if ((input_datatype == xnn_datatype_qdint8 ||
1445+ input_datatype == xnn_datatype_qduint8 ) &&
14301446 output_datatype == xnn_datatype_fp16 ) {
14311447 return true;
14321448 } else if (input_datatype == xnn_datatype_qint8 &&
@@ -1491,6 +1507,7 @@ enum xnn_status xnn_define_fully_connected(xnn_subgraph_t subgraph,
14911507 case xnn_datatype_qpint8 :
14921508 break ;
14931509 case xnn_datatype_qdint8 :
1510+ case xnn_datatype_qduint8 :
14941511 if (input_value -> quantization .num_nonbatch_dims >
14951512 input_value -> shape .num_dims ) {
14961513 xnn_log_error ("failed to define %s operator with input ID #%" PRIu32
0 commit comments