1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11
12 #include <xnnpack.h>
13 #include <xnnpack/log.h>
14 #include <xnnpack/operator.h>
15 #include <xnnpack/params.h>
16 #include <xnnpack/requantization.h>
17 #include <xnnpack/subgraph.h>
18 #include <xnnpack/subgraph-validation.h>
19
20
create_multiply_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)21 static enum xnn_status create_multiply_operator(
22 const struct xnn_node* node,
23 const struct xnn_value* values,
24 size_t num_values,
25 struct xnn_operator_data* opdata,
26 const struct xnn_caches* caches)
27 {
28 assert(node->num_inputs == 2);
29 const uint32_t input1_id = node->inputs[0];
30 assert(input1_id != XNN_INVALID_VALUE_ID);
31 assert(input1_id < num_values);
32 const uint32_t input2_id = node->inputs[1];
33 assert(input2_id != XNN_INVALID_VALUE_ID);
34 assert(input2_id < num_values);
35
36 assert(node->num_outputs == 1);
37 const uint32_t output_id = node->outputs[0];
38 assert(output_id != XNN_INVALID_VALUE_ID);
39 assert(output_id < num_values);
40
41 enum xnn_status status;
42 switch (node->compute_type) {
43 #ifndef XNN_NO_F16_OPERATORS
44 case xnn_compute_type_fp16:
45 status = xnn_create_multiply_nd_f16(
46 node->activation.output_min,
47 node->activation.output_max,
48 node->flags,
49 &opdata->operator_objects[0]);
50 break;
51 #endif // XNN_NO_F16_OPERATORS
52 case xnn_compute_type_fp32:
53 status = xnn_create_multiply_nd_f32(
54 node->activation.output_min,
55 node->activation.output_max,
56 node->flags,
57 &opdata->operator_objects[0]);
58 break;
59 #ifndef XNN_NO_QS8_OPERATORS
60 case xnn_compute_type_qs8:
61 {
62 const float output_scale = values[output_id].quantization.scale;
63 const int32_t output_zero_point = values[output_id].quantization.zero_point;
64 const int8_t output_min = xnn_qs8_quantize(node->activation.output_min, output_scale, output_zero_point);
65 const int8_t output_max = xnn_qs8_quantize(node->activation.output_max, output_scale, output_zero_point);
66 status = xnn_create_multiply_nd_qs8(
67 (int8_t) values[input1_id].quantization.zero_point,
68 values[input1_id].quantization.scale,
69 (int8_t) values[input2_id].quantization.zero_point,
70 values[input2_id].quantization.scale,
71 (int8_t) output_zero_point,
72 output_scale, output_min, output_max, node->flags,
73 &opdata->operator_objects[0]);
74 break;
75 }
76 #endif // !defined(XNN_NO_QS8_OPERATORS)
77 #ifndef XNN_NO_QU8_OPERATORS
78 case xnn_compute_type_qu8:
79 {
80 const float output_scale = values[output_id].quantization.scale;
81 const int32_t output_zero_point = values[output_id].quantization.zero_point;
82 const uint8_t output_min = xnn_qu8_quantize(node->activation.output_min, output_scale, output_zero_point);
83 const uint8_t output_max = xnn_qu8_quantize(node->activation.output_max, output_scale, output_zero_point);
84 status = xnn_create_multiply_nd_qu8(
85 (uint8_t) values[input1_id].quantization.zero_point,
86 values[input1_id].quantization.scale,
87 (uint8_t) values[input2_id].quantization.zero_point,
88 values[input2_id].quantization.scale,
89 (uint8_t) output_zero_point,
90 output_scale, output_min, output_max, node->flags,
91 &opdata->operator_objects[0]);
92 break;
93 }
94 #endif // !defined(XNN_NO_QU8_OPERATORS)
95 default:
96 XNN_UNREACHABLE;
97 }
98 if (status == xnn_status_success) {
99 opdata->shape1.num_dims = values[input1_id].shape.num_dims;
100 opdata->shape2.num_dims = values[input2_id].shape.num_dims;
101 if (values[output_id].layout == xnn_layout_type_nchw) {
102 assert(values[input1_id].layout == xnn_layout_type_nchw);
103 assert(values[input2_id].layout == xnn_layout_type_nchw);
104 opdata->shape1.dim[0] = values[input1_id].shape.dim[0];
105 opdata->shape1.dim[1] = values[input1_id].shape.dim[values[input1_id].shape.num_dims - 1];
106 if (values[input1_id].shape.num_dims > 2) {
107 memcpy(&opdata->shape1.dim[2], &values[input1_id].shape.dim[1], (values[input1_id].shape.num_dims - 2) * sizeof(size_t));
108 }
109 opdata->shape2.dim[0] = values[input2_id].shape.dim[0];
110 opdata->shape2.dim[1] = values[input2_id].shape.dim[values[input2_id].shape.num_dims - 1];
111 if (values[input1_id].shape.num_dims > 2) {
112 memcpy(&opdata->shape2.dim[2], &values[input2_id].shape.dim[1], (values[input2_id].shape.num_dims - 2) * sizeof(size_t));
113 }
114 } else {
115 assert(values[output_id].layout == xnn_layout_type_nhwc);
116 assert(values[input1_id].layout == xnn_layout_type_nhwc);
117 assert(values[input2_id].layout == xnn_layout_type_nhwc);
118 memcpy(opdata->shape1.dim, values[input1_id].shape.dim, values[input1_id].shape.num_dims * sizeof(size_t));
119 memcpy(opdata->shape2.dim, values[input2_id].shape.dim, values[input2_id].shape.num_dims * sizeof(size_t));
120 }
121 opdata->inputs[0] = input1_id;
122 opdata->inputs[1] = input2_id;
123 opdata->outputs[0] = output_id;
124 }
125 return status;
126 }
127
setup_multiply_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)128 static enum xnn_status setup_multiply_operator(
129 const struct xnn_operator_data* opdata,
130 const struct xnn_blob* blobs,
131 size_t num_blobs,
132 pthreadpool_t threadpool)
133 {
134 const uint32_t input1_id = opdata->inputs[0];
135 assert(input1_id != XNN_INVALID_VALUE_ID);
136 assert(input1_id < num_blobs);
137
138 const uint32_t input2_id = opdata->inputs[1];
139 assert(input2_id != XNN_INVALID_VALUE_ID);
140 assert(input2_id < num_blobs);
141
142 const uint32_t output_id = opdata->outputs[0];
143 assert(output_id != XNN_INVALID_VALUE_ID);
144 assert(output_id < num_blobs);
145
146 const struct xnn_blob* input1_blob = blobs + input1_id;
147 const void* input1_data = input1_blob->data;
148 assert(input1_data != NULL);
149
150 const struct xnn_blob* input2_blob = blobs + input2_id;
151 const void* input2_data = input2_blob->data;
152 assert(input2_data != NULL);
153
154 const struct xnn_blob* output_blob = blobs + output_id;
155 void* output_data = output_blob->data;
156 assert(output_data != NULL);
157
158 switch (opdata->operator_objects[0]->type) {
159 #ifndef XNN_NO_F16_OPERATORS
160 case xnn_operator_type_multiply_nd_f16:
161 return xnn_setup_multiply_nd_f16(
162 opdata->operator_objects[0],
163 opdata->shape1.num_dims,
164 opdata->shape1.dim,
165 opdata->shape2.num_dims,
166 opdata->shape2.dim,
167 input1_data, input2_data, output_data,
168 threadpool);
169 break;
170 #endif // !defined(XNN_NO_F16_OPERATORS)
171 case xnn_operator_type_multiply_nd_f32:
172 return xnn_setup_multiply_nd_f32(
173 opdata->operator_objects[0],
174 opdata->shape1.num_dims,
175 opdata->shape1.dim,
176 opdata->shape2.num_dims,
177 opdata->shape2.dim,
178 input1_data, input2_data, output_data,
179 threadpool);
180 break;
181 #ifndef XNN_NO_QS8_OPERATORS
182 case xnn_operator_type_multiply_nd_qs8:
183 return xnn_setup_multiply_nd_qs8(
184 opdata->operator_objects[0],
185 opdata->shape1.num_dims,
186 opdata->shape1.dim,
187 opdata->shape2.num_dims,
188 opdata->shape2.dim,
189 input1_data, input2_data, output_data,
190 threadpool);
191 break;
192 #endif // !defined(XNN_NO_QS8_OPERATORS)
193 #ifndef XNN_NO_QU8_OPERATORS
194 case xnn_operator_type_multiply_nd_qu8:
195 return xnn_setup_multiply_nd_qu8(
196 opdata->operator_objects[0],
197 opdata->shape1.num_dims,
198 opdata->shape1.dim,
199 opdata->shape2.num_dims,
200 opdata->shape2.dim,
201 input1_data, input2_data, output_data,
202 threadpool);
203 break;
204 #endif // !defined(XNN_NO_QU8_OPERATORS)
205 default:
206 XNN_UNREACHABLE;
207 }
208 }
209
xnn_define_multiply2(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input1_id,uint32_t input2_id,uint32_t output_id,uint32_t flags)210 enum xnn_status xnn_define_multiply2(
211 xnn_subgraph_t subgraph,
212 float output_min,
213 float output_max,
214 uint32_t input1_id,
215 uint32_t input2_id,
216 uint32_t output_id,
217 uint32_t flags)
218 {
219 enum xnn_status status;
220 if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_multiply2)) != xnn_status_success) {
221 return status;
222 }
223
224 status = xnn_subgraph_check_output_min_max(xnn_node_type_multiply2, output_min, output_max);
225 if (status != xnn_status_success) {
226 return status;
227 }
228
229 if ((status = xnn_subgraph_check_nth_input_node_id(
230 xnn_node_type_multiply2, input1_id, subgraph->num_values, 1)) != xnn_status_success) {
231 return status;
232 }
233
234 const struct xnn_value* input1_value = &subgraph->values[input1_id];
235 status = xnn_subgraph_check_nth_input_type_dense(xnn_node_type_multiply2, input1_id, input1_value, 1);
236 if (status != xnn_status_success) {
237 return status;
238 }
239
240 switch (input1_value->datatype) {
241 case xnn_datatype_fp32:
242 #ifndef XNN_NO_QS8_OPERATORS
243 case xnn_datatype_qint8:
244 #endif // !defined(XNN_NO_QS8_OPERATORS)
245 #ifndef XNN_NO_QU8_OPERATORS
246 case xnn_datatype_quint8:
247 #endif // !defined(XNN_NO_QU8_OPERATORS)
248 break;
249 default:
250 xnn_log_error(
251 "failed to define %s operator with the first input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
252 xnn_node_type_to_string(xnn_node_type_multiply2), input1_id,
253 xnn_datatype_to_string(input1_value->datatype), input1_value->datatype);
254 return xnn_status_invalid_parameter;
255 }
256
257 if ((status = xnn_subgraph_check_nth_input_node_id(
258 xnn_node_type_multiply2, input2_id, subgraph->num_values, 2)) != xnn_status_success) {
259 return status;
260 }
261
262 const struct xnn_value* input2_value = &subgraph->values[input2_id];
263 status = xnn_subgraph_check_nth_input_type_dense(xnn_node_type_multiply2, input2_id, input2_value, 2);
264 if (status != xnn_status_success) {
265 return status;
266 }
267
268 switch (input2_value->datatype) {
269 case xnn_datatype_fp32:
270 #ifndef XNN_NO_QS8_OPERATORS
271 case xnn_datatype_qint8:
272 #endif // !defined(XNN_NO_QS8_OPERATORS)
273 #ifndef XNN_NO_QU8_OPERATORS
274 case xnn_datatype_quint8:
275 #endif // !defined(XNN_NO_QU8_OPERATORS)
276 break;
277 default:
278 xnn_log_error(
279 "failed to define %s operator with the second input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
280 xnn_node_type_to_string(xnn_node_type_multiply2), input2_id,
281 xnn_datatype_to_string(input2_value->datatype), input2_value->datatype);
282 return xnn_status_invalid_parameter;
283 }
284
285 status = xnn_subgraph_check_output_node_id(xnn_node_type_multiply2, output_id, subgraph->num_values);
286 if (status != xnn_status_success) {
287 return status;
288 }
289
290 const struct xnn_value* output_value = &subgraph->values[output_id];
291 status = xnn_subgraph_check_output_type_dense(xnn_node_type_multiply2, output_id, output_value);
292 if (status != xnn_status_success) {
293 return status;
294 }
295
296 enum xnn_compute_type compute_type = xnn_compute_type_invalid;
297 switch (output_value->datatype) {
298 case xnn_datatype_fp32:
299 compute_type = xnn_compute_type_fp32;
300 break;
301 #ifndef XNN_NO_QS8_OPERATORS
302 case xnn_datatype_qint8:
303 compute_type = xnn_compute_type_qs8;
304 break;
305 #endif // !defined(XNN_NO_QS8_OPERATORS)
306 #ifndef XNN_NO_QU8_OPERATORS
307 case xnn_datatype_quint8:
308 compute_type = xnn_compute_type_qu8;
309 break;
310 #endif // !defined(XNN_NO_QU8_OPERATORS)
311 default:
312 xnn_log_error(
313 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
314 xnn_node_type_to_string(xnn_node_type_multiply2), output_id,
315 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
316 return xnn_status_invalid_parameter;
317 }
318
319 status = xnn_subgraph_check_datatype_matches_two_inputs(
320 xnn_node_type_multiply2, input1_id, input1_value, input2_id, input2_value, output_id, output_value);
321 if (status != xnn_status_success) {
322 return status;
323 }
324
325 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
326 if (node == NULL) {
327 return xnn_status_out_of_memory;
328 }
329
330 node->type = xnn_node_type_multiply2;
331 node->compute_type = compute_type;
332 node->activation.output_min = output_min;
333 node->activation.output_max = output_max;
334 node->num_inputs = 2;
335 node->inputs[0] = input1_id;
336 node->inputs[1] = input2_id;
337 node->num_outputs = 1;
338 node->outputs[0] = output_id;
339 node->flags = flags;
340
341 node->create = create_multiply_operator;
342 node->setup = setup_multiply_operator;
343
344 return xnn_status_success;
345 }
346