xref: /aosp_15_r20/external/XNNPACK/src/subgraph/add2.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11 
12 #include <xnnpack.h>
13 #include <xnnpack/log.h>
14 #include <xnnpack/operator.h>
15 #include <xnnpack/params.h>
16 #include <xnnpack/requantization.h>
17 #include <xnnpack/subgraph.h>
18 #include <xnnpack/subgraph-validation.h>
19 
20 
create_add_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)21 static enum xnn_status create_add_operator(
22   const struct xnn_node* node,
23   const struct xnn_value* values,
24   size_t num_values,
25   struct xnn_operator_data* opdata,
26   const struct xnn_caches* caches)
27 {
28   assert(node->num_inputs == 2);
29   const uint32_t input1_id = node->inputs[0];
30   assert(input1_id != XNN_INVALID_VALUE_ID);
31   assert(input1_id < num_values);
32   const uint32_t input2_id = node->inputs[1];
33   assert(input2_id != XNN_INVALID_VALUE_ID);
34   assert(input2_id < num_values);
35 
36   assert(node->num_outputs == 1);
37   const uint32_t output_id = node->outputs[0];
38   assert(output_id != XNN_INVALID_VALUE_ID);
39   assert(output_id < num_values);
40 
41   enum xnn_status status;
42   switch (node->compute_type) {
43 #ifndef XNN_NO_F16_OPERATORS
44     case xnn_compute_type_fp16:
45       status = xnn_create_add_nd_f16(
46         node->activation.output_min,
47         node->activation.output_max,
48         node->flags,
49         &opdata->operator_objects[0]);
50       break;
51 #endif  // !defined(XNN_NO_F16_OPERATORS)
52     case xnn_compute_type_fp32:
53       status = xnn_create_add_nd_f32(
54         node->activation.output_min,
55         node->activation.output_max,
56         node->flags,
57         &opdata->operator_objects[0]);
58       break;
59 #ifndef XNN_NO_QS8_OPERATORS
60     case xnn_compute_type_qs8:
61     {
62       const float output_scale = values[output_id].quantization.scale;
63       const int32_t output_zero_point = values[output_id].quantization.zero_point;
64       const int8_t output_min = xnn_qs8_quantize(node->activation.output_min, output_scale, output_zero_point);
65       const int8_t output_max = xnn_qs8_quantize(node->activation.output_max, output_scale, output_zero_point);
66       status = xnn_create_add_nd_qs8(
67         (int8_t) values[input1_id].quantization.zero_point,
68         values[input1_id].quantization.scale,
69         (int8_t) values[input2_id].quantization.zero_point,
70         values[input2_id].quantization.scale,
71         (int8_t) output_zero_point,
72         output_scale, output_min, output_max, node->flags,
73         &opdata->operator_objects[0]);
74       break;
75     }
76 #endif  // !defined(XNN_NO_QS8_OPERATORS)
77 #ifndef XNN_NO_QU8_OPERATORS
78     case xnn_compute_type_qu8:
79     {
80       const float output_scale = values[output_id].quantization.scale;
81       const int32_t output_zero_point = values[output_id].quantization.zero_point;
82       const uint8_t output_min = xnn_qu8_quantize(node->activation.output_min, output_scale, output_zero_point);
83       const uint8_t output_max = xnn_qu8_quantize(node->activation.output_max, output_scale, output_zero_point);
84       status = xnn_create_add_nd_qu8(
85         (uint8_t) values[input1_id].quantization.zero_point,
86         values[input1_id].quantization.scale,
87         (uint8_t) values[input2_id].quantization.zero_point,
88         values[input2_id].quantization.scale,
89         (uint8_t) output_zero_point,
90         output_scale, output_min, output_max, node->flags,
91         &opdata->operator_objects[0]);
92       break;
93     }
94 #endif  // !defined(XNN_NO_QU8_OPERATORS)
95     default:
96       XNN_UNREACHABLE;
97   }
98   if (status == xnn_status_success) {
99     opdata->shape1.num_dims = values[input1_id].shape.num_dims;
100     opdata->shape2.num_dims = values[input2_id].shape.num_dims;
101     if (values[output_id].layout == xnn_layout_type_nchw) {
102       assert(values[input1_id].layout == xnn_layout_type_nchw);
103       assert(values[input2_id].layout == xnn_layout_type_nchw);
104       opdata->shape1.dim[0] = values[input1_id].shape.dim[0];
105       opdata->shape1.dim[1] = values[input1_id].shape.dim[values[input1_id].shape.num_dims - 1];
106       if (values[input1_id].shape.num_dims > 2) {
107         memcpy(&opdata->shape1.dim[2], &values[input1_id].shape.dim[1], (values[input1_id].shape.num_dims - 2) * sizeof(size_t));
108       }
109       opdata->shape2.dim[0] = values[input2_id].shape.dim[0];
110       opdata->shape2.dim[1] = values[input2_id].shape.dim[values[input2_id].shape.num_dims - 1];
111       if (values[input1_id].shape.num_dims > 2) {
112         memcpy(&opdata->shape2.dim[2], &values[input2_id].shape.dim[1], (values[input2_id].shape.num_dims - 2) * sizeof(size_t));
113       }
114     } else {
115       assert(values[output_id].layout == xnn_layout_type_nhwc);
116       assert(values[input1_id].layout == xnn_layout_type_nhwc);
117       assert(values[input2_id].layout == xnn_layout_type_nhwc);
118       memcpy(opdata->shape1.dim, values[input1_id].shape.dim, values[input1_id].shape.num_dims * sizeof(size_t));
119       memcpy(opdata->shape2.dim, values[input2_id].shape.dim, values[input2_id].shape.num_dims * sizeof(size_t));
120     }
121     opdata->inputs[0] = input1_id;
122     opdata->inputs[1] = input2_id;
123     opdata->outputs[0] = output_id;
124   }
125   return status;
126 }
127 
setup_add_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)128 static enum xnn_status setup_add_operator(
129   const struct xnn_operator_data* opdata,
130   const struct xnn_blob* blobs,
131   size_t num_blobs,
132   pthreadpool_t threadpool)
133 {
134   const uint32_t input1_id = opdata->inputs[0];
135   assert(input1_id != XNN_INVALID_VALUE_ID);
136   assert(input1_id < num_blobs);
137 
138   const uint32_t input2_id = opdata->inputs[1];
139   assert(input2_id != XNN_INVALID_VALUE_ID);
140   assert(input2_id < num_blobs);
141 
142   const uint32_t output_id = opdata->outputs[0];
143   assert(output_id != XNN_INVALID_VALUE_ID);
144   assert(output_id < num_blobs);
145 
146   const struct xnn_blob* input1_blob = blobs + input1_id;
147   const void* input1_data = input1_blob->data;
148   assert(input1_data != NULL);
149 
150   const struct xnn_blob* input2_blob = blobs + input2_id;
151   const void* input2_data = input2_blob->data;
152   assert(input2_data != NULL);
153 
154   const struct xnn_blob* output_blob = blobs + output_id;
155   void* output_data = output_blob->data;
156   assert(output_data != NULL);
157 
158   switch (opdata->operator_objects[0]->type) {
159     case xnn_operator_type_add_nd_f32:
160       return xnn_setup_add_nd_f32(
161         opdata->operator_objects[0],
162         opdata->shape1.num_dims,
163         opdata->shape1.dim,
164         opdata->shape2.num_dims,
165         opdata->shape2.dim,
166         input1_data, input2_data, output_data,
167         threadpool);
168 #ifndef XNN_NO_F16_OPERATORS
169     case xnn_operator_type_add_nd_f16:
170       return xnn_setup_add_nd_f16(
171         opdata->operator_objects[0],
172         opdata->shape1.num_dims,
173         opdata->shape1.dim,
174         opdata->shape2.num_dims,
175         opdata->shape2.dim,
176         input1_data, input2_data, output_data,
177         threadpool);
178 #endif  // !defined(XNN_NO_F16_OPERATORS)
179 #ifndef XNN_NO_QS8_OPERATORS
180     case xnn_operator_type_add_nd_qs8:
181       return xnn_setup_add_nd_qs8(
182         opdata->operator_objects[0],
183         opdata->shape1.num_dims,
184         opdata->shape1.dim,
185         opdata->shape2.num_dims,
186         opdata->shape2.dim,
187         input1_data, input2_data, output_data,
188         threadpool);
189 #endif  // !defined(XNN_NO_QS8_OPERATORS)
190 #ifndef XNN_NO_QU8_OPERATORS
191     case xnn_operator_type_add_nd_qu8:
192       return xnn_setup_add_nd_qu8(
193         opdata->operator_objects[0],
194         opdata->shape1.num_dims,
195         opdata->shape1.dim,
196         opdata->shape2.num_dims,
197         opdata->shape2.dim,
198         input1_data, input2_data, output_data,
199         threadpool);
200 #endif  // !defined(XNN_NO_QU8_OPERATORS)
201     default:
202       XNN_UNREACHABLE;
203   }
204 }
205 
xnn_define_add2(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input1_id,uint32_t input2_id,uint32_t output_id,uint32_t flags)206 enum xnn_status xnn_define_add2(
207   xnn_subgraph_t subgraph,
208   float output_min,
209   float output_max,
210   uint32_t input1_id,
211   uint32_t input2_id,
212   uint32_t output_id,
213   uint32_t flags)
214 {
215   enum xnn_status status;
216   if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_add2)) != xnn_status_success) {
217     return status;
218   }
219 
220   status = xnn_subgraph_check_output_min_max(xnn_node_type_add2, output_min, output_max);
221   if (status != xnn_status_success) {
222     return status;
223   }
224 
225   if ((status = xnn_subgraph_check_nth_input_node_id(xnn_node_type_add2, input1_id, subgraph->num_values, 1)) !=
226       xnn_status_success) {
227     return status;
228   }
229 
230   const struct xnn_value* input1_value = &subgraph->values[input1_id];
231   status = xnn_subgraph_check_nth_input_type_dense(xnn_node_type_add2, input1_id, input1_value, 1);
232   if (status != xnn_status_success) {
233     return status;
234   }
235 
236   switch (input1_value->datatype) {
237     case xnn_datatype_fp32:
238 #ifndef XNN_NO_QS8_OPERATORS
239     case xnn_datatype_qint8:
240 #endif  // !defined(XNN_NO_QS8_OPERATORS)
241 #ifndef XNN_NO_QU8_OPERATORS
242     case xnn_datatype_quint8:
243 #endif  // !defined(XNN_NO_QU8_OPERATORS)
244       break;
245     default:
246       xnn_log_error(
247         "failed to define %s operator with the first input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
248         xnn_node_type_to_string(xnn_node_type_add2), input1_id,
249         xnn_datatype_to_string(input1_value->datatype), input1_value->datatype);
250       return xnn_status_invalid_parameter;
251   }
252 
253   if ((status = xnn_subgraph_check_nth_input_node_id(xnn_node_type_add2, input2_id, subgraph->num_values, 2)) !=
254       xnn_status_success) {
255     return status;
256   }
257 
258   const struct xnn_value* input2_value = &subgraph->values[input2_id];
259   status = xnn_subgraph_check_nth_input_type_dense(xnn_node_type_add2, input2_id, input2_value, 2);
260   if (status != xnn_status_success) {
261     return status;
262   }
263 
264   switch (input2_value->datatype) {
265     case xnn_datatype_fp32:
266 #ifndef XNN_NO_QS8_OPERATORS
267     case xnn_datatype_qint8:
268 #endif  // !defined(XNN_NO_QS8_OPERATORS)
269 #ifndef XNN_NO_QU8_OPERATORS
270     case xnn_datatype_quint8:
271 #endif  // !defined(XNN_NO_QU8_OPERATORS)
272       break;
273     default:
274       xnn_log_error(
275         "failed to define %s operator with the second input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
276         xnn_node_type_to_string(xnn_node_type_add2), input2_id,
277         xnn_datatype_to_string(input2_value->datatype), input2_value->datatype);
278       return xnn_status_invalid_parameter;
279   }
280 
281   status = xnn_subgraph_check_output_node_id(xnn_node_type_add2, output_id, subgraph->num_values);
282   if (status != xnn_status_success) {
283     return status;
284   }
285 
286   const struct xnn_value* output_value = &subgraph->values[output_id];
287   status = xnn_subgraph_check_output_type_dense(xnn_node_type_add2, output_id, output_value);
288   if (status != xnn_status_success) {
289     return status;
290   }
291 
292   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
293   switch (output_value->datatype) {
294     case xnn_datatype_fp32:
295       compute_type = xnn_compute_type_fp32;
296       break;
297 #ifndef XNN_NO_QS8_OPERATORS
298     case xnn_datatype_qint8:
299       compute_type = xnn_compute_type_qs8;
300       break;
301 #endif  // !defined(XNN_NO_QS8_OPERATORS)
302 #ifndef XNN_NO_QU8_OPERATORS
303     case xnn_datatype_quint8:
304       compute_type = xnn_compute_type_qu8;
305       break;
306 #endif  // !defined(XNN_NO_QU8_OPERATORS)
307     default:
308       xnn_log_error(
309         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
310         xnn_node_type_to_string(xnn_node_type_add2), output_id,
311         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
312       return xnn_status_invalid_parameter;
313   }
314 
315   status = xnn_subgraph_check_datatype_matches_two_inputs(
316       xnn_node_type_add2, input1_id, input1_value, input2_id, input2_value, output_id, output_value);
317   if (status != xnn_status_success) {
318     return status;
319   }
320 
321   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
322   if (node == NULL) {
323     return xnn_status_out_of_memory;
324   }
325 
326   node->type = xnn_node_type_add2;
327   node->compute_type = compute_type;
328   node->activation.output_min = output_min;
329   node->activation.output_max = output_max;
330   node->num_inputs = 2;
331   node->inputs[0] = input1_id;
332   node->inputs[1] = input2_id;
333   node->num_outputs = 1;
334   node->outputs[0] = output_id;
335   node->flags = flags;
336 
337   node->create = create_add_operator;
338   node->setup = setup_add_operator;
339 
340   return xnn_status_success;
341 }
342