1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11
12 #include <xnnpack.h>
13 #include <xnnpack/log.h>
14 #include <xnnpack/operator.h>
15 #include <xnnpack/params.h>
16 #include <xnnpack/requantization.h>
17 #include <xnnpack/subgraph.h>
18 #include <xnnpack/subgraph-validation.h>
19
20
create_add_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)21 static enum xnn_status create_add_operator(
22 const struct xnn_node* node,
23 const struct xnn_value* values,
24 size_t num_values,
25 struct xnn_operator_data* opdata,
26 const struct xnn_caches* caches)
27 {
28 assert(node->num_inputs == 2);
29 const uint32_t input1_id = node->inputs[0];
30 assert(input1_id != XNN_INVALID_VALUE_ID);
31 assert(input1_id < num_values);
32 const uint32_t input2_id = node->inputs[1];
33 assert(input2_id != XNN_INVALID_VALUE_ID);
34 assert(input2_id < num_values);
35
36 assert(node->num_outputs == 1);
37 const uint32_t output_id = node->outputs[0];
38 assert(output_id != XNN_INVALID_VALUE_ID);
39 assert(output_id < num_values);
40
41 enum xnn_status status;
42 switch (node->compute_type) {
43 #ifndef XNN_NO_F16_OPERATORS
44 case xnn_compute_type_fp16:
45 status = xnn_create_add_nd_f16(
46 node->activation.output_min,
47 node->activation.output_max,
48 node->flags,
49 &opdata->operator_objects[0]);
50 break;
51 #endif // !defined(XNN_NO_F16_OPERATORS)
52 case xnn_compute_type_fp32:
53 status = xnn_create_add_nd_f32(
54 node->activation.output_min,
55 node->activation.output_max,
56 node->flags,
57 &opdata->operator_objects[0]);
58 break;
59 #ifndef XNN_NO_QS8_OPERATORS
60 case xnn_compute_type_qs8:
61 {
62 const float output_scale = values[output_id].quantization.scale;
63 const int32_t output_zero_point = values[output_id].quantization.zero_point;
64 const int8_t output_min = xnn_qs8_quantize(node->activation.output_min, output_scale, output_zero_point);
65 const int8_t output_max = xnn_qs8_quantize(node->activation.output_max, output_scale, output_zero_point);
66 status = xnn_create_add_nd_qs8(
67 (int8_t) values[input1_id].quantization.zero_point,
68 values[input1_id].quantization.scale,
69 (int8_t) values[input2_id].quantization.zero_point,
70 values[input2_id].quantization.scale,
71 (int8_t) output_zero_point,
72 output_scale, output_min, output_max, node->flags,
73 &opdata->operator_objects[0]);
74 break;
75 }
76 #endif // !defined(XNN_NO_QS8_OPERATORS)
77 #ifndef XNN_NO_QU8_OPERATORS
78 case xnn_compute_type_qu8:
79 {
80 const float output_scale = values[output_id].quantization.scale;
81 const int32_t output_zero_point = values[output_id].quantization.zero_point;
82 const uint8_t output_min = xnn_qu8_quantize(node->activation.output_min, output_scale, output_zero_point);
83 const uint8_t output_max = xnn_qu8_quantize(node->activation.output_max, output_scale, output_zero_point);
84 status = xnn_create_add_nd_qu8(
85 (uint8_t) values[input1_id].quantization.zero_point,
86 values[input1_id].quantization.scale,
87 (uint8_t) values[input2_id].quantization.zero_point,
88 values[input2_id].quantization.scale,
89 (uint8_t) output_zero_point,
90 output_scale, output_min, output_max, node->flags,
91 &opdata->operator_objects[0]);
92 break;
93 }
94 #endif // !defined(XNN_NO_QU8_OPERATORS)
95 default:
96 XNN_UNREACHABLE;
97 }
98 if (status == xnn_status_success) {
99 opdata->shape1.num_dims = values[input1_id].shape.num_dims;
100 opdata->shape2.num_dims = values[input2_id].shape.num_dims;
101 if (values[output_id].layout == xnn_layout_type_nchw) {
102 assert(values[input1_id].layout == xnn_layout_type_nchw);
103 assert(values[input2_id].layout == xnn_layout_type_nchw);
104 opdata->shape1.dim[0] = values[input1_id].shape.dim[0];
105 opdata->shape1.dim[1] = values[input1_id].shape.dim[values[input1_id].shape.num_dims - 1];
106 if (values[input1_id].shape.num_dims > 2) {
107 memcpy(&opdata->shape1.dim[2], &values[input1_id].shape.dim[1], (values[input1_id].shape.num_dims - 2) * sizeof(size_t));
108 }
109 opdata->shape2.dim[0] = values[input2_id].shape.dim[0];
110 opdata->shape2.dim[1] = values[input2_id].shape.dim[values[input2_id].shape.num_dims - 1];
111 if (values[input1_id].shape.num_dims > 2) {
112 memcpy(&opdata->shape2.dim[2], &values[input2_id].shape.dim[1], (values[input2_id].shape.num_dims - 2) * sizeof(size_t));
113 }
114 } else {
115 assert(values[output_id].layout == xnn_layout_type_nhwc);
116 assert(values[input1_id].layout == xnn_layout_type_nhwc);
117 assert(values[input2_id].layout == xnn_layout_type_nhwc);
118 memcpy(opdata->shape1.dim, values[input1_id].shape.dim, values[input1_id].shape.num_dims * sizeof(size_t));
119 memcpy(opdata->shape2.dim, values[input2_id].shape.dim, values[input2_id].shape.num_dims * sizeof(size_t));
120 }
121 opdata->inputs[0] = input1_id;
122 opdata->inputs[1] = input2_id;
123 opdata->outputs[0] = output_id;
124 }
125 return status;
126 }
127
setup_add_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)128 static enum xnn_status setup_add_operator(
129 const struct xnn_operator_data* opdata,
130 const struct xnn_blob* blobs,
131 size_t num_blobs,
132 pthreadpool_t threadpool)
133 {
134 const uint32_t input1_id = opdata->inputs[0];
135 assert(input1_id != XNN_INVALID_VALUE_ID);
136 assert(input1_id < num_blobs);
137
138 const uint32_t input2_id = opdata->inputs[1];
139 assert(input2_id != XNN_INVALID_VALUE_ID);
140 assert(input2_id < num_blobs);
141
142 const uint32_t output_id = opdata->outputs[0];
143 assert(output_id != XNN_INVALID_VALUE_ID);
144 assert(output_id < num_blobs);
145
146 const struct xnn_blob* input1_blob = blobs + input1_id;
147 const void* input1_data = input1_blob->data;
148 assert(input1_data != NULL);
149
150 const struct xnn_blob* input2_blob = blobs + input2_id;
151 const void* input2_data = input2_blob->data;
152 assert(input2_data != NULL);
153
154 const struct xnn_blob* output_blob = blobs + output_id;
155 void* output_data = output_blob->data;
156 assert(output_data != NULL);
157
158 switch (opdata->operator_objects[0]->type) {
159 case xnn_operator_type_add_nd_f32:
160 return xnn_setup_add_nd_f32(
161 opdata->operator_objects[0],
162 opdata->shape1.num_dims,
163 opdata->shape1.dim,
164 opdata->shape2.num_dims,
165 opdata->shape2.dim,
166 input1_data, input2_data, output_data,
167 threadpool);
168 #ifndef XNN_NO_F16_OPERATORS
169 case xnn_operator_type_add_nd_f16:
170 return xnn_setup_add_nd_f16(
171 opdata->operator_objects[0],
172 opdata->shape1.num_dims,
173 opdata->shape1.dim,
174 opdata->shape2.num_dims,
175 opdata->shape2.dim,
176 input1_data, input2_data, output_data,
177 threadpool);
178 #endif // !defined(XNN_NO_F16_OPERATORS)
179 #ifndef XNN_NO_QS8_OPERATORS
180 case xnn_operator_type_add_nd_qs8:
181 return xnn_setup_add_nd_qs8(
182 opdata->operator_objects[0],
183 opdata->shape1.num_dims,
184 opdata->shape1.dim,
185 opdata->shape2.num_dims,
186 opdata->shape2.dim,
187 input1_data, input2_data, output_data,
188 threadpool);
189 #endif // !defined(XNN_NO_QS8_OPERATORS)
190 #ifndef XNN_NO_QU8_OPERATORS
191 case xnn_operator_type_add_nd_qu8:
192 return xnn_setup_add_nd_qu8(
193 opdata->operator_objects[0],
194 opdata->shape1.num_dims,
195 opdata->shape1.dim,
196 opdata->shape2.num_dims,
197 opdata->shape2.dim,
198 input1_data, input2_data, output_data,
199 threadpool);
200 #endif // !defined(XNN_NO_QU8_OPERATORS)
201 default:
202 XNN_UNREACHABLE;
203 }
204 }
205
xnn_define_add2(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input1_id,uint32_t input2_id,uint32_t output_id,uint32_t flags)206 enum xnn_status xnn_define_add2(
207 xnn_subgraph_t subgraph,
208 float output_min,
209 float output_max,
210 uint32_t input1_id,
211 uint32_t input2_id,
212 uint32_t output_id,
213 uint32_t flags)
214 {
215 enum xnn_status status;
216 if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_add2)) != xnn_status_success) {
217 return status;
218 }
219
220 status = xnn_subgraph_check_output_min_max(xnn_node_type_add2, output_min, output_max);
221 if (status != xnn_status_success) {
222 return status;
223 }
224
225 if ((status = xnn_subgraph_check_nth_input_node_id(xnn_node_type_add2, input1_id, subgraph->num_values, 1)) !=
226 xnn_status_success) {
227 return status;
228 }
229
230 const struct xnn_value* input1_value = &subgraph->values[input1_id];
231 status = xnn_subgraph_check_nth_input_type_dense(xnn_node_type_add2, input1_id, input1_value, 1);
232 if (status != xnn_status_success) {
233 return status;
234 }
235
236 switch (input1_value->datatype) {
237 case xnn_datatype_fp32:
238 #ifndef XNN_NO_QS8_OPERATORS
239 case xnn_datatype_qint8:
240 #endif // !defined(XNN_NO_QS8_OPERATORS)
241 #ifndef XNN_NO_QU8_OPERATORS
242 case xnn_datatype_quint8:
243 #endif // !defined(XNN_NO_QU8_OPERATORS)
244 break;
245 default:
246 xnn_log_error(
247 "failed to define %s operator with the first input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
248 xnn_node_type_to_string(xnn_node_type_add2), input1_id,
249 xnn_datatype_to_string(input1_value->datatype), input1_value->datatype);
250 return xnn_status_invalid_parameter;
251 }
252
253 if ((status = xnn_subgraph_check_nth_input_node_id(xnn_node_type_add2, input2_id, subgraph->num_values, 2)) !=
254 xnn_status_success) {
255 return status;
256 }
257
258 const struct xnn_value* input2_value = &subgraph->values[input2_id];
259 status = xnn_subgraph_check_nth_input_type_dense(xnn_node_type_add2, input2_id, input2_value, 2);
260 if (status != xnn_status_success) {
261 return status;
262 }
263
264 switch (input2_value->datatype) {
265 case xnn_datatype_fp32:
266 #ifndef XNN_NO_QS8_OPERATORS
267 case xnn_datatype_qint8:
268 #endif // !defined(XNN_NO_QS8_OPERATORS)
269 #ifndef XNN_NO_QU8_OPERATORS
270 case xnn_datatype_quint8:
271 #endif // !defined(XNN_NO_QU8_OPERATORS)
272 break;
273 default:
274 xnn_log_error(
275 "failed to define %s operator with the second input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
276 xnn_node_type_to_string(xnn_node_type_add2), input2_id,
277 xnn_datatype_to_string(input2_value->datatype), input2_value->datatype);
278 return xnn_status_invalid_parameter;
279 }
280
281 status = xnn_subgraph_check_output_node_id(xnn_node_type_add2, output_id, subgraph->num_values);
282 if (status != xnn_status_success) {
283 return status;
284 }
285
286 const struct xnn_value* output_value = &subgraph->values[output_id];
287 status = xnn_subgraph_check_output_type_dense(xnn_node_type_add2, output_id, output_value);
288 if (status != xnn_status_success) {
289 return status;
290 }
291
292 enum xnn_compute_type compute_type = xnn_compute_type_invalid;
293 switch (output_value->datatype) {
294 case xnn_datatype_fp32:
295 compute_type = xnn_compute_type_fp32;
296 break;
297 #ifndef XNN_NO_QS8_OPERATORS
298 case xnn_datatype_qint8:
299 compute_type = xnn_compute_type_qs8;
300 break;
301 #endif // !defined(XNN_NO_QS8_OPERATORS)
302 #ifndef XNN_NO_QU8_OPERATORS
303 case xnn_datatype_quint8:
304 compute_type = xnn_compute_type_qu8;
305 break;
306 #endif // !defined(XNN_NO_QU8_OPERATORS)
307 default:
308 xnn_log_error(
309 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
310 xnn_node_type_to_string(xnn_node_type_add2), output_id,
311 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
312 return xnn_status_invalid_parameter;
313 }
314
315 status = xnn_subgraph_check_datatype_matches_two_inputs(
316 xnn_node_type_add2, input1_id, input1_value, input2_id, input2_value, output_id, output_value);
317 if (status != xnn_status_success) {
318 return status;
319 }
320
321 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
322 if (node == NULL) {
323 return xnn_status_out_of_memory;
324 }
325
326 node->type = xnn_node_type_add2;
327 node->compute_type = compute_type;
328 node->activation.output_min = output_min;
329 node->activation.output_max = output_max;
330 node->num_inputs = 2;
331 node->inputs[0] = input1_id;
332 node->inputs[1] = input2_id;
333 node->num_outputs = 1;
334 node->outputs[0] = output_id;
335 node->flags = flags;
336
337 node->create = create_add_operator;
338 node->setup = setup_add_operator;
339
340 return xnn_status_success;
341 }
342