xref: /aosp_15_r20/external/XNNPACK/src/subgraph/multiply2.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11 
12 #include <xnnpack.h>
13 #include <xnnpack/log.h>
14 #include <xnnpack/operator.h>
15 #include <xnnpack/params.h>
16 #include <xnnpack/requantization.h>
17 #include <xnnpack/subgraph.h>
18 #include <xnnpack/subgraph-validation.h>
19 
20 
create_multiply_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)21 static enum xnn_status create_multiply_operator(
22   const struct xnn_node* node,
23   const struct xnn_value* values,
24   size_t num_values,
25   struct xnn_operator_data* opdata,
26   const struct xnn_caches* caches)
27 {
28   assert(node->num_inputs == 2);
29   const uint32_t input1_id = node->inputs[0];
30   assert(input1_id != XNN_INVALID_VALUE_ID);
31   assert(input1_id < num_values);
32   const uint32_t input2_id = node->inputs[1];
33   assert(input2_id != XNN_INVALID_VALUE_ID);
34   assert(input2_id < num_values);
35 
36   assert(node->num_outputs == 1);
37   const uint32_t output_id = node->outputs[0];
38   assert(output_id != XNN_INVALID_VALUE_ID);
39   assert(output_id < num_values);
40 
41   enum xnn_status status;
42   switch (node->compute_type) {
43 #ifndef XNN_NO_F16_OPERATORS
44     case xnn_compute_type_fp16:
45       status = xnn_create_multiply_nd_f16(
46         node->activation.output_min,
47         node->activation.output_max,
48         node->flags,
49         &opdata->operator_objects[0]);
50       break;
51 #endif  // XNN_NO_F16_OPERATORS
52     case xnn_compute_type_fp32:
53       status = xnn_create_multiply_nd_f32(
54         node->activation.output_min,
55         node->activation.output_max,
56         node->flags,
57         &opdata->operator_objects[0]);
58       break;
59 #ifndef XNN_NO_QS8_OPERATORS
60     case xnn_compute_type_qs8:
61     {
62       const float output_scale = values[output_id].quantization.scale;
63       const int32_t output_zero_point = values[output_id].quantization.zero_point;
64       const int8_t output_min = xnn_qs8_quantize(node->activation.output_min, output_scale, output_zero_point);
65       const int8_t output_max = xnn_qs8_quantize(node->activation.output_max, output_scale, output_zero_point);
66       status = xnn_create_multiply_nd_qs8(
67         (int8_t) values[input1_id].quantization.zero_point,
68         values[input1_id].quantization.scale,
69         (int8_t) values[input2_id].quantization.zero_point,
70         values[input2_id].quantization.scale,
71         (int8_t) output_zero_point,
72         output_scale, output_min, output_max, node->flags,
73         &opdata->operator_objects[0]);
74       break;
75     }
76 #endif  // !defined(XNN_NO_QS8_OPERATORS)
77 #ifndef XNN_NO_QU8_OPERATORS
78     case xnn_compute_type_qu8:
79     {
80       const float output_scale = values[output_id].quantization.scale;
81       const int32_t output_zero_point = values[output_id].quantization.zero_point;
82       const uint8_t output_min = xnn_qu8_quantize(node->activation.output_min, output_scale, output_zero_point);
83       const uint8_t output_max = xnn_qu8_quantize(node->activation.output_max, output_scale, output_zero_point);
84       status = xnn_create_multiply_nd_qu8(
85         (uint8_t) values[input1_id].quantization.zero_point,
86         values[input1_id].quantization.scale,
87         (uint8_t) values[input2_id].quantization.zero_point,
88         values[input2_id].quantization.scale,
89         (uint8_t) output_zero_point,
90         output_scale, output_min, output_max, node->flags,
91         &opdata->operator_objects[0]);
92       break;
93     }
94 #endif  // !defined(XNN_NO_QU8_OPERATORS)
95     default:
96       XNN_UNREACHABLE;
97   }
98   if (status == xnn_status_success) {
99     opdata->shape1.num_dims = values[input1_id].shape.num_dims;
100     opdata->shape2.num_dims = values[input2_id].shape.num_dims;
101     if (values[output_id].layout == xnn_layout_type_nchw) {
102       assert(values[input1_id].layout == xnn_layout_type_nchw);
103       assert(values[input2_id].layout == xnn_layout_type_nchw);
104       opdata->shape1.dim[0] = values[input1_id].shape.dim[0];
105       opdata->shape1.dim[1] = values[input1_id].shape.dim[values[input1_id].shape.num_dims - 1];
106       if (values[input1_id].shape.num_dims > 2) {
107         memcpy(&opdata->shape1.dim[2], &values[input1_id].shape.dim[1], (values[input1_id].shape.num_dims - 2) * sizeof(size_t));
108       }
109       opdata->shape2.dim[0] = values[input2_id].shape.dim[0];
110       opdata->shape2.dim[1] = values[input2_id].shape.dim[values[input2_id].shape.num_dims - 1];
111       if (values[input1_id].shape.num_dims > 2) {
112         memcpy(&opdata->shape2.dim[2], &values[input2_id].shape.dim[1], (values[input2_id].shape.num_dims - 2) * sizeof(size_t));
113       }
114     } else {
115       assert(values[output_id].layout == xnn_layout_type_nhwc);
116       assert(values[input1_id].layout == xnn_layout_type_nhwc);
117       assert(values[input2_id].layout == xnn_layout_type_nhwc);
118       memcpy(opdata->shape1.dim, values[input1_id].shape.dim, values[input1_id].shape.num_dims * sizeof(size_t));
119       memcpy(opdata->shape2.dim, values[input2_id].shape.dim, values[input2_id].shape.num_dims * sizeof(size_t));
120     }
121     opdata->inputs[0] = input1_id;
122     opdata->inputs[1] = input2_id;
123     opdata->outputs[0] = output_id;
124   }
125   return status;
126 }
127 
setup_multiply_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)128 static enum xnn_status setup_multiply_operator(
129   const struct xnn_operator_data* opdata,
130   const struct xnn_blob* blobs,
131   size_t num_blobs,
132   pthreadpool_t threadpool)
133 {
134   const uint32_t input1_id = opdata->inputs[0];
135   assert(input1_id != XNN_INVALID_VALUE_ID);
136   assert(input1_id < num_blobs);
137 
138   const uint32_t input2_id = opdata->inputs[1];
139   assert(input2_id != XNN_INVALID_VALUE_ID);
140   assert(input2_id < num_blobs);
141 
142   const uint32_t output_id = opdata->outputs[0];
143   assert(output_id != XNN_INVALID_VALUE_ID);
144   assert(output_id < num_blobs);
145 
146   const struct xnn_blob* input1_blob = blobs + input1_id;
147   const void* input1_data = input1_blob->data;
148   assert(input1_data != NULL);
149 
150   const struct xnn_blob* input2_blob = blobs + input2_id;
151   const void* input2_data = input2_blob->data;
152   assert(input2_data != NULL);
153 
154   const struct xnn_blob* output_blob = blobs + output_id;
155   void* output_data = output_blob->data;
156   assert(output_data != NULL);
157 
158   switch (opdata->operator_objects[0]->type) {
159 #ifndef XNN_NO_F16_OPERATORS
160     case xnn_operator_type_multiply_nd_f16:
161       return xnn_setup_multiply_nd_f16(
162         opdata->operator_objects[0],
163         opdata->shape1.num_dims,
164         opdata->shape1.dim,
165         opdata->shape2.num_dims,
166         opdata->shape2.dim,
167         input1_data, input2_data, output_data,
168         threadpool);
169       break;
170 #endif  // !defined(XNN_NO_F16_OPERATORS)
171     case xnn_operator_type_multiply_nd_f32:
172       return xnn_setup_multiply_nd_f32(
173         opdata->operator_objects[0],
174         opdata->shape1.num_dims,
175         opdata->shape1.dim,
176         opdata->shape2.num_dims,
177         opdata->shape2.dim,
178         input1_data, input2_data, output_data,
179         threadpool);
180       break;
181 #ifndef XNN_NO_QS8_OPERATORS
182     case xnn_operator_type_multiply_nd_qs8:
183       return xnn_setup_multiply_nd_qs8(
184         opdata->operator_objects[0],
185         opdata->shape1.num_dims,
186         opdata->shape1.dim,
187         opdata->shape2.num_dims,
188         opdata->shape2.dim,
189         input1_data, input2_data, output_data,
190         threadpool);
191       break;
192 #endif  // !defined(XNN_NO_QS8_OPERATORS)
193 #ifndef XNN_NO_QU8_OPERATORS
194     case xnn_operator_type_multiply_nd_qu8:
195       return xnn_setup_multiply_nd_qu8(
196         opdata->operator_objects[0],
197         opdata->shape1.num_dims,
198         opdata->shape1.dim,
199         opdata->shape2.num_dims,
200         opdata->shape2.dim,
201         input1_data, input2_data, output_data,
202         threadpool);
203       break;
204 #endif  // !defined(XNN_NO_QU8_OPERATORS)
205     default:
206       XNN_UNREACHABLE;
207   }
208 }
209 
xnn_define_multiply2(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input1_id,uint32_t input2_id,uint32_t output_id,uint32_t flags)210 enum xnn_status xnn_define_multiply2(
211   xnn_subgraph_t subgraph,
212   float output_min,
213   float output_max,
214   uint32_t input1_id,
215   uint32_t input2_id,
216   uint32_t output_id,
217   uint32_t flags)
218 {
219   enum xnn_status status;
220   if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_multiply2)) != xnn_status_success) {
221     return status;
222   }
223 
224   status = xnn_subgraph_check_output_min_max(xnn_node_type_multiply2, output_min, output_max);
225   if (status != xnn_status_success) {
226     return status;
227   }
228 
229   if ((status = xnn_subgraph_check_nth_input_node_id(
230         xnn_node_type_multiply2, input1_id, subgraph->num_values, 1)) != xnn_status_success) {
231     return status;
232   }
233 
234   const struct xnn_value* input1_value = &subgraph->values[input1_id];
235   status = xnn_subgraph_check_nth_input_type_dense(xnn_node_type_multiply2, input1_id, input1_value, 1);
236   if (status != xnn_status_success) {
237     return status;
238   }
239 
240   switch (input1_value->datatype) {
241     case xnn_datatype_fp32:
242 #ifndef XNN_NO_QS8_OPERATORS
243     case xnn_datatype_qint8:
244 #endif  // !defined(XNN_NO_QS8_OPERATORS)
245 #ifndef XNN_NO_QU8_OPERATORS
246     case xnn_datatype_quint8:
247 #endif  // !defined(XNN_NO_QU8_OPERATORS)
248       break;
249     default:
250       xnn_log_error(
251         "failed to define %s operator with the first input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
252         xnn_node_type_to_string(xnn_node_type_multiply2), input1_id,
253         xnn_datatype_to_string(input1_value->datatype), input1_value->datatype);
254       return xnn_status_invalid_parameter;
255   }
256 
257   if ((status = xnn_subgraph_check_nth_input_node_id(
258         xnn_node_type_multiply2, input2_id, subgraph->num_values, 2)) != xnn_status_success) {
259     return status;
260   }
261 
262   const struct xnn_value* input2_value = &subgraph->values[input2_id];
263   status = xnn_subgraph_check_nth_input_type_dense(xnn_node_type_multiply2, input2_id, input2_value, 2);
264   if (status != xnn_status_success) {
265     return status;
266   }
267 
268   switch (input2_value->datatype) {
269     case xnn_datatype_fp32:
270 #ifndef XNN_NO_QS8_OPERATORS
271     case xnn_datatype_qint8:
272 #endif  // !defined(XNN_NO_QS8_OPERATORS)
273 #ifndef XNN_NO_QU8_OPERATORS
274     case xnn_datatype_quint8:
275 #endif  // !defined(XNN_NO_QU8_OPERATORS)
276       break;
277     default:
278       xnn_log_error(
279         "failed to define %s operator with the second input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
280         xnn_node_type_to_string(xnn_node_type_multiply2), input2_id,
281         xnn_datatype_to_string(input2_value->datatype), input2_value->datatype);
282       return xnn_status_invalid_parameter;
283   }
284 
285   status = xnn_subgraph_check_output_node_id(xnn_node_type_multiply2, output_id, subgraph->num_values);
286   if (status != xnn_status_success) {
287     return status;
288   }
289 
290   const struct xnn_value* output_value = &subgraph->values[output_id];
291   status = xnn_subgraph_check_output_type_dense(xnn_node_type_multiply2, output_id, output_value);
292   if (status != xnn_status_success) {
293     return status;
294   }
295 
296   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
297   switch (output_value->datatype) {
298     case xnn_datatype_fp32:
299       compute_type = xnn_compute_type_fp32;
300       break;
301 #ifndef XNN_NO_QS8_OPERATORS
302     case xnn_datatype_qint8:
303       compute_type = xnn_compute_type_qs8;
304       break;
305 #endif  // !defined(XNN_NO_QS8_OPERATORS)
306 #ifndef XNN_NO_QU8_OPERATORS
307     case xnn_datatype_quint8:
308       compute_type = xnn_compute_type_qu8;
309       break;
310 #endif  // !defined(XNN_NO_QU8_OPERATORS)
311     default:
312       xnn_log_error(
313         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
314         xnn_node_type_to_string(xnn_node_type_multiply2), output_id,
315         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
316       return xnn_status_invalid_parameter;
317   }
318 
319   status = xnn_subgraph_check_datatype_matches_two_inputs(
320       xnn_node_type_multiply2, input1_id, input1_value, input2_id, input2_value, output_id, output_value);
321   if (status != xnn_status_success) {
322     return status;
323   }
324 
325   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
326   if (node == NULL) {
327     return xnn_status_out_of_memory;
328   }
329 
330   node->type = xnn_node_type_multiply2;
331   node->compute_type = compute_type;
332   node->activation.output_min = output_min;
333   node->activation.output_max = output_max;
334   node->num_inputs = 2;
335   node->inputs[0] = input1_id;
336   node->inputs[1] = input2_id;
337   node->num_outputs = 1;
338   node->outputs[0] = output_id;
339   node->flags = flags;
340 
341   node->create = create_multiply_operator;
342   node->setup = setup_multiply_operator;
343 
344   return xnn_status_success;
345 }
346