1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/subgraph.h>
16 #include <xnnpack/subgraph-validation.h>
17
18
19
create_convert_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)20 static enum xnn_status create_convert_operator(
21 const struct xnn_node* node,
22 const struct xnn_value* values,
23 size_t num_values,
24 struct xnn_operator_data* opdata,
25 const struct xnn_caches* caches)
26 {
27 assert(node->num_inputs == 1);
28 const uint32_t input_id = node->inputs[0];
29 assert(input_id != XNN_INVALID_VALUE_ID);
30 assert(input_id < num_values);
31
32 assert(node->num_outputs == 1);
33 const uint32_t output_id = node->outputs[0];
34 assert(output_id != XNN_INVALID_VALUE_ID);
35 assert(output_id < num_values);
36
37 const size_t num_input_dims = values[input_id].shape.num_dims;
38 const size_t channel_dim = num_input_dims == 0 ? 1 : values[input_id].shape.dim[num_input_dims - 1];
39
40 enum xnn_status status = xnn_status_uninitialized;
41 switch (node->compute_type) {
42 case xnn_compute_type_fp32_to_fp16:
43 status = xnn_create_convert_nc_f32_f16(
44 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
45 node->flags,
46 &opdata->operator_objects[0]);
47 break;
48 case xnn_compute_type_fp32_to_qs8:
49 status = xnn_create_convert_nc_f32_qs8(
50 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
51 values[output_id].quantization.scale,
52 (int8_t) values[output_id].quantization.zero_point,
53 INT8_MIN, INT8_MAX,
54 node->flags,
55 &opdata->operator_objects[0]);
56 break;
57 case xnn_compute_type_fp32_to_qu8:
58 status = xnn_create_convert_nc_f32_qu8(
59 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
60 values[output_id].quantization.scale,
61 (uint8_t) values[output_id].quantization.zero_point,
62 0, UINT8_MAX,
63 node->flags,
64 &opdata->operator_objects[0]);
65 break;
66 case xnn_compute_type_fp16_to_fp32:
67 status = xnn_create_convert_nc_f16_f32(
68 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
69 node->flags,
70 &opdata->operator_objects[0]);
71 break;
72 case xnn_compute_type_qs8:
73 status = xnn_create_convert_nc_qs8(
74 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
75 values[input_id].quantization.scale,
76 (int8_t) values[input_id].quantization.zero_point,
77 values[output_id].quantization.scale,
78 (int8_t) values[output_id].quantization.zero_point,
79 node->flags,
80 &opdata->operator_objects[0]);
81 break;
82 case xnn_compute_type_qs8_to_fp32:
83 status = xnn_create_convert_nc_qs8_f32(
84 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
85 values[input_id].quantization.scale,
86 (int8_t) values[input_id].quantization.zero_point,
87 node->flags,
88 &opdata->operator_objects[0]);
89 break;
90 case xnn_compute_type_qu8:
91 status = xnn_create_convert_nc_qu8(
92 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
93 values[input_id].quantization.scale,
94 (uint8_t) values[input_id].quantization.zero_point,
95 values[output_id].quantization.scale,
96 (uint8_t) values[output_id].quantization.zero_point,
97 node->flags,
98 &opdata->operator_objects[0]);
99 break;
100 case xnn_compute_type_qu8_to_fp32:
101 status = xnn_create_convert_nc_qu8_f32(
102 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
103 values[input_id].quantization.scale,
104 (uint8_t) values[input_id].quantization.zero_point,
105 node->flags,
106 &opdata->operator_objects[0]);
107 break;
108 default:
109 XNN_UNREACHABLE;
110 }
111 if (status == xnn_status_success) {
112 opdata->batch_size = xnn_shape_multiply_non_channel_dims(&values[input_id].shape);
113 opdata->inputs[0] = input_id;
114 opdata->outputs[0] = output_id;
115 }
116 return status;
117 }
118
setup_convert_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)119 static enum xnn_status setup_convert_operator(
120 const struct xnn_operator_data* opdata,
121 const struct xnn_blob* blobs,
122 size_t num_blobs,
123 pthreadpool_t threadpool)
124 {
125 const uint32_t input_id = opdata->inputs[0];
126 assert(input_id != XNN_INVALID_VALUE_ID);
127 assert(input_id < num_blobs);
128
129 const uint32_t output_id = opdata->outputs[0];
130 assert(output_id != XNN_INVALID_VALUE_ID);
131 assert(output_id < num_blobs);
132
133 const struct xnn_blob* input_blob = blobs + input_id;
134 const void* input_data = input_blob->data;
135 assert(input_data != NULL);
136
137 const struct xnn_blob* output_blob = blobs + output_id;
138 void* output_data = output_blob->data;
139 assert(output_data != NULL);
140
141 switch (opdata->operator_objects[0]->type) {
142 case xnn_operator_type_convert_nc_f32_f16:
143 return xnn_setup_convert_nc_f32_f16(
144 opdata->operator_objects[0],
145 opdata->batch_size,
146 input_data,
147 output_data,
148 threadpool);
149 case xnn_operator_type_convert_nc_f32_qs8:
150 return xnn_setup_convert_nc_f32_qs8(
151 opdata->operator_objects[0],
152 opdata->batch_size,
153 input_data,
154 output_data,
155 threadpool);
156 case xnn_operator_type_convert_nc_f32_qu8:
157 return xnn_setup_convert_nc_f32_qu8(
158 opdata->operator_objects[0],
159 opdata->batch_size,
160 input_data,
161 output_data,
162 threadpool);
163 case xnn_operator_type_convert_nc_f16_f32:
164 return xnn_setup_convert_nc_f16_f32(
165 opdata->operator_objects[0],
166 opdata->batch_size,
167 input_data,
168 output_data,
169 threadpool);
170 case xnn_operator_type_convert_nc_qs8:
171 return xnn_setup_convert_nc_qs8(
172 opdata->operator_objects[0],
173 opdata->batch_size,
174 input_data,
175 output_data,
176 threadpool);
177 case xnn_operator_type_convert_nc_qs8_f32:
178 return xnn_setup_convert_nc_qs8_f32(
179 opdata->operator_objects[0],
180 opdata->batch_size,
181 input_data,
182 output_data,
183 threadpool);
184 case xnn_operator_type_convert_nc_qu8:
185 return xnn_setup_convert_nc_qu8(
186 opdata->operator_objects[0],
187 opdata->batch_size,
188 input_data,
189 output_data,
190 threadpool);
191 case xnn_operator_type_convert_nc_qu8_f32:
192 return xnn_setup_convert_nc_qu8_f32(
193 opdata->operator_objects[0],
194 opdata->batch_size,
195 input_data,
196 output_data,
197 threadpool);
198 default:
199 XNN_UNREACHABLE;
200 }
201 }
202
validate_datatypes(enum xnn_datatype input_datatype,enum xnn_datatype output_datatype)203 static inline enum xnn_compute_type validate_datatypes(
204 enum xnn_datatype input_datatype,
205 enum xnn_datatype output_datatype)
206 {
207 switch (input_datatype) {
208 case xnn_datatype_fp32:
209 switch (output_datatype) {
210 case xnn_datatype_fp16:
211 return xnn_compute_type_fp32_to_fp16;
212 case xnn_datatype_qint8:
213 return xnn_compute_type_fp32_to_qs8;
214 case xnn_datatype_quint8:
215 return xnn_compute_type_fp32_to_qu8;
216 default:
217 break;
218 }
219 break;
220 case xnn_datatype_fp16:
221 if (output_datatype == xnn_datatype_fp32) {
222 return xnn_compute_type_fp16_to_fp32;
223 }
224 break;
225 case xnn_datatype_qint8:
226 switch (output_datatype) {
227 case xnn_datatype_fp32:
228 return xnn_compute_type_qs8_to_fp32;
229 case xnn_datatype_qint8:
230 return xnn_compute_type_qs8;
231 default:
232 break;
233 }
234 break;
235 case xnn_datatype_quint8:
236 switch (output_datatype) {
237 case xnn_datatype_fp32:
238 return xnn_compute_type_qu8_to_fp32;
239 case xnn_datatype_quint8:
240 return xnn_compute_type_qu8;
241 default:
242 break;
243 }
244 break;
245 default:
246 XNN_UNREACHABLE;
247 }
248 return xnn_compute_type_invalid;
249 }
250
xnn_init_convert_node(struct xnn_node * node,enum xnn_compute_type compute_type,uint32_t input_id,uint32_t output_id,uint32_t flags)251 void xnn_init_convert_node(
252 struct xnn_node* node,
253 enum xnn_compute_type compute_type,
254 uint32_t input_id,
255 uint32_t output_id,
256 uint32_t flags)
257 {
258 node->type = xnn_node_type_convert;
259 node->compute_type = compute_type;
260 node->num_inputs = 1;
261 node->inputs[0] = input_id;
262 node->num_outputs = 1;
263 node->outputs[0] = output_id;
264 node->flags = flags;
265
266 node->create = create_convert_operator;
267 node->setup = setup_convert_operator;
268 }
269
xnn_define_convert(xnn_subgraph_t subgraph,uint32_t input_id,uint32_t output_id,uint32_t flags)270 enum xnn_status xnn_define_convert(
271 xnn_subgraph_t subgraph,
272 uint32_t input_id,
273 uint32_t output_id,
274 uint32_t flags)
275 {
276 enum xnn_status status;
277 if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_convert)) != xnn_status_success) {
278 return status;
279 }
280
281 if ((status = xnn_subgraph_check_input_node_id(xnn_node_type_convert, input_id, subgraph->num_values)) !=
282 xnn_status_success) {
283 return status;
284 }
285
286 const struct xnn_value* input_value = &subgraph->values[input_id];
287 status = xnn_subgraph_check_input_type_dense(xnn_node_type_convert, input_id, input_value);
288 if (status != xnn_status_success) {
289 return status;
290 }
291
292 switch (input_value->datatype) {
293 case xnn_datatype_fp16:
294 case xnn_datatype_fp32:
295 case xnn_datatype_qint8:
296 case xnn_datatype_quint8:
297 break;
298 default:
299 xnn_log_error(
300 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
301 xnn_node_type_to_string(xnn_node_type_convert), input_id,
302 xnn_datatype_to_string(input_value->datatype), input_value->datatype);
303 return xnn_status_invalid_parameter;
304 }
305
306 status = xnn_subgraph_check_output_node_id(xnn_node_type_convert, output_id, subgraph->num_values);
307 if (status != xnn_status_success) {
308 return status;
309 }
310
311 const struct xnn_value* output_value = &subgraph->values[output_id];
312 status = xnn_subgraph_check_output_type_dense(xnn_node_type_convert, output_id, output_value);
313 if (status != xnn_status_success) {
314 return status;
315 }
316
317 switch (output_value->datatype) {
318 case xnn_datatype_fp16:
319 case xnn_datatype_fp32:
320 case xnn_datatype_qint8:
321 case xnn_datatype_quint8:
322 break;
323 default:
324 xnn_log_error(
325 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
326 xnn_node_type_to_string(xnn_node_type_convert), output_id,
327 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
328 return xnn_status_invalid_parameter;
329 }
330
331 enum xnn_compute_type compute_type = validate_datatypes(input_value->datatype, output_value->datatype);
332 if (compute_type == xnn_compute_type_invalid) {
333 xnn_log_error(
334 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
335 ": mismatching datatypes across input (%s) and output (%s)",
336 xnn_node_type_to_string(xnn_node_type_convert), input_id, output_id,
337 xnn_datatype_to_string(input_value->datatype),
338 xnn_datatype_to_string(output_value->datatype));
339 return xnn_status_invalid_parameter;
340 }
341
342 switch (compute_type) {
343 case xnn_compute_type_invalid:
344 xnn_log_error(
345 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
346 ": mismatching datatypes across input (%s) and output (%s)",
347 xnn_node_type_to_string(xnn_node_type_convert), input_id, output_id,
348 xnn_datatype_to_string(input_value->datatype),
349 xnn_datatype_to_string(output_value->datatype));
350 return xnn_status_invalid_parameter;
351 case xnn_compute_type_qs8:
352 case xnn_compute_type_qu8:
353 {
354 const float input_output_scale = input_value->quantization.scale / output_value->quantization.scale;
355 if (input_output_scale < 0x1.0p-8f || input_output_scale > 0x1.0p+7f) {
356 xnn_log_error(
357 "failed to define %s operator with %.7g input-to-output scale ratio (input #%"PRIu32" scale %.7g, output #%"PRIu32" scale %.7g): "
358 "scale ratio must be in [2**-8, 2**7] range",
359 xnn_node_type_to_string(xnn_node_type_convert), input_output_scale,
360 input_id, input_value->quantization.scale, output_id, output_value->quantization.scale);
361 return xnn_status_invalid_parameter;
362 }
363 break;
364 }
365 default:
366 break;
367 }
368
369 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
370 if (node == NULL) {
371 return xnn_status_out_of_memory;
372 }
373
374 xnn_init_convert_node(node, compute_type, input_id, output_id, flags);
375 return xnn_status_success;
376 }
377