1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/requantization.h>
16 #include <xnnpack/subgraph.h>
17 #include <xnnpack/subgraph-validation.h>
18
19
create_global_average_pooling_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)20 static enum xnn_status create_global_average_pooling_operator(
21 const struct xnn_node* node,
22 const struct xnn_value* values,
23 size_t num_values,
24 struct xnn_operator_data* opdata,
25 const struct xnn_caches* caches)
26 {
27 assert(node->num_inputs == 1);
28 const uint32_t input_id = node->inputs[0];
29 assert(input_id != XNN_INVALID_VALUE_ID);
30 assert(input_id < num_values);
31
32 assert(node->num_outputs == 1);
33 const uint32_t output_id = node->outputs[0];
34 assert(output_id != XNN_INVALID_VALUE_ID);
35 assert(output_id < num_values);
36
37 const size_t num_input_dims = values[input_id].shape.num_dims;
38 assert(num_input_dims >= 1);
39 const size_t channel_dim = values[input_id].shape.dim[num_input_dims - 1];
40
41 enum xnn_status status;
42 if (values[node->inputs[0]].layout == xnn_layout_type_nchw) {
43 assert(node->compute_type == xnn_compute_type_fp32);
44 status = xnn_create_global_average_pooling_ncw_f32(
45 channel_dim /* channels */,
46 node->activation.output_min,
47 node->activation.output_max,
48 node->flags,
49 &opdata->operator_objects[0]);
50 } else {
51 assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
52 assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
53 switch (node->compute_type) {
54 case xnn_compute_type_fp32:
55 status = xnn_create_global_average_pooling_nwc_f32(
56 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
57 node->activation.output_min,
58 node->activation.output_max,
59 node->flags,
60 &opdata->operator_objects[0]);
61 break;
62 #ifndef XNN_NO_F16_OPERATORS
63 case xnn_compute_type_fp16:
64 status = xnn_create_global_average_pooling_nwc_f16(
65 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
66 node->activation.output_min,
67 node->activation.output_max,
68 node->flags,
69 &opdata->operator_objects[0]);
70 break;
71 #endif // !defined(XNN_NO_F16_OPERATORS)
72 #ifndef XNN_NO_QS8_OPERATORS
73 case xnn_compute_type_qs8:
74 {
75 const float output_scale = values[output_id].quantization.scale;
76 const int32_t output_zero_point = values[output_id].quantization.zero_point;
77 const int8_t output_min = xnn_qs8_quantize(node->activation.output_min, output_scale, output_zero_point);
78 const int8_t output_max = xnn_qs8_quantize(node->activation.output_max, output_scale, output_zero_point);
79 status = xnn_create_global_average_pooling_nwc_qs8(
80 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
81 (int8_t) values[input_id].quantization.zero_point, values[input_id].quantization.scale,
82 (int8_t) values[output_id].quantization.zero_point, values[output_id].quantization.scale,
83 output_min,
84 output_max,
85 node->flags,
86 &opdata->operator_objects[0]);
87 break;
88 }
89 #endif // !defined(XNN_NO_QS8_OPERATORS)
90 #ifndef XNN_NO_QU8_OPERATORS
91 case xnn_compute_type_qu8:
92 {
93 const float output_scale = values[output_id].quantization.scale;
94 const int32_t output_zero_point = values[output_id].quantization.zero_point;
95 const uint8_t output_min = xnn_qu8_quantize(node->activation.output_min, output_scale, output_zero_point);
96 const uint8_t output_max = xnn_qu8_quantize(node->activation.output_max, output_scale, output_zero_point);
97 status = xnn_create_global_average_pooling_nwc_qu8(
98 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
99 (uint8_t) values[input_id].quantization.zero_point, values[input_id].quantization.scale,
100 (uint8_t) values[output_id].quantization.zero_point, values[output_id].quantization.scale,
101 output_min,
102 output_max,
103 node->flags,
104 &opdata->operator_objects[0]);
105 break;
106 }
107 #endif // !defined(XNN_NO_QU8_OPERATORS)
108 default:
109 XNN_UNREACHABLE;
110 }
111 }
112 if (status == xnn_status_success) {
113 switch (node->type) {
114 case xnn_node_type_global_average_pooling_1d:
115 opdata->batch_size = xnn_shape_multiply_batch_dims(&values[input_id].shape, 2);
116 opdata->input_width = values[input_id].shape.dim[num_input_dims - 2];
117 break;
118 case xnn_node_type_global_average_pooling_2d:
119 opdata->batch_size = xnn_shape_multiply_batch_dims(&values[input_id].shape, 3);
120 opdata->input_width = values[input_id].shape.dim[num_input_dims - 3] * values[input_id].shape.dim[num_input_dims - 2];
121 break;
122 default:
123 XNN_UNREACHABLE;
124 }
125 opdata->inputs[0] = input_id;
126 opdata->outputs[0] = output_id;
127 }
128 return status;
129 }
130
setup_global_average_pooling_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)131 static enum xnn_status setup_global_average_pooling_operator(
132 const struct xnn_operator_data* opdata,
133 const struct xnn_blob* blobs,
134 size_t num_blobs,
135 pthreadpool_t threadpool)
136 {
137 const uint32_t input_id = opdata->inputs[0];
138 assert(input_id != XNN_INVALID_VALUE_ID);
139 assert(input_id < num_blobs);
140
141 const uint32_t output_id = opdata->outputs[0];
142 assert(output_id != XNN_INVALID_VALUE_ID);
143 assert(output_id < num_blobs);
144
145 const struct xnn_blob* input_blob = blobs + input_id;
146 const void* input_data = input_blob->data;
147 assert(input_data != NULL);
148
149 const struct xnn_blob* output_blob = blobs + output_id;
150 void* output_data = output_blob->data;
151 assert(output_data != NULL);
152
153 switch (opdata->operator_objects[0]->type) {
154 case xnn_operator_type_global_average_pooling_ncw_f32:
155 return xnn_setup_global_average_pooling_ncw_f32(
156 opdata->operator_objects[0],
157 opdata->batch_size,
158 opdata->input_width,
159 input_data,
160 output_data,
161 threadpool);
162 break;
163 case xnn_operator_type_global_average_pooling_nwc_f32:
164 return xnn_setup_global_average_pooling_nwc_f32(
165 opdata->operator_objects[0],
166 opdata->batch_size,
167 opdata->input_width,
168 input_data,
169 output_data,
170 threadpool);
171 break;
172 #ifndef XNN_NO_F16_OPERATORS
173 case xnn_operator_type_global_average_pooling_nwc_f16:
174 return xnn_setup_global_average_pooling_nwc_f16(
175 opdata->operator_objects[0],
176 opdata->batch_size,
177 opdata->input_width,
178 input_data,
179 output_data,
180 threadpool);
181 break;
182 #endif // !defined(XNN_NO_F16_OPERATORS)
183 #ifndef XNN_NO_QS8_OPERATORS
184 case xnn_operator_type_global_average_pooling_nwc_qs8:
185 return xnn_setup_global_average_pooling_nwc_qs8(
186 opdata->operator_objects[0],
187 opdata->batch_size,
188 opdata->input_width,
189 input_data,
190 output_data,
191 threadpool);
192 break;
193 #endif // !defined(XNN_NO_QS8_OPERATORS)
194 #ifndef XNN_NO_QU8_OPERATORS
195 case xnn_operator_type_global_average_pooling_nwc_qu8:
196 return xnn_setup_global_average_pooling_nwc_qu8(
197 opdata->operator_objects[0],
198 opdata->batch_size,
199 opdata->input_width,
200 input_data,
201 output_data,
202 threadpool);
203 break;
204 #endif // !defined(XNN_NO_QU8_OPERATORS)
205 default:
206 XNN_UNREACHABLE;
207 }
208 }
209
define_global_average_pooling_nd(xnn_subgraph_t subgraph,enum xnn_node_type node_type,float output_min,float output_max,uint32_t input_id,uint32_t output_id,uint32_t flags)210 static enum xnn_status define_global_average_pooling_nd(
211 xnn_subgraph_t subgraph,
212 enum xnn_node_type node_type,
213 float output_min,
214 float output_max,
215 uint32_t input_id,
216 uint32_t output_id,
217 uint32_t flags)
218 {
219 enum xnn_status status;
220 if ((status = xnn_subgraph_check_xnnpack_initialized(node_type)) != xnn_status_success) {
221 return status;
222 }
223
224 status = xnn_subgraph_check_output_min_max(node_type, output_min, output_max);
225 if (status != xnn_status_success) {
226 return status;
227 }
228
229 status = xnn_subgraph_check_input_node_id(node_type, input_id, subgraph->num_values);
230 if (status != xnn_status_success) {
231 return status;
232 }
233
234 const struct xnn_value* input_value = &subgraph->values[input_id];
235 status = xnn_subgraph_check_input_type_dense(node_type, input_id, input_value);
236 if (status != xnn_status_success) {
237 return status;
238 }
239
240 switch (input_value->datatype) {
241 case xnn_datatype_fp32:
242 #ifndef XNN_NO_QS8_OPERATORS
243 case xnn_datatype_qint8:
244 #endif // !defined(XNN_NO_QS8_OPERATORS)
245 #ifndef XNN_NO_QU8_OPERATORS
246 case xnn_datatype_quint8:
247 #endif // !defined(XNN_NO_QU8_OPERATORS)
248 break;
249 default:
250 xnn_log_error(
251 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
252 xnn_node_type_to_string(node_type), input_id,
253 xnn_datatype_to_string(input_value->datatype), input_value->datatype);
254 return xnn_status_invalid_parameter;
255 }
256
257 status = xnn_subgraph_check_output_node_id(node_type, output_id, subgraph->num_values);
258 if (status != xnn_status_success) {
259 return status;
260 }
261
262 const struct xnn_value* output_value = &subgraph->values[output_id];
263 status = xnn_subgraph_check_output_type_dense(node_type, output_id, output_value);
264 if (status != xnn_status_success) {
265 return status;
266 }
267
268 enum xnn_compute_type compute_type = xnn_compute_type_invalid;
269 switch (output_value->datatype) {
270 case xnn_datatype_fp32:
271 compute_type = xnn_compute_type_fp32;
272 break;
273 #ifndef XNN_NO_QS8_OPERATORS
274 case xnn_datatype_qint8:
275 compute_type = xnn_compute_type_qs8;
276 break;
277 #endif // !defined(XNN_NO_QS8_OPERATORS)
278 #ifndef XNN_NO_QU8_OPERATORS
279 case xnn_datatype_quint8:
280 compute_type = xnn_compute_type_qu8;
281 break;
282 #endif // !defined(XNN_NO_QU8_OPERATORS)
283 default:
284 xnn_log_error(
285 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
286 xnn_node_type_to_string(node_type), output_id,
287 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
288 return xnn_status_invalid_parameter;
289 }
290
291 status = xnn_subgraph_check_datatype_matches(
292 node_type, input_id, input_value, output_id, output_value);
293 if (status != xnn_status_success) {
294 return status;
295 }
296
297 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
298 if (node == NULL) {
299 return xnn_status_out_of_memory;
300 }
301
302 node->type = node_type;
303 node->compute_type = compute_type;
304 node->activation.output_min = output_min;
305 node->activation.output_max = output_max;
306 node->num_inputs = 1;
307 node->inputs[0] = input_id;
308 node->num_outputs = 1;
309 node->outputs[0] = output_id;
310 node->flags = flags;
311
312 node->create = create_global_average_pooling_operator;
313 node->setup = setup_global_average_pooling_operator;
314
315 return xnn_status_success;
316 }
317
xnn_define_global_average_pooling_1d(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input_id,uint32_t output_id,uint32_t flags)318 enum xnn_status xnn_define_global_average_pooling_1d(
319 xnn_subgraph_t subgraph,
320 float output_min,
321 float output_max,
322 uint32_t input_id,
323 uint32_t output_id,
324 uint32_t flags)
325 {
326 return define_global_average_pooling_nd(
327 subgraph, xnn_node_type_global_average_pooling_1d, output_min, output_max, input_id, output_id, flags);
328 }
329
xnn_define_global_average_pooling_2d(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input_id,uint32_t output_id,uint32_t flags)330 enum xnn_status xnn_define_global_average_pooling_2d(
331 xnn_subgraph_t subgraph,
332 float output_min,
333 float output_max,
334 uint32_t input_id,
335 uint32_t output_id,
336 uint32_t flags)
337 {
338 return define_global_average_pooling_nd(
339 subgraph, xnn_node_type_global_average_pooling_2d, output_min, output_max, input_id, output_id, flags);
340 }
341