1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
5*4bdc9457SAndroid Build Coastguard Worker //
6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
8*4bdc9457SAndroid Build Coastguard Worker
9*4bdc9457SAndroid Build Coastguard Worker #include <assert.h>
10*4bdc9457SAndroid Build Coastguard Worker #include <stdbool.h>
11*4bdc9457SAndroid Build Coastguard Worker #include <stddef.h>
12*4bdc9457SAndroid Build Coastguard Worker #include <stdint.h>
13*4bdc9457SAndroid Build Coastguard Worker #include <string.h>
14*4bdc9457SAndroid Build Coastguard Worker #include <math.h>
15*4bdc9457SAndroid Build Coastguard Worker
16*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
17*4bdc9457SAndroid Build Coastguard Worker
18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/allocator.h>
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/log.h>
21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math.h>
22*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h>
23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/pack.h>
24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/params.h>
25*4bdc9457SAndroid Build Coastguard Worker
26*4bdc9457SAndroid Build Coastguard Worker
create_fully_connected_nc(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,const void * kernel,const void * bias,uint32_t flags,uint32_t log2_filter_element_size,uint32_t bias_element_size,xnn_pack_gemm_io_w_function pack_gemm_io_w,xnn_pack_gemm_goi_w_function pack_gemm_goi_w,const void * packing_params,int packed_weights_padding_byte,const void * params,size_t params_size,const struct gemm_parameters * gemm_parameters,const struct gemm_fused_ukernels * gemm_ukernels,uint32_t datatype_init_flags,enum xnn_operator_type operator_type,xnn_caches_t caches,xnn_operator_t * fully_connected_op_out)27*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status create_fully_connected_nc(
28*4bdc9457SAndroid Build Coastguard Worker size_t input_channels,
29*4bdc9457SAndroid Build Coastguard Worker size_t output_channels,
30*4bdc9457SAndroid Build Coastguard Worker size_t input_stride,
31*4bdc9457SAndroid Build Coastguard Worker size_t output_stride,
32*4bdc9457SAndroid Build Coastguard Worker const void* kernel,
33*4bdc9457SAndroid Build Coastguard Worker const void* bias,
34*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
35*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_filter_element_size,
36*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_element_size,
37*4bdc9457SAndroid Build Coastguard Worker xnn_pack_gemm_io_w_function pack_gemm_io_w,
38*4bdc9457SAndroid Build Coastguard Worker xnn_pack_gemm_goi_w_function pack_gemm_goi_w,
39*4bdc9457SAndroid Build Coastguard Worker const void* packing_params,
40*4bdc9457SAndroid Build Coastguard Worker int packed_weights_padding_byte,
41*4bdc9457SAndroid Build Coastguard Worker const void* params,
42*4bdc9457SAndroid Build Coastguard Worker size_t params_size,
43*4bdc9457SAndroid Build Coastguard Worker const struct gemm_parameters* gemm_parameters,
44*4bdc9457SAndroid Build Coastguard Worker const struct gemm_fused_ukernels* gemm_ukernels,
45*4bdc9457SAndroid Build Coastguard Worker uint32_t datatype_init_flags,
46*4bdc9457SAndroid Build Coastguard Worker enum xnn_operator_type operator_type,
47*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
48*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* fully_connected_op_out)
49*4bdc9457SAndroid Build Coastguard Worker {
50*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t fully_connected_op = NULL;
51*4bdc9457SAndroid Build Coastguard Worker enum xnn_status status = xnn_status_uninitialized;
52*4bdc9457SAndroid Build Coastguard Worker
53*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
54*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
55*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type));
56*4bdc9457SAndroid Build Coastguard Worker goto error;
57*4bdc9457SAndroid Build Coastguard Worker }
58*4bdc9457SAndroid Build Coastguard Worker
59*4bdc9457SAndroid Build Coastguard Worker status = xnn_status_unsupported_hardware;
60*4bdc9457SAndroid Build Coastguard Worker
61*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
62*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
63*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator: operations on data type are not supported",
64*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type));
65*4bdc9457SAndroid Build Coastguard Worker goto error;
66*4bdc9457SAndroid Build Coastguard Worker }
67*4bdc9457SAndroid Build Coastguard Worker
68*4bdc9457SAndroid Build Coastguard Worker status = xnn_status_invalid_parameter;
69*4bdc9457SAndroid Build Coastguard Worker
70*4bdc9457SAndroid Build Coastguard Worker if (input_channels == 0) {
71*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
72*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %zu input channels: number of channels must be non-zero",
73*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), input_channels);
74*4bdc9457SAndroid Build Coastguard Worker goto error;
75*4bdc9457SAndroid Build Coastguard Worker }
76*4bdc9457SAndroid Build Coastguard Worker
77*4bdc9457SAndroid Build Coastguard Worker if (output_channels == 0) {
78*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
79*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %zu output channels: number of channels must be non-zero",
80*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), output_channels);
81*4bdc9457SAndroid Build Coastguard Worker goto error;
82*4bdc9457SAndroid Build Coastguard Worker }
83*4bdc9457SAndroid Build Coastguard Worker
84*4bdc9457SAndroid Build Coastguard Worker if (input_stride < input_channels) {
85*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
86*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with input element stride of %zu: "
87*4bdc9457SAndroid Build Coastguard Worker "stride must be at least as large as the number of input channels (%zu)",
88*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), input_stride, input_channels);
89*4bdc9457SAndroid Build Coastguard Worker goto error;
90*4bdc9457SAndroid Build Coastguard Worker }
91*4bdc9457SAndroid Build Coastguard Worker
92*4bdc9457SAndroid Build Coastguard Worker if (output_stride < output_channels) {
93*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
94*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with output element stride of %zu: "
95*4bdc9457SAndroid Build Coastguard Worker "stride must be at least as large as the number of output channels (%zu)",
96*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), output_stride, output_channels);
97*4bdc9457SAndroid Build Coastguard Worker goto error;
98*4bdc9457SAndroid Build Coastguard Worker }
99*4bdc9457SAndroid Build Coastguard Worker
100*4bdc9457SAndroid Build Coastguard Worker status = xnn_status_out_of_memory;
101*4bdc9457SAndroid Build Coastguard Worker
102*4bdc9457SAndroid Build Coastguard Worker fully_connected_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
103*4bdc9457SAndroid Build Coastguard Worker if (fully_connected_op == NULL) {
104*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
105*4bdc9457SAndroid Build Coastguard Worker "failed to allocate %zu bytes for %s operator descriptor",
106*4bdc9457SAndroid Build Coastguard Worker sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
107*4bdc9457SAndroid Build Coastguard Worker goto error;
108*4bdc9457SAndroid Build Coastguard Worker }
109*4bdc9457SAndroid Build Coastguard Worker
110*4bdc9457SAndroid Build Coastguard Worker if (caches != NULL) {
111*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->weights_cache = caches->weights_cache;
112*4bdc9457SAndroid Build Coastguard Worker }
113*4bdc9457SAndroid Build Coastguard Worker
114*4bdc9457SAndroid Build Coastguard Worker const uint32_t nr = gemm_parameters->nr;
115*4bdc9457SAndroid Build Coastguard Worker const uint32_t kr = UINT32_C(1) << gemm_parameters->log2_kr;
116*4bdc9457SAndroid Build Coastguard Worker const uint32_t sr = UINT32_C(1) << gemm_parameters->log2_sr;
117*4bdc9457SAndroid Build Coastguard Worker
118*4bdc9457SAndroid Build Coastguard Worker const size_t n_stride = round_up(output_channels, nr);
119*4bdc9457SAndroid Build Coastguard Worker const size_t k_stride = round_up_po2(input_channels, kr * sr);
120*4bdc9457SAndroid Build Coastguard Worker
121*4bdc9457SAndroid Build Coastguard Worker const size_t packed_weights_size = n_stride * (bias_element_size + (k_stride << log2_filter_element_size));
122*4bdc9457SAndroid Build Coastguard Worker size_t aligned_total_weights_size = round_up_po2(packed_weights_size, XNN_ALLOCATION_ALIGNMENT);
123*4bdc9457SAndroid Build Coastguard Worker void* weights_ptr = xnn_get_pointer_to_write_weights(
124*4bdc9457SAndroid Build Coastguard Worker fully_connected_op, aligned_total_weights_size, packed_weights_padding_byte);
125*4bdc9457SAndroid Build Coastguard Worker if (weights_ptr == NULL) {
126*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
127*4bdc9457SAndroid Build Coastguard Worker "failed to allocate %zu bytes for %s operator packed weights",
128*4bdc9457SAndroid Build Coastguard Worker packed_weights_size, xnn_operator_type_to_string(operator_type));
129*4bdc9457SAndroid Build Coastguard Worker goto error;
130*4bdc9457SAndroid Build Coastguard Worker }
131*4bdc9457SAndroid Build Coastguard Worker
132*4bdc9457SAndroid Build Coastguard Worker if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
133*4bdc9457SAndroid Build Coastguard Worker pack_gemm_io_w(
134*4bdc9457SAndroid Build Coastguard Worker output_channels, input_channels,
135*4bdc9457SAndroid Build Coastguard Worker nr, kr, sr,
136*4bdc9457SAndroid Build Coastguard Worker kernel, bias,
137*4bdc9457SAndroid Build Coastguard Worker weights_ptr,
138*4bdc9457SAndroid Build Coastguard Worker packing_params);
139*4bdc9457SAndroid Build Coastguard Worker } else {
140*4bdc9457SAndroid Build Coastguard Worker pack_gemm_goi_w(
141*4bdc9457SAndroid Build Coastguard Worker 1, output_channels, input_channels,
142*4bdc9457SAndroid Build Coastguard Worker nr, kr, sr,
143*4bdc9457SAndroid Build Coastguard Worker kernel, bias,
144*4bdc9457SAndroid Build Coastguard Worker weights_ptr,
145*4bdc9457SAndroid Build Coastguard Worker 0 /* extra bytes */,
146*4bdc9457SAndroid Build Coastguard Worker packing_params);
147*4bdc9457SAndroid Build Coastguard Worker }
148*4bdc9457SAndroid Build Coastguard Worker
149*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache(fully_connected_op)) {
150*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->packed_weights.offset = xnn_get_or_insert_weights_cache(
151*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->weights_cache, weights_ptr, aligned_total_weights_size);
152*4bdc9457SAndroid Build Coastguard Worker }
153*4bdc9457SAndroid Build Coastguard Worker
154*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->group_input_channels = input_channels;
155*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->group_output_channels = output_channels;
156*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->input_pixel_stride = input_stride;
157*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->output_pixel_stride = output_stride;
158*4bdc9457SAndroid Build Coastguard Worker
159*4bdc9457SAndroid Build Coastguard Worker memcpy(&fully_connected_op->params, params, params_size);
160*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->type = operator_type;
161*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->flags = flags;
162*4bdc9457SAndroid Build Coastguard Worker
163*4bdc9457SAndroid Build Coastguard Worker const size_t mr = gemm_parameters->mr;
164*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->ukernel.type = xnn_ukernel_type_gemm;
165*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
166*4bdc9457SAndroid Build Coastguard Worker .mr = mr,
167*4bdc9457SAndroid Build Coastguard Worker .nr = nr,
168*4bdc9457SAndroid Build Coastguard Worker .kr = kr,
169*4bdc9457SAndroid Build Coastguard Worker .sr = sr,
170*4bdc9457SAndroid Build Coastguard Worker };
171*4bdc9457SAndroid Build Coastguard Worker
172*4bdc9457SAndroid Build Coastguard Worker assert(XNN_MAX_MR >= mr);
173*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->ukernel.gemm.gemm_cases[0] = gemm_ukernels->gemm[0];
174*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 1; i < mr; i++) {
175*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->ukernel.gemm.gemm_cases[i] = gemm_ukernels->gemm[mr-1];
176*4bdc9457SAndroid Build Coastguard Worker }
177*4bdc9457SAndroid Build Coastguard Worker
178*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->state = xnn_run_state_invalid;
179*4bdc9457SAndroid Build Coastguard Worker
180*4bdc9457SAndroid Build Coastguard Worker *fully_connected_op_out = fully_connected_op;
181*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
182*4bdc9457SAndroid Build Coastguard Worker
183*4bdc9457SAndroid Build Coastguard Worker error:
184*4bdc9457SAndroid Build Coastguard Worker xnn_delete_operator(fully_connected_op);
185*4bdc9457SAndroid Build Coastguard Worker return status;
186*4bdc9457SAndroid Build Coastguard Worker }
187*4bdc9457SAndroid Build Coastguard Worker
setup_fully_connected_nc(xnn_operator_t fully_connected_op,enum xnn_operator_type expected_operator_type,size_t batch_size,const void * input,void * output,uint32_t datatype_init_flags,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t bias_element_size,uint32_t log2_output_element_size,const void * params,size_t params_size,size_t num_threads)188*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status setup_fully_connected_nc(
189*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t fully_connected_op,
190*4bdc9457SAndroid Build Coastguard Worker enum xnn_operator_type expected_operator_type,
191*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
192*4bdc9457SAndroid Build Coastguard Worker const void* input,
193*4bdc9457SAndroid Build Coastguard Worker void* output,
194*4bdc9457SAndroid Build Coastguard Worker uint32_t datatype_init_flags,
195*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_input_element_size,
196*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_filter_element_size,
197*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_element_size,
198*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_output_element_size,
199*4bdc9457SAndroid Build Coastguard Worker const void* params,
200*4bdc9457SAndroid Build Coastguard Worker size_t params_size,
201*4bdc9457SAndroid Build Coastguard Worker size_t num_threads)
202*4bdc9457SAndroid Build Coastguard Worker {
203*4bdc9457SAndroid Build Coastguard Worker if (fully_connected_op->type != expected_operator_type) {
204*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
205*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(expected_operator_type),
206*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(fully_connected_op->type));
207*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
208*4bdc9457SAndroid Build Coastguard Worker }
209*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->state = xnn_run_state_invalid;
210*4bdc9457SAndroid Build Coastguard Worker
211*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
212*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
213*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(fully_connected_op->type));
214*4bdc9457SAndroid Build Coastguard Worker return xnn_status_uninitialized;
215*4bdc9457SAndroid Build Coastguard Worker }
216*4bdc9457SAndroid Build Coastguard Worker
217*4bdc9457SAndroid Build Coastguard Worker if (batch_size == 0) {
218*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->state = xnn_run_state_skip;
219*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
220*4bdc9457SAndroid Build Coastguard Worker }
221*4bdc9457SAndroid Build Coastguard Worker
222*4bdc9457SAndroid Build Coastguard Worker if (fully_connected_op->weights_cache != NULL &&
223*4bdc9457SAndroid Build Coastguard Worker !xnn_weights_cache_is_finalized(fully_connected_op->weights_cache)) {
224*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to setup %s operator: weights cache is not finalized",
225*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(fully_connected_op->type));
226*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_state;
227*4bdc9457SAndroid Build Coastguard Worker }
228*4bdc9457SAndroid Build Coastguard Worker
229*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->batch_size = 1;
230*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->input_height = batch_size;
231*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->input_width = 1;
232*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->input = input;
233*4bdc9457SAndroid Build Coastguard Worker
234*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->output_height = batch_size;
235*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->output_width = 1;
236*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->output = output;
237*4bdc9457SAndroid Build Coastguard Worker
238*4bdc9457SAndroid Build Coastguard Worker const size_t input_channels = fully_connected_op->group_input_channels;
239*4bdc9457SAndroid Build Coastguard Worker const size_t output_channels = fully_connected_op->group_output_channels;
240*4bdc9457SAndroid Build Coastguard Worker
241*4bdc9457SAndroid Build Coastguard Worker uint32_t mr = fully_connected_op->ukernel.gemm.mr;
242*4bdc9457SAndroid Build Coastguard Worker const uint32_t nr = fully_connected_op->ukernel.gemm.nr;
243*4bdc9457SAndroid Build Coastguard Worker
244*4bdc9457SAndroid Build Coastguard Worker struct xnn_hmp_gemm_ukernel gemm_ukernel = fully_connected_op->ukernel.gemm.gemm_cases[mr-1];
245*4bdc9457SAndroid Build Coastguard Worker if (batch_size == 1 && fully_connected_op->ukernel.gemm.gemm_cases[0].function[XNN_UARCH_DEFAULT] != NULL) {
246*4bdc9457SAndroid Build Coastguard Worker gemm_ukernel = fully_connected_op->ukernel.gemm.gemm_cases[0];
247*4bdc9457SAndroid Build Coastguard Worker mr = 1;
248*4bdc9457SAndroid Build Coastguard Worker }
249*4bdc9457SAndroid Build Coastguard Worker
250*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->context.gemm = (struct gemm_context) {
251*4bdc9457SAndroid Build Coastguard Worker .k_scaled = input_channels << log2_input_element_size,
252*4bdc9457SAndroid Build Coastguard Worker .w_stride = bias_element_size +
253*4bdc9457SAndroid Build Coastguard Worker (round_up_po2(input_channels, fully_connected_op->ukernel.gemm.kr * fully_connected_op->ukernel.gemm.sr) << log2_input_element_size),
254*4bdc9457SAndroid Build Coastguard Worker .a = input,
255*4bdc9457SAndroid Build Coastguard Worker .a_stride = fully_connected_op->input_pixel_stride << log2_input_element_size,
256*4bdc9457SAndroid Build Coastguard Worker .packed_w = packed_weights(fully_connected_op),
257*4bdc9457SAndroid Build Coastguard Worker .c = output,
258*4bdc9457SAndroid Build Coastguard Worker .cm_stride = fully_connected_op->output_pixel_stride << log2_output_element_size,
259*4bdc9457SAndroid Build Coastguard Worker .cn_stride = nr << log2_output_element_size,
260*4bdc9457SAndroid Build Coastguard Worker .log2_csize = log2_output_element_size,
261*4bdc9457SAndroid Build Coastguard Worker .ukernel = gemm_ukernel,
262*4bdc9457SAndroid Build Coastguard Worker };
263*4bdc9457SAndroid Build Coastguard Worker memcpy(&fully_connected_op->context.gemm.params, params, params_size);
264*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->context.gemm.fused_params = &fully_connected_op->context.gemm.params;
265*4bdc9457SAndroid Build Coastguard Worker
266*4bdc9457SAndroid Build Coastguard Worker #if XNN_TEST_MODE
267*4bdc9457SAndroid Build Coastguard Worker const size_t nc = nr;
268*4bdc9457SAndroid Build Coastguard Worker #else
269*4bdc9457SAndroid Build Coastguard Worker size_t nc = output_channels;
270*4bdc9457SAndroid Build Coastguard Worker if (num_threads > 1) {
271*4bdc9457SAndroid Build Coastguard Worker const size_t num_other_tiles = divide_round_up(batch_size, mr);
272*4bdc9457SAndroid Build Coastguard Worker const size_t target_tiles_per_thread = 5;
273*4bdc9457SAndroid Build Coastguard Worker const size_t max_nc = divide_round_up(output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
274*4bdc9457SAndroid Build Coastguard Worker if (max_nc < nc) {
275*4bdc9457SAndroid Build Coastguard Worker nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
276*4bdc9457SAndroid Build Coastguard Worker }
277*4bdc9457SAndroid Build Coastguard Worker }
278*4bdc9457SAndroid Build Coastguard Worker #endif
279*4bdc9457SAndroid Build Coastguard Worker #if XNN_MAX_UARCH_TYPES > 1
280*4bdc9457SAndroid Build Coastguard Worker if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
281*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->compute.type = xnn_parallelization_type_2d_tile_2d_with_uarch;
282*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->compute.task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm;
283*4bdc9457SAndroid Build Coastguard Worker } else {
284*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->compute.type = xnn_parallelization_type_2d_tile_2d;
285*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
286*4bdc9457SAndroid Build Coastguard Worker }
287*4bdc9457SAndroid Build Coastguard Worker #else
288*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->compute.type = xnn_parallelization_type_2d_tile_2d;
289*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
290*4bdc9457SAndroid Build Coastguard Worker #endif
291*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->compute.range[0] = batch_size;
292*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->compute.range[1] = output_channels;
293*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->compute.tile[0] = mr;
294*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->compute.tile[1] = nc;
295*4bdc9457SAndroid Build Coastguard Worker fully_connected_op->state = xnn_run_state_ready;
296*4bdc9457SAndroid Build Coastguard Worker
297*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
298*4bdc9457SAndroid Build Coastguard Worker }
299*4bdc9457SAndroid Build Coastguard Worker
xnn_create_fully_connected_nc_f16(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,const void * kernel,const void * bias,float output_min,float output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * fully_connected_op_out)300*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_fully_connected_nc_f16(
301*4bdc9457SAndroid Build Coastguard Worker size_t input_channels,
302*4bdc9457SAndroid Build Coastguard Worker size_t output_channels,
303*4bdc9457SAndroid Build Coastguard Worker size_t input_stride,
304*4bdc9457SAndroid Build Coastguard Worker size_t output_stride,
305*4bdc9457SAndroid Build Coastguard Worker const void* kernel,
306*4bdc9457SAndroid Build Coastguard Worker const void* bias,
307*4bdc9457SAndroid Build Coastguard Worker float output_min,
308*4bdc9457SAndroid Build Coastguard Worker float output_max,
309*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
310*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
311*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* fully_connected_op_out)
312*4bdc9457SAndroid Build Coastguard Worker {
313*4bdc9457SAndroid Build Coastguard Worker if (isnan(output_min)) {
314*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
315*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
316*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f16));
317*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
318*4bdc9457SAndroid Build Coastguard Worker }
319*4bdc9457SAndroid Build Coastguard Worker
320*4bdc9457SAndroid Build Coastguard Worker if (isnan(output_max)) {
321*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
322*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
323*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f16));
324*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
325*4bdc9457SAndroid Build Coastguard Worker }
326*4bdc9457SAndroid Build Coastguard Worker
327*4bdc9457SAndroid Build Coastguard Worker const uint16_t fp16_output_min = fp16_ieee_from_fp32_value(output_min);
328*4bdc9457SAndroid Build Coastguard Worker const uint16_t fp16_output_max = fp16_ieee_from_fp32_value(output_max);
329*4bdc9457SAndroid Build Coastguard Worker const float rounded_output_min = fp16_ieee_to_fp32_value(fp16_output_min);
330*4bdc9457SAndroid Build Coastguard Worker const float rounded_output_max = fp16_ieee_to_fp32_value(fp16_output_max);
331*4bdc9457SAndroid Build Coastguard Worker if (rounded_output_min >= rounded_output_max) {
332*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
333*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
334*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f16), rounded_output_min, rounded_output_max);
335*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
336*4bdc9457SAndroid Build Coastguard Worker }
337*4bdc9457SAndroid Build Coastguard Worker
338*4bdc9457SAndroid Build Coastguard Worker union xnn_f16_minmax_params params;
339*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(xnn_params.f16.gemm.init.f16 != NULL) {
340*4bdc9457SAndroid Build Coastguard Worker xnn_params.f16.gemm.init.f16(¶ms, fp16_output_min, fp16_output_max);
341*4bdc9457SAndroid Build Coastguard Worker }
342*4bdc9457SAndroid Build Coastguard Worker xnn_pack_gemm_io_w_function pack_gemm_io_w = (xnn_pack_gemm_io_w_function) xnn_pack_f16_gemm_io_w;
343*4bdc9457SAndroid Build Coastguard Worker xnn_pack_gemm_goi_w_function pack_gemm_goi_w = (xnn_pack_gemm_goi_w_function) xnn_pack_f16_gemm_goi_w;
344*4bdc9457SAndroid Build Coastguard Worker if (flags & XNN_FLAG_FP32_STATIC_WEIGHTS) {
345*4bdc9457SAndroid Build Coastguard Worker pack_gemm_io_w = (xnn_pack_gemm_io_w_function) xnn_pack_f32_to_f16_gemm_io_w;
346*4bdc9457SAndroid Build Coastguard Worker pack_gemm_goi_w = (xnn_pack_gemm_goi_w_function) xnn_pack_f32_to_f16_gemm_goi_w;
347*4bdc9457SAndroid Build Coastguard Worker }
348*4bdc9457SAndroid Build Coastguard Worker return create_fully_connected_nc(
349*4bdc9457SAndroid Build Coastguard Worker input_channels, output_channels,
350*4bdc9457SAndroid Build Coastguard Worker input_stride, output_stride,
351*4bdc9457SAndroid Build Coastguard Worker kernel, bias, flags,
352*4bdc9457SAndroid Build Coastguard Worker 1 /* log2(sizeof(filter element)) = log2(sizeof(uint16_t)) */,
353*4bdc9457SAndroid Build Coastguard Worker sizeof(uint16_t) /* sizeof(bias element) */,
354*4bdc9457SAndroid Build Coastguard Worker pack_gemm_io_w,
355*4bdc9457SAndroid Build Coastguard Worker pack_gemm_goi_w,
356*4bdc9457SAndroid Build Coastguard Worker NULL /* packing params */, 0 /* packed weights padding byte */,
357*4bdc9457SAndroid Build Coastguard Worker ¶ms, sizeof(params),
358*4bdc9457SAndroid Build Coastguard Worker &xnn_params.f16.gemm, &xnn_params.f16.gemm.minmax,
359*4bdc9457SAndroid Build Coastguard Worker XNN_INIT_FLAG_F16,
360*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_fully_connected_nc_f16,
361*4bdc9457SAndroid Build Coastguard Worker caches,
362*4bdc9457SAndroid Build Coastguard Worker fully_connected_op_out);
363*4bdc9457SAndroid Build Coastguard Worker }
364*4bdc9457SAndroid Build Coastguard Worker
xnn_create_fully_connected_nc_f32(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,const float * kernel,const float * bias,float output_min,float output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * fully_connected_op_out)365*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_fully_connected_nc_f32(
366*4bdc9457SAndroid Build Coastguard Worker size_t input_channels,
367*4bdc9457SAndroid Build Coastguard Worker size_t output_channels,
368*4bdc9457SAndroid Build Coastguard Worker size_t input_stride,
369*4bdc9457SAndroid Build Coastguard Worker size_t output_stride,
370*4bdc9457SAndroid Build Coastguard Worker const float* kernel,
371*4bdc9457SAndroid Build Coastguard Worker const float* bias,
372*4bdc9457SAndroid Build Coastguard Worker float output_min,
373*4bdc9457SAndroid Build Coastguard Worker float output_max,
374*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
375*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
376*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* fully_connected_op_out)
377*4bdc9457SAndroid Build Coastguard Worker {
378*4bdc9457SAndroid Build Coastguard Worker if (isnan(output_min)) {
379*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
380*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
381*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
382*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
383*4bdc9457SAndroid Build Coastguard Worker }
384*4bdc9457SAndroid Build Coastguard Worker
385*4bdc9457SAndroid Build Coastguard Worker if (isnan(output_max)) {
386*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
387*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
388*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
389*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
390*4bdc9457SAndroid Build Coastguard Worker }
391*4bdc9457SAndroid Build Coastguard Worker
392*4bdc9457SAndroid Build Coastguard Worker if (output_min >= output_max) {
393*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
394*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
395*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32), output_min, output_max);
396*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
397*4bdc9457SAndroid Build Coastguard Worker }
398*4bdc9457SAndroid Build Coastguard Worker
399*4bdc9457SAndroid Build Coastguard Worker const struct gemm_fused_ukernels* gemm_ukernels = &xnn_params.f32.gemm.minmax;
400*4bdc9457SAndroid Build Coastguard Worker const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
401*4bdc9457SAndroid Build Coastguard Worker if (linear_activation && xnn_params.f32.gemm.linear.gemm[xnn_params.f32.gemm.mr-1].function[XNN_UARCH_DEFAULT] != NULL) {
402*4bdc9457SAndroid Build Coastguard Worker gemm_ukernels = &xnn_params.f32.gemm.linear;
403*4bdc9457SAndroid Build Coastguard Worker }
404*4bdc9457SAndroid Build Coastguard Worker
405*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_minmax_params params;
406*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(xnn_params.f32.gemm.init.f32 != NULL) {
407*4bdc9457SAndroid Build Coastguard Worker xnn_params.f32.gemm.init.f32(¶ms, output_min, output_max);
408*4bdc9457SAndroid Build Coastguard Worker }
409*4bdc9457SAndroid Build Coastguard Worker return create_fully_connected_nc(
410*4bdc9457SAndroid Build Coastguard Worker input_channels, output_channels,
411*4bdc9457SAndroid Build Coastguard Worker input_stride, output_stride,
412*4bdc9457SAndroid Build Coastguard Worker kernel, bias, flags,
413*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
414*4bdc9457SAndroid Build Coastguard Worker sizeof(float) /* sizeof(bias element) */,
415*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_gemm_io_w_function) xnn_pack_f32_gemm_io_w,
416*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_gemm_goi_w_function) xnn_pack_f32_gemm_goi_w,
417*4bdc9457SAndroid Build Coastguard Worker NULL /* packing params */, 0 /* packed weights padding byte */,
418*4bdc9457SAndroid Build Coastguard Worker ¶ms, sizeof(params),
419*4bdc9457SAndroid Build Coastguard Worker &xnn_params.f32.gemm, gemm_ukernels,
420*4bdc9457SAndroid Build Coastguard Worker XNN_INIT_FLAG_F32,
421*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_fully_connected_nc_f32,
422*4bdc9457SAndroid Build Coastguard Worker caches,
423*4bdc9457SAndroid Build Coastguard Worker fully_connected_op_out);
424*4bdc9457SAndroid Build Coastguard Worker }
425*4bdc9457SAndroid Build Coastguard Worker
xnn_create_fully_connected_nc_qs8(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,int8_t input_zero_point,float input_scale,float kernel_scale,const int8_t * kernel,const int32_t * bias,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * fully_connected_op_out)426*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_fully_connected_nc_qs8(
427*4bdc9457SAndroid Build Coastguard Worker size_t input_channels,
428*4bdc9457SAndroid Build Coastguard Worker size_t output_channels,
429*4bdc9457SAndroid Build Coastguard Worker size_t input_stride,
430*4bdc9457SAndroid Build Coastguard Worker size_t output_stride,
431*4bdc9457SAndroid Build Coastguard Worker int8_t input_zero_point,
432*4bdc9457SAndroid Build Coastguard Worker float input_scale,
433*4bdc9457SAndroid Build Coastguard Worker float kernel_scale,
434*4bdc9457SAndroid Build Coastguard Worker const int8_t* kernel,
435*4bdc9457SAndroid Build Coastguard Worker const int32_t* bias,
436*4bdc9457SAndroid Build Coastguard Worker int8_t output_zero_point,
437*4bdc9457SAndroid Build Coastguard Worker float output_scale,
438*4bdc9457SAndroid Build Coastguard Worker int8_t output_min,
439*4bdc9457SAndroid Build Coastguard Worker int8_t output_max,
440*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
441*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
442*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* fully_connected_op_out)
443*4bdc9457SAndroid Build Coastguard Worker {
444*4bdc9457SAndroid Build Coastguard Worker if (input_scale <= 0.0f || !isnormal(input_scale)) {
445*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
446*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
447*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8), input_scale);
448*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
449*4bdc9457SAndroid Build Coastguard Worker }
450*4bdc9457SAndroid Build Coastguard Worker
451*4bdc9457SAndroid Build Coastguard Worker if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
452*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
453*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
454*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8), kernel_scale);
455*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
456*4bdc9457SAndroid Build Coastguard Worker }
457*4bdc9457SAndroid Build Coastguard Worker
458*4bdc9457SAndroid Build Coastguard Worker if (output_scale <= 0.0f || !isnormal(output_scale)) {
459*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
460*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
461*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8), output_scale);
462*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
463*4bdc9457SAndroid Build Coastguard Worker }
464*4bdc9457SAndroid Build Coastguard Worker
465*4bdc9457SAndroid Build Coastguard Worker if (output_min >= output_max) {
466*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
467*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: range min must be below range max",
468*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8), output_min, output_max);
469*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
470*4bdc9457SAndroid Build Coastguard Worker }
471*4bdc9457SAndroid Build Coastguard Worker
472*4bdc9457SAndroid Build Coastguard Worker const float requantization_scale = input_scale * kernel_scale / output_scale;
473*4bdc9457SAndroid Build Coastguard Worker if (requantization_scale >= 256.0f) {
474*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
475*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
476*4bdc9457SAndroid Build Coastguard Worker "requantization scale %.7g is greater or equal to 256.0",
477*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8),
478*4bdc9457SAndroid Build Coastguard Worker input_scale, kernel_scale, output_scale, requantization_scale);
479*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
480*4bdc9457SAndroid Build Coastguard Worker }
481*4bdc9457SAndroid Build Coastguard Worker
482*4bdc9457SAndroid Build Coastguard Worker union xnn_qs8_conv_minmax_params params;
483*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(xnn_params.qs8.gemm.init.qs8 != NULL) {
484*4bdc9457SAndroid Build Coastguard Worker xnn_params.qs8.gemm.init.qs8(¶ms, requantization_scale, output_zero_point, output_min, output_max);
485*4bdc9457SAndroid Build Coastguard Worker }
486*4bdc9457SAndroid Build Coastguard Worker const struct xnn_qs8_packing_params packing_params = {
487*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = input_zero_point,
488*4bdc9457SAndroid Build Coastguard Worker };
489*4bdc9457SAndroid Build Coastguard Worker return create_fully_connected_nc(
490*4bdc9457SAndroid Build Coastguard Worker input_channels, output_channels,
491*4bdc9457SAndroid Build Coastguard Worker input_stride, output_stride,
492*4bdc9457SAndroid Build Coastguard Worker kernel, bias, flags,
493*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(filter element)) = log2(sizeof(int8_t)) */,
494*4bdc9457SAndroid Build Coastguard Worker sizeof(int32_t) /* sizeof(bias element) */,
495*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_gemm_io_w_function) xnn_pack_qs8_gemm_io_w,
496*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_gemm_goi_w_function) xnn_pack_qs8_gemm_goi_w,
497*4bdc9457SAndroid Build Coastguard Worker &packing_params, 0 /* packed weights padding byte */,
498*4bdc9457SAndroid Build Coastguard Worker ¶ms, sizeof(params),
499*4bdc9457SAndroid Build Coastguard Worker &xnn_params.qs8.gemm, &xnn_params.qs8.gemm.minmax,
500*4bdc9457SAndroid Build Coastguard Worker XNN_INIT_FLAG_QS8,
501*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_fully_connected_nc_qs8,
502*4bdc9457SAndroid Build Coastguard Worker caches,
503*4bdc9457SAndroid Build Coastguard Worker fully_connected_op_out);
504*4bdc9457SAndroid Build Coastguard Worker }
505*4bdc9457SAndroid Build Coastguard Worker
xnn_create_fully_connected_nc_qu8(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,uint8_t input_zero_point,float input_scale,uint8_t kernel_zero_point,float kernel_scale,const uint8_t * kernel,const int32_t * bias,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * fully_connected_op_out)506*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_fully_connected_nc_qu8(
507*4bdc9457SAndroid Build Coastguard Worker size_t input_channels,
508*4bdc9457SAndroid Build Coastguard Worker size_t output_channels,
509*4bdc9457SAndroid Build Coastguard Worker size_t input_stride,
510*4bdc9457SAndroid Build Coastguard Worker size_t output_stride,
511*4bdc9457SAndroid Build Coastguard Worker uint8_t input_zero_point,
512*4bdc9457SAndroid Build Coastguard Worker float input_scale,
513*4bdc9457SAndroid Build Coastguard Worker uint8_t kernel_zero_point,
514*4bdc9457SAndroid Build Coastguard Worker float kernel_scale,
515*4bdc9457SAndroid Build Coastguard Worker const uint8_t* kernel,
516*4bdc9457SAndroid Build Coastguard Worker const int32_t* bias,
517*4bdc9457SAndroid Build Coastguard Worker uint8_t output_zero_point,
518*4bdc9457SAndroid Build Coastguard Worker float output_scale,
519*4bdc9457SAndroid Build Coastguard Worker uint8_t output_min,
520*4bdc9457SAndroid Build Coastguard Worker uint8_t output_max,
521*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
522*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
523*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* fully_connected_op_out)
524*4bdc9457SAndroid Build Coastguard Worker {
525*4bdc9457SAndroid Build Coastguard Worker if (input_scale <= 0.0f || !isnormal(input_scale)) {
526*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
527*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
528*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), input_scale);
529*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
530*4bdc9457SAndroid Build Coastguard Worker }
531*4bdc9457SAndroid Build Coastguard Worker
532*4bdc9457SAndroid Build Coastguard Worker if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
533*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
534*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
535*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), kernel_scale);
536*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
537*4bdc9457SAndroid Build Coastguard Worker }
538*4bdc9457SAndroid Build Coastguard Worker
539*4bdc9457SAndroid Build Coastguard Worker if (output_scale <= 0.0f || !isnormal(output_scale)) {
540*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
541*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
542*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), output_scale);
543*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
544*4bdc9457SAndroid Build Coastguard Worker }
545*4bdc9457SAndroid Build Coastguard Worker
546*4bdc9457SAndroid Build Coastguard Worker if (output_min >= output_max) {
547*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
548*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
549*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), output_min, output_max);
550*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
551*4bdc9457SAndroid Build Coastguard Worker }
552*4bdc9457SAndroid Build Coastguard Worker
553*4bdc9457SAndroid Build Coastguard Worker const float requantization_scale = input_scale * kernel_scale / output_scale;
554*4bdc9457SAndroid Build Coastguard Worker if (requantization_scale >= 256.0f) {
555*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
556*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
557*4bdc9457SAndroid Build Coastguard Worker "requantization scale %.7g is greater or equal to 256.0",
558*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8),
559*4bdc9457SAndroid Build Coastguard Worker input_scale, kernel_scale, output_scale, requantization_scale);
560*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
561*4bdc9457SAndroid Build Coastguard Worker }
562*4bdc9457SAndroid Build Coastguard Worker
563*4bdc9457SAndroid Build Coastguard Worker union xnn_qu8_conv_minmax_params params;
564*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(xnn_params.qu8.gemm.init.qu8 != NULL) {
565*4bdc9457SAndroid Build Coastguard Worker xnn_params.qu8.gemm.init.qu8(¶ms,
566*4bdc9457SAndroid Build Coastguard Worker kernel_zero_point, requantization_scale, output_zero_point, output_min, output_max);
567*4bdc9457SAndroid Build Coastguard Worker }
568*4bdc9457SAndroid Build Coastguard Worker const struct xnn_qu8_packing_params packing_params = {
569*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = input_zero_point,
570*4bdc9457SAndroid Build Coastguard Worker .kernel_zero_point = kernel_zero_point,
571*4bdc9457SAndroid Build Coastguard Worker };
572*4bdc9457SAndroid Build Coastguard Worker return create_fully_connected_nc(
573*4bdc9457SAndroid Build Coastguard Worker input_channels, output_channels,
574*4bdc9457SAndroid Build Coastguard Worker input_stride, output_stride,
575*4bdc9457SAndroid Build Coastguard Worker kernel, bias, flags,
576*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
577*4bdc9457SAndroid Build Coastguard Worker sizeof(int32_t) /* sizeof(bias element) */,
578*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_gemm_io_w_function) xnn_pack_qu8_gemm_io_w,
579*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_gemm_goi_w_function) xnn_pack_qu8_gemm_goi_w,
580*4bdc9457SAndroid Build Coastguard Worker &packing_params, kernel_zero_point /* packed weights padding byte */,
581*4bdc9457SAndroid Build Coastguard Worker ¶ms, sizeof(params),
582*4bdc9457SAndroid Build Coastguard Worker &xnn_params.qu8.gemm, &xnn_params.qu8.gemm.minmax,
583*4bdc9457SAndroid Build Coastguard Worker XNN_INIT_FLAG_QU8,
584*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_fully_connected_nc_qu8,
585*4bdc9457SAndroid Build Coastguard Worker caches,
586*4bdc9457SAndroid Build Coastguard Worker fully_connected_op_out);
587*4bdc9457SAndroid Build Coastguard Worker }
588*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_fully_connected_nc_f16(xnn_operator_t fully_connected_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)589*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_fully_connected_nc_f16(
590*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t fully_connected_op,
591*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
592*4bdc9457SAndroid Build Coastguard Worker const void* input,
593*4bdc9457SAndroid Build Coastguard Worker void* output,
594*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
595*4bdc9457SAndroid Build Coastguard Worker {
596*4bdc9457SAndroid Build Coastguard Worker return setup_fully_connected_nc(
597*4bdc9457SAndroid Build Coastguard Worker fully_connected_op, xnn_operator_type_fully_connected_nc_f16,
598*4bdc9457SAndroid Build Coastguard Worker batch_size,
599*4bdc9457SAndroid Build Coastguard Worker input, output,
600*4bdc9457SAndroid Build Coastguard Worker XNN_INIT_FLAG_F32,
601*4bdc9457SAndroid Build Coastguard Worker 1 /* log2(sizeof(input element)) = log2(sizeof(uint16_t)) */,
602*4bdc9457SAndroid Build Coastguard Worker 1 /* log2(sizeof(filter element)) = log2(sizeof(uint16_t)) */,
603*4bdc9457SAndroid Build Coastguard Worker sizeof(uint16_t) /* sizeof(bias element) */,
604*4bdc9457SAndroid Build Coastguard Worker 1 /* log2(sizeof(output element)) = log2(sizeof(uint16_t)) */,
605*4bdc9457SAndroid Build Coastguard Worker &fully_connected_op->params.f16_minmax,
606*4bdc9457SAndroid Build Coastguard Worker sizeof(fully_connected_op->params.f16_minmax),
607*4bdc9457SAndroid Build Coastguard Worker pthreadpool_get_threads_count(threadpool));
608*4bdc9457SAndroid Build Coastguard Worker }
609*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_fully_connected_nc_f32(xnn_operator_t fully_connected_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)610*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_fully_connected_nc_f32(
611*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t fully_connected_op,
612*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
613*4bdc9457SAndroid Build Coastguard Worker const float* input,
614*4bdc9457SAndroid Build Coastguard Worker float* output,
615*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
616*4bdc9457SAndroid Build Coastguard Worker {
617*4bdc9457SAndroid Build Coastguard Worker return setup_fully_connected_nc(
618*4bdc9457SAndroid Build Coastguard Worker fully_connected_op, xnn_operator_type_fully_connected_nc_f32,
619*4bdc9457SAndroid Build Coastguard Worker batch_size,
620*4bdc9457SAndroid Build Coastguard Worker input, output,
621*4bdc9457SAndroid Build Coastguard Worker XNN_INIT_FLAG_F32,
622*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
623*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
624*4bdc9457SAndroid Build Coastguard Worker sizeof(float) /* sizeof(bias element) */,
625*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
626*4bdc9457SAndroid Build Coastguard Worker &fully_connected_op->params.f32_minmax,
627*4bdc9457SAndroid Build Coastguard Worker sizeof(fully_connected_op->params.f32_minmax),
628*4bdc9457SAndroid Build Coastguard Worker pthreadpool_get_threads_count(threadpool));
629*4bdc9457SAndroid Build Coastguard Worker }
630*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_fully_connected_nc_qs8(xnn_operator_t fully_connected_op,size_t batch_size,const int8_t * input,int8_t * output,pthreadpool_t threadpool)631*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_fully_connected_nc_qs8(
632*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t fully_connected_op,
633*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
634*4bdc9457SAndroid Build Coastguard Worker const int8_t* input,
635*4bdc9457SAndroid Build Coastguard Worker int8_t* output,
636*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
637*4bdc9457SAndroid Build Coastguard Worker {
638*4bdc9457SAndroid Build Coastguard Worker return setup_fully_connected_nc(
639*4bdc9457SAndroid Build Coastguard Worker fully_connected_op, xnn_operator_type_fully_connected_nc_qs8,
640*4bdc9457SAndroid Build Coastguard Worker batch_size,
641*4bdc9457SAndroid Build Coastguard Worker input, output,
642*4bdc9457SAndroid Build Coastguard Worker XNN_INIT_FLAG_QS8,
643*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(input element)) = log2(sizeof(int8_t)) */,
644*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(filter element)) = log2(sizeof(int8_t)) */,
645*4bdc9457SAndroid Build Coastguard Worker sizeof(int32_t) /* sizeof(bias element) */,
646*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(output element)) = log2(sizeof(int8_t)) */,
647*4bdc9457SAndroid Build Coastguard Worker &fully_connected_op->params.qs8_conv_minmax,
648*4bdc9457SAndroid Build Coastguard Worker sizeof(fully_connected_op->params.qs8_conv_minmax),
649*4bdc9457SAndroid Build Coastguard Worker pthreadpool_get_threads_count(threadpool));
650*4bdc9457SAndroid Build Coastguard Worker }
651*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_fully_connected_nc_qu8(xnn_operator_t fully_connected_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)652*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_fully_connected_nc_qu8(
653*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t fully_connected_op,
654*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
655*4bdc9457SAndroid Build Coastguard Worker const uint8_t* input,
656*4bdc9457SAndroid Build Coastguard Worker uint8_t* output,
657*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
658*4bdc9457SAndroid Build Coastguard Worker {
659*4bdc9457SAndroid Build Coastguard Worker return setup_fully_connected_nc(
660*4bdc9457SAndroid Build Coastguard Worker fully_connected_op, xnn_operator_type_fully_connected_nc_qu8,
661*4bdc9457SAndroid Build Coastguard Worker batch_size,
662*4bdc9457SAndroid Build Coastguard Worker input, output,
663*4bdc9457SAndroid Build Coastguard Worker XNN_INIT_FLAG_QU8,
664*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
665*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
666*4bdc9457SAndroid Build Coastguard Worker sizeof(int32_t) /* sizeof(bias element) */,
667*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
668*4bdc9457SAndroid Build Coastguard Worker &fully_connected_op->params.qu8_conv_minmax,
669*4bdc9457SAndroid Build Coastguard Worker sizeof(fully_connected_op->params.qu8_conv_minmax),
670*4bdc9457SAndroid Build Coastguard Worker pthreadpool_get_threads_count(threadpool));
671*4bdc9457SAndroid Build Coastguard Worker }
672