xref: /aosp_15_r20/external/XNNPACK/src/subgraph/global-average-pooling.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/requantization.h>
16 #include <xnnpack/subgraph.h>
17 #include <xnnpack/subgraph-validation.h>
18 
19 
create_global_average_pooling_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)20 static enum xnn_status create_global_average_pooling_operator(
21   const struct xnn_node* node,
22   const struct xnn_value* values,
23   size_t num_values,
24   struct xnn_operator_data* opdata,
25   const struct xnn_caches* caches)
26 {
27   assert(node->num_inputs == 1);
28   const uint32_t input_id = node->inputs[0];
29   assert(input_id != XNN_INVALID_VALUE_ID);
30   assert(input_id < num_values);
31 
32   assert(node->num_outputs == 1);
33   const uint32_t output_id = node->outputs[0];
34   assert(output_id != XNN_INVALID_VALUE_ID);
35   assert(output_id < num_values);
36 
37   const size_t num_input_dims = values[input_id].shape.num_dims;
38   assert(num_input_dims >= 1);
39   const size_t channel_dim = values[input_id].shape.dim[num_input_dims - 1];
40 
41   enum xnn_status status;
42   if (values[node->inputs[0]].layout == xnn_layout_type_nchw) {
43     assert(node->compute_type == xnn_compute_type_fp32);
44     status = xnn_create_global_average_pooling_ncw_f32(
45       channel_dim /* channels */,
46       node->activation.output_min,
47       node->activation.output_max,
48       node->flags,
49       &opdata->operator_objects[0]);
50   } else {
51     assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
52     assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
53     switch (node->compute_type) {
54       case xnn_compute_type_fp32:
55         status = xnn_create_global_average_pooling_nwc_f32(
56           channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
57           node->activation.output_min,
58           node->activation.output_max,
59           node->flags,
60           &opdata->operator_objects[0]);
61         break;
62 #ifndef XNN_NO_F16_OPERATORS
63       case xnn_compute_type_fp16:
64         status = xnn_create_global_average_pooling_nwc_f16(
65           channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
66           node->activation.output_min,
67           node->activation.output_max,
68           node->flags,
69           &opdata->operator_objects[0]);
70         break;
71 #endif  // !defined(XNN_NO_F16_OPERATORS)
72 #ifndef XNN_NO_QS8_OPERATORS
73       case xnn_compute_type_qs8:
74       {
75         const float output_scale = values[output_id].quantization.scale;
76         const int32_t output_zero_point = values[output_id].quantization.zero_point;
77         const int8_t output_min = xnn_qs8_quantize(node->activation.output_min, output_scale, output_zero_point);
78         const int8_t output_max = xnn_qs8_quantize(node->activation.output_max, output_scale, output_zero_point);
79         status = xnn_create_global_average_pooling_nwc_qs8(
80           channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
81           (int8_t) values[input_id].quantization.zero_point, values[input_id].quantization.scale,
82           (int8_t) values[output_id].quantization.zero_point, values[output_id].quantization.scale,
83           output_min,
84           output_max,
85           node->flags,
86           &opdata->operator_objects[0]);
87         break;
88       }
89 #endif  // !defined(XNN_NO_QS8_OPERATORS)
90 #ifndef XNN_NO_QU8_OPERATORS
91       case xnn_compute_type_qu8:
92       {
93         const float output_scale = values[output_id].quantization.scale;
94         const int32_t output_zero_point = values[output_id].quantization.zero_point;
95         const uint8_t output_min = xnn_qu8_quantize(node->activation.output_min, output_scale, output_zero_point);
96         const uint8_t output_max = xnn_qu8_quantize(node->activation.output_max, output_scale, output_zero_point);
97         status = xnn_create_global_average_pooling_nwc_qu8(
98           channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
99           (uint8_t) values[input_id].quantization.zero_point, values[input_id].quantization.scale,
100           (uint8_t) values[output_id].quantization.zero_point, values[output_id].quantization.scale,
101           output_min,
102           output_max,
103           node->flags,
104           &opdata->operator_objects[0]);
105         break;
106       }
107 #endif  // !defined(XNN_NO_QU8_OPERATORS)
108       default:
109         XNN_UNREACHABLE;
110     }
111   }
112   if (status == xnn_status_success) {
113     switch (node->type) {
114       case xnn_node_type_global_average_pooling_1d:
115         opdata->batch_size = xnn_shape_multiply_batch_dims(&values[input_id].shape, 2);
116         opdata->input_width = values[input_id].shape.dim[num_input_dims - 2];
117         break;
118       case xnn_node_type_global_average_pooling_2d:
119         opdata->batch_size = xnn_shape_multiply_batch_dims(&values[input_id].shape, 3);
120         opdata->input_width = values[input_id].shape.dim[num_input_dims - 3] * values[input_id].shape.dim[num_input_dims - 2];
121         break;
122       default:
123         XNN_UNREACHABLE;
124     }
125     opdata->inputs[0] = input_id;
126     opdata->outputs[0] = output_id;
127   }
128   return status;
129 }
130 
setup_global_average_pooling_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)131 static enum xnn_status setup_global_average_pooling_operator(
132   const struct xnn_operator_data* opdata,
133   const struct xnn_blob* blobs,
134   size_t num_blobs,
135   pthreadpool_t threadpool)
136 {
137   const uint32_t input_id = opdata->inputs[0];
138   assert(input_id != XNN_INVALID_VALUE_ID);
139   assert(input_id < num_blobs);
140 
141   const uint32_t output_id = opdata->outputs[0];
142   assert(output_id != XNN_INVALID_VALUE_ID);
143   assert(output_id < num_blobs);
144 
145   const struct xnn_blob* input_blob = blobs + input_id;
146   const void* input_data = input_blob->data;
147   assert(input_data != NULL);
148 
149   const struct xnn_blob* output_blob = blobs + output_id;
150   void* output_data = output_blob->data;
151   assert(output_data != NULL);
152 
153   switch (opdata->operator_objects[0]->type) {
154     case xnn_operator_type_global_average_pooling_ncw_f32:
155       return xnn_setup_global_average_pooling_ncw_f32(
156         opdata->operator_objects[0],
157         opdata->batch_size,
158         opdata->input_width,
159         input_data,
160         output_data,
161         threadpool);
162       break;
163     case xnn_operator_type_global_average_pooling_nwc_f32:
164       return xnn_setup_global_average_pooling_nwc_f32(
165         opdata->operator_objects[0],
166         opdata->batch_size,
167         opdata->input_width,
168         input_data,
169         output_data,
170         threadpool);
171       break;
172 #ifndef XNN_NO_F16_OPERATORS
173     case xnn_operator_type_global_average_pooling_nwc_f16:
174       return xnn_setup_global_average_pooling_nwc_f16(
175         opdata->operator_objects[0],
176         opdata->batch_size,
177         opdata->input_width,
178         input_data,
179         output_data,
180         threadpool);
181       break;
182 #endif  // !defined(XNN_NO_F16_OPERATORS)
183 #ifndef XNN_NO_QS8_OPERATORS
184     case xnn_operator_type_global_average_pooling_nwc_qs8:
185       return xnn_setup_global_average_pooling_nwc_qs8(
186         opdata->operator_objects[0],
187         opdata->batch_size,
188         opdata->input_width,
189         input_data,
190         output_data,
191         threadpool);
192       break;
193 #endif  // !defined(XNN_NO_QS8_OPERATORS)
194 #ifndef XNN_NO_QU8_OPERATORS
195     case xnn_operator_type_global_average_pooling_nwc_qu8:
196       return xnn_setup_global_average_pooling_nwc_qu8(
197         opdata->operator_objects[0],
198         opdata->batch_size,
199         opdata->input_width,
200         input_data,
201         output_data,
202         threadpool);
203       break;
204 #endif  // !defined(XNN_NO_QU8_OPERATORS)
205     default:
206       XNN_UNREACHABLE;
207   }
208 }
209 
define_global_average_pooling_nd(xnn_subgraph_t subgraph,enum xnn_node_type node_type,float output_min,float output_max,uint32_t input_id,uint32_t output_id,uint32_t flags)210 static enum xnn_status define_global_average_pooling_nd(
211   xnn_subgraph_t subgraph,
212   enum xnn_node_type node_type,
213   float output_min,
214   float output_max,
215   uint32_t input_id,
216   uint32_t output_id,
217   uint32_t flags)
218 {
219   enum xnn_status status;
220   if ((status = xnn_subgraph_check_xnnpack_initialized(node_type)) != xnn_status_success) {
221     return status;
222   }
223 
224   status = xnn_subgraph_check_output_min_max(node_type, output_min, output_max);
225   if (status != xnn_status_success) {
226     return status;
227   }
228 
229   status = xnn_subgraph_check_input_node_id(node_type, input_id, subgraph->num_values);
230   if (status != xnn_status_success) {
231     return status;
232   }
233 
234   const struct xnn_value* input_value = &subgraph->values[input_id];
235   status = xnn_subgraph_check_input_type_dense(node_type, input_id, input_value);
236   if (status != xnn_status_success) {
237     return status;
238   }
239 
240   switch (input_value->datatype) {
241     case xnn_datatype_fp32:
242 #ifndef XNN_NO_QS8_OPERATORS
243     case xnn_datatype_qint8:
244 #endif  // !defined(XNN_NO_QS8_OPERATORS)
245 #ifndef XNN_NO_QU8_OPERATORS
246     case xnn_datatype_quint8:
247 #endif  // !defined(XNN_NO_QU8_OPERATORS)
248       break;
249     default:
250       xnn_log_error(
251         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
252         xnn_node_type_to_string(node_type), input_id,
253         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
254       return xnn_status_invalid_parameter;
255   }
256 
257   status = xnn_subgraph_check_output_node_id(node_type, output_id, subgraph->num_values);
258   if (status != xnn_status_success) {
259     return status;
260   }
261 
262   const struct xnn_value* output_value = &subgraph->values[output_id];
263   status = xnn_subgraph_check_output_type_dense(node_type, output_id, output_value);
264   if (status != xnn_status_success) {
265     return status;
266   }
267 
268   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
269   switch (output_value->datatype) {
270     case xnn_datatype_fp32:
271       compute_type = xnn_compute_type_fp32;
272       break;
273 #ifndef XNN_NO_QS8_OPERATORS
274     case xnn_datatype_qint8:
275       compute_type = xnn_compute_type_qs8;
276       break;
277 #endif  // !defined(XNN_NO_QS8_OPERATORS)
278 #ifndef XNN_NO_QU8_OPERATORS
279     case xnn_datatype_quint8:
280       compute_type = xnn_compute_type_qu8;
281       break;
282 #endif  // !defined(XNN_NO_QU8_OPERATORS)
283     default:
284       xnn_log_error(
285         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
286         xnn_node_type_to_string(node_type), output_id,
287         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
288       return xnn_status_invalid_parameter;
289   }
290 
291   status = xnn_subgraph_check_datatype_matches(
292     node_type, input_id, input_value, output_id, output_value);
293   if (status != xnn_status_success) {
294     return status;
295   }
296 
297   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
298   if (node == NULL) {
299     return xnn_status_out_of_memory;
300   }
301 
302   node->type = node_type;
303   node->compute_type = compute_type;
304   node->activation.output_min = output_min;
305   node->activation.output_max = output_max;
306   node->num_inputs = 1;
307   node->inputs[0] = input_id;
308   node->num_outputs = 1;
309   node->outputs[0] = output_id;
310   node->flags = flags;
311 
312   node->create = create_global_average_pooling_operator;
313   node->setup = setup_global_average_pooling_operator;
314 
315   return xnn_status_success;
316 }
317 
xnn_define_global_average_pooling_1d(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input_id,uint32_t output_id,uint32_t flags)318 enum xnn_status xnn_define_global_average_pooling_1d(
319   xnn_subgraph_t subgraph,
320   float output_min,
321   float output_max,
322   uint32_t input_id,
323   uint32_t output_id,
324   uint32_t flags)
325 {
326   return define_global_average_pooling_nd(
327     subgraph, xnn_node_type_global_average_pooling_1d, output_min, output_max, input_id, output_id, flags);
328 }
329 
xnn_define_global_average_pooling_2d(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input_id,uint32_t output_id,uint32_t flags)330 enum xnn_status xnn_define_global_average_pooling_2d(
331   xnn_subgraph_t subgraph,
332   float output_min,
333   float output_max,
334   uint32_t input_id,
335   uint32_t output_id,
336   uint32_t flags)
337 {
338   return define_global_average_pooling_nd(
339     subgraph, xnn_node_type_global_average_pooling_2d, output_min, output_max, input_id, output_id, flags);
340 }
341