xref: /aosp_15_r20/external/XNNPACK/src/subgraph/clamp.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/requantization.h>
16 #include <xnnpack/subgraph.h>
17 #include <xnnpack/subgraph-validation.h>
18 
19 
create_clamp_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)20 static enum xnn_status create_clamp_operator(
21   const struct xnn_node* node,
22   const struct xnn_value* values,
23   size_t num_values,
24   struct xnn_operator_data* opdata,
25   const struct xnn_caches* caches)
26 {
27   assert(node->num_inputs == 1);
28   const uint32_t input_id = node->inputs[0];
29   assert(input_id != XNN_INVALID_VALUE_ID);
30   assert(input_id < num_values);
31 
32   assert(node->num_outputs == 1);
33   const uint32_t output_id = node->outputs[0];
34   assert(output_id != XNN_INVALID_VALUE_ID);
35   assert(output_id < num_values);
36 
37   const size_t num_input_dims = values[input_id].shape.num_dims;
38   const size_t channel_dim = num_input_dims == 0 ? 1 : values[input_id].shape.dim[num_input_dims - 1];
39 
40   enum xnn_status status;
41   switch (node->compute_type) {
42 #ifndef XNN_NO_F16_OPERATORS
43     case xnn_compute_type_fp16:
44       status = xnn_create_clamp_nc_f16(
45         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
46         node->activation.output_min,
47         node->activation.output_max,
48         node->flags,
49         &opdata->operator_objects[0]);
50       break;
51 #endif  // XNN_NO_F16_OPERATORS
52     case xnn_compute_type_fp32:
53       status = xnn_create_clamp_nc_f32(
54         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
55         node->activation.output_min,
56         node->activation.output_max,
57         node->flags,
58         &opdata->operator_objects[0]);
59       break;
60 #ifndef XNN_NO_S8_OPERATORS
61     case xnn_compute_type_qs8:
62     {
63       const float output_scale = values[output_id].quantization.scale;
64       const int32_t output_zero_point = values[output_id].quantization.zero_point;
65       const int8_t output_min = xnn_qs8_quantize(node->activation.output_min, output_scale, output_zero_point);
66       const int8_t output_max = xnn_qs8_quantize(node->activation.output_max, output_scale, output_zero_point);
67       status = xnn_create_clamp_nc_s8(
68         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
69         output_min,
70         output_max,
71         node->flags,
72         &opdata->operator_objects[0]);
73       break;
74     }
75 #endif  // !defined(XNN_NO_S8_OPERATORS)
76 #ifndef XNN_NO_U8_OPERATORS
77     case xnn_compute_type_qu8:
78     {
79       const float output_scale = values[output_id].quantization.scale;
80       const int32_t output_zero_point = values[output_id].quantization.zero_point;
81       const uint8_t output_min = xnn_qu8_quantize(node->activation.output_min, output_scale, output_zero_point);
82       const uint8_t output_max = xnn_qu8_quantize(node->activation.output_max, output_scale, output_zero_point);
83       status = xnn_create_clamp_nc_u8(
84         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
85         output_min,
86         output_max,
87         node->flags,
88         &opdata->operator_objects[0]);
89       break;
90     }
91 #endif  // !defined(XNN_NO_U8_OPERATORS)
92     default:
93       XNN_UNREACHABLE;
94   }
95   if (status == xnn_status_success) {
96     opdata->batch_size = xnn_shape_multiply_non_channel_dims(&values[input_id].shape);
97     opdata->inputs[0] = input_id;
98     opdata->outputs[0] = output_id;
99   }
100   return status;
101 }
102 
setup_clamp_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)103 static enum xnn_status setup_clamp_operator(
104   const struct xnn_operator_data* opdata,
105   const struct xnn_blob* blobs,
106   size_t num_blobs,
107   pthreadpool_t threadpool)
108 {
109   const uint32_t input_id = opdata->inputs[0];
110   assert(input_id != XNN_INVALID_VALUE_ID);
111   assert(input_id < num_blobs);
112 
113   const uint32_t output_id = opdata->outputs[0];
114   assert(output_id != XNN_INVALID_VALUE_ID);
115   assert(output_id < num_blobs);
116 
117   const struct xnn_blob* input_blob = blobs + input_id;
118   const void* input_data = input_blob->data;
119   assert(input_data != NULL);
120 
121   const struct xnn_blob* output_blob = blobs + output_id;
122   void* output_data = output_blob->data;
123   assert(output_data != NULL);
124 
125   switch (opdata->operator_objects[0]->type) {
126 #ifndef XNN_NO_F16_OPERATORS
127     case xnn_operator_type_clamp_nc_f16:
128       return xnn_setup_clamp_nc_f16(
129         opdata->operator_objects[0],
130         opdata->batch_size,
131         input_data,
132         output_data,
133         threadpool);
134 #endif  // !defined(XNN_NO_F16_OPERATORS)
135     case xnn_operator_type_clamp_nc_f32:
136       return xnn_setup_clamp_nc_f32(
137         opdata->operator_objects[0],
138         opdata->batch_size,
139         input_data,
140         output_data,
141         threadpool);
142 #ifndef XNN_NO_S8_OPERATORS
143     case xnn_operator_type_clamp_nc_s8:
144       return xnn_setup_clamp_nc_s8(
145         opdata->operator_objects[0],
146         opdata->batch_size,
147         input_data,
148         output_data,
149         threadpool);
150 #endif  // !defined(XNN_NO_S8_OPERATORS)
151 #ifndef XNN_NO_U8_OPERATORS
152     case xnn_operator_type_clamp_nc_u8:
153       return xnn_setup_clamp_nc_u8(
154         opdata->operator_objects[0],
155         opdata->batch_size,
156         input_data,
157         output_data,
158         threadpool);
159       break;
160 #endif  // !defined(XNN_NO_U8_OPERATORS)
161     default:
162       XNN_UNREACHABLE;
163   }
164 }
165 
xnn_define_clamp(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input_id,uint32_t output_id,uint32_t flags)166 enum xnn_status xnn_define_clamp(
167   xnn_subgraph_t subgraph,
168   float output_min,
169   float output_max,
170   uint32_t input_id,
171   uint32_t output_id,
172   uint32_t flags)
173 {
174   enum xnn_status status;
175   if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_clamp)) != xnn_status_success) {
176     return status;
177   }
178 
179   if ((status = xnn_subgraph_check_input_node_id(xnn_node_type_clamp, input_id, subgraph->num_values)) !=
180       xnn_status_success) {
181     return status;
182   }
183 
184   const struct xnn_value* input_value = &subgraph->values[input_id];
185   status = xnn_subgraph_check_input_type_dense(xnn_node_type_clamp, input_id, input_value);
186   if (status != xnn_status_success) {
187     return status;
188   }
189 
190   switch (input_value->datatype) {
191     case xnn_datatype_fp32:
192 #ifndef XNN_NO_S8_OPERATORS
193     case xnn_datatype_qint8:
194 #endif  // !defined(XNN_NO_S8_OPERATORS)
195 #ifndef XNN_NO_U8_OPERATORS
196     case xnn_datatype_quint8:
197 #endif  // !defined(XNN_NO_U8_OPERATORS)
198       break;
199     default:
200       xnn_log_error(
201         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
202         xnn_node_type_to_string(xnn_node_type_clamp), input_id,
203         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
204       return xnn_status_invalid_parameter;
205   }
206 
207   status = xnn_subgraph_check_output_node_id(xnn_node_type_clamp, output_id, subgraph->num_values);
208   if (status != xnn_status_success) {
209     return status;
210   }
211 
212   const struct xnn_value* output_value = &subgraph->values[output_id];
213   status = xnn_subgraph_check_output_type_dense(xnn_node_type_clamp, output_id, output_value);
214   if (status != xnn_status_success) {
215     return status;
216   }
217 
218   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
219   switch (output_value->datatype) {
220     case xnn_datatype_fp32:
221       compute_type = xnn_compute_type_fp32;
222       break;
223 #ifndef XNN_NO_S8_OPERATORS
224     case xnn_datatype_qint8:
225       compute_type = xnn_compute_type_qs8;
226       break;
227 #endif  // !defined(XNN_NO_S8_OPERATORS)
228 #ifndef XNN_NO_U8_OPERATORS
229     case xnn_datatype_quint8:
230       compute_type = xnn_compute_type_qu8;
231       break;
232 #endif  // !defined(XNN_NO_U8_OPERATORS)
233     default:
234       xnn_log_error(
235         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
236         xnn_node_type_to_string(xnn_node_type_clamp), output_id,
237         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
238       return xnn_status_invalid_parameter;
239   }
240   assert(compute_type != xnn_compute_type_invalid);
241 
242   status = xnn_subgraph_check_datatype_matches(xnn_node_type_clamp, input_id, input_value, output_id, output_value);
243   if (status != xnn_status_success) {
244     return status;
245   }
246 
247 #if !defined(XNN_NO_U8_OPERATORS) || !defined(XNN_NO_S8_OPERATORS)
248   if (compute_type == xnn_datatype_qint8 || compute_type == xnn_datatype_quint8) {
249     if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
250       xnn_log_error(
251         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
252         ": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
253         xnn_node_type_to_string(xnn_node_type_clamp), input_id, output_id,
254         input_value->quantization.zero_point, output_value->quantization.zero_point);
255       return xnn_status_invalid_parameter;
256     }
257     if (input_value->quantization.scale != output_value->quantization.scale) {
258       xnn_log_error(
259         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
260         ": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
261         xnn_node_type_to_string(xnn_node_type_clamp), input_id, output_id,
262         input_value->quantization.scale, output_value->quantization.scale);
263       return xnn_status_invalid_parameter;
264     }
265   }
266 #endif  // !defined(XNN_NO_U8_OPERATORS) || !defined(XNN_NO_S8_OPERATORS)
267 
268   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
269   if (node == NULL) {
270     return xnn_status_out_of_memory;
271   }
272 
273   node->type = xnn_node_type_clamp;
274   node->compute_type = compute_type;
275   node->activation.output_min = output_min;
276   node->activation.output_max = output_max;
277   node->num_inputs = 1;
278   node->inputs[0] = input_id;
279   node->num_outputs = 1;
280   node->outputs[0] = output_id;
281   node->flags = flags;
282 
283   node->create = create_clamp_operator;
284   node->setup = setup_clamp_operator;
285 
286   return xnn_status_success;
287 }
288