xref: /aosp_15_r20/external/XNNPACK/src/subgraph/convert.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/subgraph.h>
16 #include <xnnpack/subgraph-validation.h>
17 
18 
19 
create_convert_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)20 static enum xnn_status create_convert_operator(
21   const struct xnn_node* node,
22   const struct xnn_value* values,
23   size_t num_values,
24   struct xnn_operator_data* opdata,
25   const struct xnn_caches* caches)
26 {
27   assert(node->num_inputs == 1);
28   const uint32_t input_id = node->inputs[0];
29   assert(input_id != XNN_INVALID_VALUE_ID);
30   assert(input_id < num_values);
31 
32   assert(node->num_outputs == 1);
33   const uint32_t output_id = node->outputs[0];
34   assert(output_id != XNN_INVALID_VALUE_ID);
35   assert(output_id < num_values);
36 
37   const size_t num_input_dims = values[input_id].shape.num_dims;
38   const size_t channel_dim = num_input_dims == 0 ? 1 : values[input_id].shape.dim[num_input_dims - 1];
39 
40   enum xnn_status status = xnn_status_uninitialized;
41   switch (node->compute_type) {
42     case xnn_compute_type_fp32_to_fp16:
43       status = xnn_create_convert_nc_f32_f16(
44         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
45         node->flags,
46         &opdata->operator_objects[0]);
47       break;
48     case xnn_compute_type_fp32_to_qs8:
49       status = xnn_create_convert_nc_f32_qs8(
50         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
51         values[output_id].quantization.scale,
52         (int8_t) values[output_id].quantization.zero_point,
53         INT8_MIN, INT8_MAX,
54         node->flags,
55         &opdata->operator_objects[0]);
56       break;
57     case xnn_compute_type_fp32_to_qu8:
58       status = xnn_create_convert_nc_f32_qu8(
59         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
60         values[output_id].quantization.scale,
61         (uint8_t) values[output_id].quantization.zero_point,
62         0, UINT8_MAX,
63         node->flags,
64         &opdata->operator_objects[0]);
65       break;
66     case xnn_compute_type_fp16_to_fp32:
67       status = xnn_create_convert_nc_f16_f32(
68         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
69         node->flags,
70         &opdata->operator_objects[0]);
71       break;
72     case xnn_compute_type_qs8:
73       status = xnn_create_convert_nc_qs8(
74         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
75         values[input_id].quantization.scale,
76         (int8_t) values[input_id].quantization.zero_point,
77         values[output_id].quantization.scale,
78         (int8_t) values[output_id].quantization.zero_point,
79         node->flags,
80         &opdata->operator_objects[0]);
81       break;
82     case xnn_compute_type_qs8_to_fp32:
83       status = xnn_create_convert_nc_qs8_f32(
84         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
85         values[input_id].quantization.scale,
86         (int8_t) values[input_id].quantization.zero_point,
87         node->flags,
88         &opdata->operator_objects[0]);
89       break;
90     case xnn_compute_type_qu8:
91       status = xnn_create_convert_nc_qu8(
92         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
93         values[input_id].quantization.scale,
94         (uint8_t) values[input_id].quantization.zero_point,
95         values[output_id].quantization.scale,
96         (uint8_t) values[output_id].quantization.zero_point,
97         node->flags,
98         &opdata->operator_objects[0]);
99       break;
100     case xnn_compute_type_qu8_to_fp32:
101       status = xnn_create_convert_nc_qu8_f32(
102         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
103         values[input_id].quantization.scale,
104         (uint8_t) values[input_id].quantization.zero_point,
105         node->flags,
106         &opdata->operator_objects[0]);
107       break;
108     default:
109       XNN_UNREACHABLE;
110   }
111   if (status == xnn_status_success) {
112     opdata->batch_size = xnn_shape_multiply_non_channel_dims(&values[input_id].shape);
113     opdata->inputs[0] = input_id;
114     opdata->outputs[0] = output_id;
115   }
116   return status;
117 }
118 
setup_convert_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)119 static enum xnn_status setup_convert_operator(
120   const struct xnn_operator_data* opdata,
121   const struct xnn_blob* blobs,
122   size_t num_blobs,
123   pthreadpool_t threadpool)
124 {
125   const uint32_t input_id = opdata->inputs[0];
126   assert(input_id != XNN_INVALID_VALUE_ID);
127   assert(input_id < num_blobs);
128 
129   const uint32_t output_id = opdata->outputs[0];
130   assert(output_id != XNN_INVALID_VALUE_ID);
131   assert(output_id < num_blobs);
132 
133   const struct xnn_blob* input_blob = blobs + input_id;
134   const void* input_data = input_blob->data;
135   assert(input_data != NULL);
136 
137   const struct xnn_blob* output_blob = blobs + output_id;
138   void* output_data = output_blob->data;
139   assert(output_data != NULL);
140 
141   switch (opdata->operator_objects[0]->type) {
142     case xnn_operator_type_convert_nc_f32_f16:
143       return xnn_setup_convert_nc_f32_f16(
144         opdata->operator_objects[0],
145         opdata->batch_size,
146         input_data,
147         output_data,
148         threadpool);
149     case xnn_operator_type_convert_nc_f32_qs8:
150       return xnn_setup_convert_nc_f32_qs8(
151         opdata->operator_objects[0],
152         opdata->batch_size,
153         input_data,
154         output_data,
155         threadpool);
156     case xnn_operator_type_convert_nc_f32_qu8:
157       return xnn_setup_convert_nc_f32_qu8(
158         opdata->operator_objects[0],
159         opdata->batch_size,
160         input_data,
161         output_data,
162         threadpool);
163     case xnn_operator_type_convert_nc_f16_f32:
164       return xnn_setup_convert_nc_f16_f32(
165         opdata->operator_objects[0],
166         opdata->batch_size,
167         input_data,
168         output_data,
169         threadpool);
170     case xnn_operator_type_convert_nc_qs8:
171       return xnn_setup_convert_nc_qs8(
172         opdata->operator_objects[0],
173         opdata->batch_size,
174         input_data,
175         output_data,
176         threadpool);
177     case xnn_operator_type_convert_nc_qs8_f32:
178       return xnn_setup_convert_nc_qs8_f32(
179         opdata->operator_objects[0],
180         opdata->batch_size,
181         input_data,
182         output_data,
183         threadpool);
184     case xnn_operator_type_convert_nc_qu8:
185       return xnn_setup_convert_nc_qu8(
186         opdata->operator_objects[0],
187         opdata->batch_size,
188         input_data,
189         output_data,
190         threadpool);
191     case xnn_operator_type_convert_nc_qu8_f32:
192       return xnn_setup_convert_nc_qu8_f32(
193         opdata->operator_objects[0],
194         opdata->batch_size,
195         input_data,
196         output_data,
197         threadpool);
198     default:
199       XNN_UNREACHABLE;
200   }
201 }
202 
validate_datatypes(enum xnn_datatype input_datatype,enum xnn_datatype output_datatype)203 static inline enum xnn_compute_type validate_datatypes(
204   enum xnn_datatype input_datatype,
205   enum xnn_datatype output_datatype)
206 {
207   switch (input_datatype) {
208     case xnn_datatype_fp32:
209       switch (output_datatype) {
210         case xnn_datatype_fp16:
211           return xnn_compute_type_fp32_to_fp16;
212         case xnn_datatype_qint8:
213           return xnn_compute_type_fp32_to_qs8;
214         case xnn_datatype_quint8:
215           return xnn_compute_type_fp32_to_qu8;
216         default:
217           break;
218       }
219       break;
220     case xnn_datatype_fp16:
221       if (output_datatype == xnn_datatype_fp32) {
222         return xnn_compute_type_fp16_to_fp32;
223       }
224       break;
225     case xnn_datatype_qint8:
226       switch (output_datatype) {
227         case xnn_datatype_fp32:
228           return xnn_compute_type_qs8_to_fp32;
229         case xnn_datatype_qint8:
230           return xnn_compute_type_qs8;
231         default:
232           break;
233       }
234       break;
235     case xnn_datatype_quint8:
236       switch (output_datatype) {
237         case xnn_datatype_fp32:
238           return xnn_compute_type_qu8_to_fp32;
239         case xnn_datatype_quint8:
240           return xnn_compute_type_qu8;
241         default:
242           break;
243       }
244       break;
245     default:
246       XNN_UNREACHABLE;
247   }
248   return xnn_compute_type_invalid;
249 }
250 
xnn_init_convert_node(struct xnn_node * node,enum xnn_compute_type compute_type,uint32_t input_id,uint32_t output_id,uint32_t flags)251 void xnn_init_convert_node(
252   struct xnn_node* node,
253   enum xnn_compute_type compute_type,
254   uint32_t input_id,
255   uint32_t output_id,
256   uint32_t flags)
257 {
258   node->type = xnn_node_type_convert;
259   node->compute_type = compute_type;
260   node->num_inputs = 1;
261   node->inputs[0] = input_id;
262   node->num_outputs = 1;
263   node->outputs[0] = output_id;
264   node->flags = flags;
265 
266   node->create = create_convert_operator;
267   node->setup = setup_convert_operator;
268 }
269 
xnn_define_convert(xnn_subgraph_t subgraph,uint32_t input_id,uint32_t output_id,uint32_t flags)270 enum xnn_status xnn_define_convert(
271   xnn_subgraph_t subgraph,
272   uint32_t input_id,
273   uint32_t output_id,
274   uint32_t flags)
275 {
276   enum xnn_status status;
277   if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_convert)) != xnn_status_success) {
278     return status;
279   }
280 
281   if ((status = xnn_subgraph_check_input_node_id(xnn_node_type_convert, input_id, subgraph->num_values)) !=
282       xnn_status_success) {
283     return status;
284   }
285 
286   const struct xnn_value* input_value = &subgraph->values[input_id];
287   status = xnn_subgraph_check_input_type_dense(xnn_node_type_convert, input_id, input_value);
288   if (status != xnn_status_success) {
289     return status;
290   }
291 
292   switch (input_value->datatype) {
293     case xnn_datatype_fp16:
294     case xnn_datatype_fp32:
295     case xnn_datatype_qint8:
296     case xnn_datatype_quint8:
297       break;
298     default:
299       xnn_log_error(
300         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
301         xnn_node_type_to_string(xnn_node_type_convert), input_id,
302         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
303       return xnn_status_invalid_parameter;
304   }
305 
306   status = xnn_subgraph_check_output_node_id(xnn_node_type_convert, output_id, subgraph->num_values);
307   if (status != xnn_status_success) {
308     return status;
309   }
310 
311   const struct xnn_value* output_value = &subgraph->values[output_id];
312   status = xnn_subgraph_check_output_type_dense(xnn_node_type_convert, output_id, output_value);
313   if (status != xnn_status_success) {
314     return status;
315   }
316 
317   switch (output_value->datatype) {
318     case xnn_datatype_fp16:
319     case xnn_datatype_fp32:
320     case xnn_datatype_qint8:
321     case xnn_datatype_quint8:
322       break;
323     default:
324       xnn_log_error(
325         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
326         xnn_node_type_to_string(xnn_node_type_convert), output_id,
327         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
328       return xnn_status_invalid_parameter;
329   }
330 
331   enum xnn_compute_type compute_type = validate_datatypes(input_value->datatype, output_value->datatype);
332   if (compute_type == xnn_compute_type_invalid) {
333     xnn_log_error(
334       "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
335       ": mismatching datatypes across input (%s) and output (%s)",
336       xnn_node_type_to_string(xnn_node_type_convert), input_id, output_id,
337       xnn_datatype_to_string(input_value->datatype),
338       xnn_datatype_to_string(output_value->datatype));
339     return xnn_status_invalid_parameter;
340   }
341 
342   switch (compute_type) {
343     case xnn_compute_type_invalid:
344       xnn_log_error(
345         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
346         ": mismatching datatypes across input (%s) and output (%s)",
347         xnn_node_type_to_string(xnn_node_type_convert), input_id, output_id,
348         xnn_datatype_to_string(input_value->datatype),
349         xnn_datatype_to_string(output_value->datatype));
350       return xnn_status_invalid_parameter;
351     case xnn_compute_type_qs8:
352     case xnn_compute_type_qu8:
353     {
354       const float input_output_scale = input_value->quantization.scale / output_value->quantization.scale;
355       if (input_output_scale < 0x1.0p-8f || input_output_scale > 0x1.0p+7f) {
356         xnn_log_error(
357           "failed to define %s operator with %.7g input-to-output scale ratio (input #%"PRIu32" scale %.7g, output #%"PRIu32" scale %.7g): "
358           "scale ratio must be in [2**-8, 2**7] range",
359           xnn_node_type_to_string(xnn_node_type_convert), input_output_scale,
360           input_id, input_value->quantization.scale, output_id, output_value->quantization.scale);
361         return xnn_status_invalid_parameter;
362       }
363       break;
364     }
365     default:
366       break;
367   }
368 
369   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
370   if (node == NULL) {
371     return xnn_status_out_of_memory;
372   }
373 
374   xnn_init_convert_node(node, compute_type, input_id, output_id, flags);
375   return xnn_status_success;
376 }
377