1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/requantization.h>
16 #include <xnnpack/subgraph.h>
17 #include <xnnpack/subgraph-validation.h>
18
19
create_clamp_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)20 static enum xnn_status create_clamp_operator(
21 const struct xnn_node* node,
22 const struct xnn_value* values,
23 size_t num_values,
24 struct xnn_operator_data* opdata,
25 const struct xnn_caches* caches)
26 {
27 assert(node->num_inputs == 1);
28 const uint32_t input_id = node->inputs[0];
29 assert(input_id != XNN_INVALID_VALUE_ID);
30 assert(input_id < num_values);
31
32 assert(node->num_outputs == 1);
33 const uint32_t output_id = node->outputs[0];
34 assert(output_id != XNN_INVALID_VALUE_ID);
35 assert(output_id < num_values);
36
37 const size_t num_input_dims = values[input_id].shape.num_dims;
38 const size_t channel_dim = num_input_dims == 0 ? 1 : values[input_id].shape.dim[num_input_dims - 1];
39
40 enum xnn_status status;
41 switch (node->compute_type) {
42 #ifndef XNN_NO_F16_OPERATORS
43 case xnn_compute_type_fp16:
44 status = xnn_create_clamp_nc_f16(
45 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
46 node->activation.output_min,
47 node->activation.output_max,
48 node->flags,
49 &opdata->operator_objects[0]);
50 break;
51 #endif // XNN_NO_F16_OPERATORS
52 case xnn_compute_type_fp32:
53 status = xnn_create_clamp_nc_f32(
54 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
55 node->activation.output_min,
56 node->activation.output_max,
57 node->flags,
58 &opdata->operator_objects[0]);
59 break;
60 #ifndef XNN_NO_S8_OPERATORS
61 case xnn_compute_type_qs8:
62 {
63 const float output_scale = values[output_id].quantization.scale;
64 const int32_t output_zero_point = values[output_id].quantization.zero_point;
65 const int8_t output_min = xnn_qs8_quantize(node->activation.output_min, output_scale, output_zero_point);
66 const int8_t output_max = xnn_qs8_quantize(node->activation.output_max, output_scale, output_zero_point);
67 status = xnn_create_clamp_nc_s8(
68 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
69 output_min,
70 output_max,
71 node->flags,
72 &opdata->operator_objects[0]);
73 break;
74 }
75 #endif // !defined(XNN_NO_S8_OPERATORS)
76 #ifndef XNN_NO_U8_OPERATORS
77 case xnn_compute_type_qu8:
78 {
79 const float output_scale = values[output_id].quantization.scale;
80 const int32_t output_zero_point = values[output_id].quantization.zero_point;
81 const uint8_t output_min = xnn_qu8_quantize(node->activation.output_min, output_scale, output_zero_point);
82 const uint8_t output_max = xnn_qu8_quantize(node->activation.output_max, output_scale, output_zero_point);
83 status = xnn_create_clamp_nc_u8(
84 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
85 output_min,
86 output_max,
87 node->flags,
88 &opdata->operator_objects[0]);
89 break;
90 }
91 #endif // !defined(XNN_NO_U8_OPERATORS)
92 default:
93 XNN_UNREACHABLE;
94 }
95 if (status == xnn_status_success) {
96 opdata->batch_size = xnn_shape_multiply_non_channel_dims(&values[input_id].shape);
97 opdata->inputs[0] = input_id;
98 opdata->outputs[0] = output_id;
99 }
100 return status;
101 }
102
setup_clamp_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)103 static enum xnn_status setup_clamp_operator(
104 const struct xnn_operator_data* opdata,
105 const struct xnn_blob* blobs,
106 size_t num_blobs,
107 pthreadpool_t threadpool)
108 {
109 const uint32_t input_id = opdata->inputs[0];
110 assert(input_id != XNN_INVALID_VALUE_ID);
111 assert(input_id < num_blobs);
112
113 const uint32_t output_id = opdata->outputs[0];
114 assert(output_id != XNN_INVALID_VALUE_ID);
115 assert(output_id < num_blobs);
116
117 const struct xnn_blob* input_blob = blobs + input_id;
118 const void* input_data = input_blob->data;
119 assert(input_data != NULL);
120
121 const struct xnn_blob* output_blob = blobs + output_id;
122 void* output_data = output_blob->data;
123 assert(output_data != NULL);
124
125 switch (opdata->operator_objects[0]->type) {
126 #ifndef XNN_NO_F16_OPERATORS
127 case xnn_operator_type_clamp_nc_f16:
128 return xnn_setup_clamp_nc_f16(
129 opdata->operator_objects[0],
130 opdata->batch_size,
131 input_data,
132 output_data,
133 threadpool);
134 #endif // !defined(XNN_NO_F16_OPERATORS)
135 case xnn_operator_type_clamp_nc_f32:
136 return xnn_setup_clamp_nc_f32(
137 opdata->operator_objects[0],
138 opdata->batch_size,
139 input_data,
140 output_data,
141 threadpool);
142 #ifndef XNN_NO_S8_OPERATORS
143 case xnn_operator_type_clamp_nc_s8:
144 return xnn_setup_clamp_nc_s8(
145 opdata->operator_objects[0],
146 opdata->batch_size,
147 input_data,
148 output_data,
149 threadpool);
150 #endif // !defined(XNN_NO_S8_OPERATORS)
151 #ifndef XNN_NO_U8_OPERATORS
152 case xnn_operator_type_clamp_nc_u8:
153 return xnn_setup_clamp_nc_u8(
154 opdata->operator_objects[0],
155 opdata->batch_size,
156 input_data,
157 output_data,
158 threadpool);
159 break;
160 #endif // !defined(XNN_NO_U8_OPERATORS)
161 default:
162 XNN_UNREACHABLE;
163 }
164 }
165
xnn_define_clamp(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input_id,uint32_t output_id,uint32_t flags)166 enum xnn_status xnn_define_clamp(
167 xnn_subgraph_t subgraph,
168 float output_min,
169 float output_max,
170 uint32_t input_id,
171 uint32_t output_id,
172 uint32_t flags)
173 {
174 enum xnn_status status;
175 if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_clamp)) != xnn_status_success) {
176 return status;
177 }
178
179 if ((status = xnn_subgraph_check_input_node_id(xnn_node_type_clamp, input_id, subgraph->num_values)) !=
180 xnn_status_success) {
181 return status;
182 }
183
184 const struct xnn_value* input_value = &subgraph->values[input_id];
185 status = xnn_subgraph_check_input_type_dense(xnn_node_type_clamp, input_id, input_value);
186 if (status != xnn_status_success) {
187 return status;
188 }
189
190 switch (input_value->datatype) {
191 case xnn_datatype_fp32:
192 #ifndef XNN_NO_S8_OPERATORS
193 case xnn_datatype_qint8:
194 #endif // !defined(XNN_NO_S8_OPERATORS)
195 #ifndef XNN_NO_U8_OPERATORS
196 case xnn_datatype_quint8:
197 #endif // !defined(XNN_NO_U8_OPERATORS)
198 break;
199 default:
200 xnn_log_error(
201 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
202 xnn_node_type_to_string(xnn_node_type_clamp), input_id,
203 xnn_datatype_to_string(input_value->datatype), input_value->datatype);
204 return xnn_status_invalid_parameter;
205 }
206
207 status = xnn_subgraph_check_output_node_id(xnn_node_type_clamp, output_id, subgraph->num_values);
208 if (status != xnn_status_success) {
209 return status;
210 }
211
212 const struct xnn_value* output_value = &subgraph->values[output_id];
213 status = xnn_subgraph_check_output_type_dense(xnn_node_type_clamp, output_id, output_value);
214 if (status != xnn_status_success) {
215 return status;
216 }
217
218 enum xnn_compute_type compute_type = xnn_compute_type_invalid;
219 switch (output_value->datatype) {
220 case xnn_datatype_fp32:
221 compute_type = xnn_compute_type_fp32;
222 break;
223 #ifndef XNN_NO_S8_OPERATORS
224 case xnn_datatype_qint8:
225 compute_type = xnn_compute_type_qs8;
226 break;
227 #endif // !defined(XNN_NO_S8_OPERATORS)
228 #ifndef XNN_NO_U8_OPERATORS
229 case xnn_datatype_quint8:
230 compute_type = xnn_compute_type_qu8;
231 break;
232 #endif // !defined(XNN_NO_U8_OPERATORS)
233 default:
234 xnn_log_error(
235 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
236 xnn_node_type_to_string(xnn_node_type_clamp), output_id,
237 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
238 return xnn_status_invalid_parameter;
239 }
240 assert(compute_type != xnn_compute_type_invalid);
241
242 status = xnn_subgraph_check_datatype_matches(xnn_node_type_clamp, input_id, input_value, output_id, output_value);
243 if (status != xnn_status_success) {
244 return status;
245 }
246
247 #if !defined(XNN_NO_U8_OPERATORS) || !defined(XNN_NO_S8_OPERATORS)
248 if (compute_type == xnn_datatype_qint8 || compute_type == xnn_datatype_quint8) {
249 if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
250 xnn_log_error(
251 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
252 ": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
253 xnn_node_type_to_string(xnn_node_type_clamp), input_id, output_id,
254 input_value->quantization.zero_point, output_value->quantization.zero_point);
255 return xnn_status_invalid_parameter;
256 }
257 if (input_value->quantization.scale != output_value->quantization.scale) {
258 xnn_log_error(
259 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
260 ": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
261 xnn_node_type_to_string(xnn_node_type_clamp), input_id, output_id,
262 input_value->quantization.scale, output_value->quantization.scale);
263 return xnn_status_invalid_parameter;
264 }
265 }
266 #endif // !defined(XNN_NO_U8_OPERATORS) || !defined(XNN_NO_S8_OPERATORS)
267
268 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
269 if (node == NULL) {
270 return xnn_status_out_of_memory;
271 }
272
273 node->type = xnn_node_type_clamp;
274 node->compute_type = compute_type;
275 node->activation.output_min = output_min;
276 node->activation.output_max = output_max;
277 node->num_inputs = 1;
278 node->inputs[0] = input_id;
279 node->num_outputs = 1;
280 node->outputs[0] = output_id;
281 node->flags = flags;
282
283 node->create = create_clamp_operator;
284 node->setup = setup_clamp_operator;
285
286 return xnn_status_success;
287 }
288