xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/elementwise.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdint.h>
17 #include <stdlib.h>
18 
19 #include <algorithm>
20 #include <cmath>
21 #include <functional>
22 #include <limits>
23 
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/kernels/internal/quantization_util.h"
26 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
27 #include "tensorflow/lite/kernels/internal/tensor.h"
28 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
29 #include "tensorflow/lite/kernels/kernel_util.h"
30 #include "tensorflow/lite/kernels/op_macros.h"
31 
32 namespace tflite {
33 namespace ops {
34 namespace builtin {
35 namespace elementwise {
36 namespace {
37 
38 const char kAbsName[] = "Abs";
39 const char kRsqrtName[] = "Rsqrt";
40 
41 struct OpData {
42   int32_t multiplier;
43   int32_t shift;
44   int input_offset;
45   int output_offset;
46   bool needs_rescale;
47 };
48 
IsNumericSupportedType(const TfLiteType type)49 bool IsNumericSupportedType(const TfLiteType type) {
50   return type == kTfLiteFloat32;
51 }
52 
IsLogicalSupportedType(const TfLiteType type)53 bool IsLogicalSupportedType(const TfLiteType type) {
54   return type == kTfLiteBool;
55 }
56 
IsAbsSupportedType(const TfLiteType type)57 bool IsAbsSupportedType(const TfLiteType type) {
58   return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
59 }
60 
IsRsqrtSupportedType(const TfLiteType type)61 bool IsRsqrtSupportedType(const TfLiteType type) {
62   return type == kTfLiteFloat32 || type == kTfLiteInt8;
63 }
64 
SetAbsOutputMultiplier(const float input_scale,const float output_scale,int32_t * multiplier,int32_t * shift)65 inline void SetAbsOutputMultiplier(const float input_scale,
66                                    const float output_scale,
67                                    int32_t* multiplier, int32_t* shift) {
68   QuantizeMultiplier(input_scale / output_scale, multiplier, shift);
69 }
70 
SetRsqrtOutputMultiplier(const float input_scale,const float output_scale,int32_t * multiplier,int32_t * shift)71 inline void SetRsqrtOutputMultiplier(const float input_scale,
72                                      const float output_scale,
73                                      int32_t* multiplier, int32_t* shift) {
74   const double scale = 1. / (std::sqrt(input_scale) * output_scale);
75   QuantizeMultiplier(scale, multiplier, shift);
76 }
77 
78 typedef bool (*IsSupportedType)(TfLiteType);
GenericPrepare(TfLiteContext * context,TfLiteNode * node,IsSupportedType is_supported_type,const char * op_name)79 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node,
80                             IsSupportedType is_supported_type,
81                             const char* op_name) {
82   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
83   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
84   const TfLiteTensor* input;
85   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
86   TfLiteTensor* output;
87   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
88   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
89   if (!is_supported_type(input->type)) {
90     TF_LITE_UNSUPPORTED_TYPE(context, input->type, op_name);
91   }
92   // For int16 type input, we support both quantized and non-quantized
93   // evaluation.
94   if (input->type == kTfLiteInt8 ||
95       (input->type == kTfLiteInt16 &&
96        input->quantization.type != kTfLiteNoQuantization)) {
97     TfLiteTensor* output = GetOutput(context, node, 0);
98     auto* op_data = static_cast<OpData*>(node->user_data);
99     TF_LITE_ENSURE_EQ(context, input->quantization.type,
100                       kTfLiteAffineQuantization);
101     TF_LITE_ENSURE_EQ(context, output->quantization.type,
102                       kTfLiteAffineQuantization);
103     const auto* input_params =
104         reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
105     const auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
106         output->quantization.params);
107     TF_LITE_ENSURE(context, input_params != nullptr);
108     TF_LITE_ENSURE(context, input_params->scale != nullptr);
109     TF_LITE_ENSURE(context, input_params->scale->size > 0);
110     TF_LITE_ENSURE(context, input_params->zero_point->size > 0);
111     TF_LITE_ENSURE(context, output_params != nullptr);
112     TF_LITE_ENSURE(context, output_params->scale != nullptr);
113     TF_LITE_ENSURE(context, output_params->scale->size > 0);
114     TF_LITE_ENSURE(context, output_params->zero_point->size > 0);
115     op_data->input_offset = input_params->zero_point->data[0];
116     op_data->output_offset = output_params->zero_point->data[0];
117     if (input->type == kTfLiteInt16) {
118       TF_LITE_ENSURE_EQ(context, op_data->input_offset, 0);
119       TF_LITE_ENSURE_EQ(context, op_data->output_offset, 0);
120     }
121     const float input_scale = input_params->scale->data[0];
122     const float output_scale = output_params->scale->data[0];
123     op_data->needs_rescale = input_scale != output_scale;
124     if (op_name == kAbsName && op_data->needs_rescale) {
125       SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
126                              &op_data->shift);
127     } else if (op_name == kRsqrtName) {
128       SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
129                                &op_data->shift);
130     }
131   }
132   return context->ResizeTensor(context, output,
133                                TfLiteIntArrayCopy(input->dims));
134 }
135 
136 template <typename T>
EvalImpl(TfLiteContext * context,TfLiteNode * node,std::function<T (T)> func,std::function<TfLiteStatus (T)> validate_input_func,TfLiteType expected_type)137 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
138                              std::function<T(T)> func,
139                              std::function<TfLiteStatus(T)> validate_input_func,
140                              TfLiteType expected_type) {
141   const TfLiteTensor* input;
142   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
143   TfLiteTensor* output;
144   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
145   TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
146   const int64_t num_elements = NumElements(input);
147   const T* in_data = GetTensorData<T>(input);
148   T* out_data = GetTensorData<T>(output);
149   for (int64_t i = 0; i < num_elements; ++i) {
150     if (validate_input_func) {
151       TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i]));
152     }
153     out_data[i] = func(in_data[i]);
154   }
155   return kTfLiteOk;
156 }
157 
158 // Non-quantized evaluation of Abs op when input is int16.
AbsInt16EvalImpl(TfLiteContext * context,TfLiteNode * node,TfLiteType expected_type)159 inline TfLiteStatus AbsInt16EvalImpl(TfLiteContext* context, TfLiteNode* node,
160                                      TfLiteType expected_type) {
161   const TfLiteTensor* input;
162   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
163   TfLiteTensor* output;
164   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
165   TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
166   const int64_t num_elements = NumElements(input);
167   const int16_t* in_data = GetTensorData<int16_t>(input);
168   int16_t* out_data = GetTensorData<int16_t>(output);
169   for (int64_t i = 0; i < num_elements; ++i) {
170     out_data[i] = static_cast<int16_t>(
171         std::abs<int32_t>(static_cast<int32_t>(in_data[i])));
172   }
173   return kTfLiteOk;
174 }
175 
176 template <typename T>
EvalImpl(TfLiteContext * context,TfLiteNode * node,std::function<T (T)> func,TfLiteType expected_type)177 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
178                              std::function<T(T)> func,
179                              TfLiteType expected_type) {
180   return EvalImpl<T>(context, node, func, /*validate_input_func=*/nullptr,
181                      expected_type);
182 }
183 
EvalNumeric(TfLiteContext * context,TfLiteNode * node,float float_func (float))184 inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
185                                 float float_func(float)) {
186   return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
187 }
188 
EvalLogical(TfLiteContext * context,TfLiteNode * node,bool bool_func (bool))189 inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
190                                 bool bool_func(bool)) {
191   return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
192 }
193 
ElementWiseQuantizedInit(TfLiteContext * context,const char * buffer,size_t length)194 void* ElementWiseQuantizedInit(TfLiteContext* context, const char* buffer,
195                                size_t length) {
196   return new OpData();
197 }
198 
ElementWiseQuantizedFree(TfLiteContext * context,void * buffer)199 void ElementWiseQuantizedFree(TfLiteContext* context, void* buffer) {
200   delete static_cast<OpData*>(buffer);
201 }
202 
203 template <typename T>
AbsEvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteType type)204 TfLiteStatus AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node,
205                               TfLiteType type) {
206   const auto* op_data = static_cast<const OpData*>(node->user_data);
207   const int kMin = std::numeric_limits<T>::min();
208   const int kMax = std::numeric_limits<T>::max();
209 
210   std::function<T(T)> func = [&](T i) {
211     const int32_t value = std::abs(i - op_data->input_offset);
212     if (!op_data->needs_rescale) {
213       return static_cast<T>(
214           std::min(std::max(value + op_data->output_offset, kMin), kMax));
215     }
216     const int32_t output = MultiplyByQuantizedMultiplier(
217                                value, op_data->multiplier, op_data->shift) +
218                            op_data->output_offset;
219     return static_cast<T>(std::min(std::max(output, kMin), kMax));
220   };
221 
222   return EvalImpl<T>(context, node, func, type);
223 }
224 
AbsEval(TfLiteContext * context,TfLiteNode * node)225 TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
226   const TfLiteTensor* input = GetInput(context, node, 0);
227   const TfLiteType type = input->type;
228   switch (type) {
229     case kTfLiteFloat32:
230       return EvalImpl<float>(context, node, std::abs<float>, type);
231     case kTfLiteInt8:
232       return AbsEvalQuantized<int8_t>(context, node, type);
233     case kTfLiteInt16:
234       return input->quantization.type == kTfLiteNoQuantization
235                  ? AbsInt16EvalImpl(context, node, type)
236                  : AbsEvalQuantized<int16_t>(context, node, type);
237     default:
238       TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
239                          TfLiteTypeGetName(type));
240       return kTfLiteError;
241   }
242 }
243 
SinEval(TfLiteContext * context,TfLiteNode * node)244 TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
245   return EvalNumeric(context, node, std::sin);
246 }
247 
CosEval(TfLiteContext * context,TfLiteNode * node)248 TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
249   return EvalNumeric(context, node, std::cos);
250 }
251 
LogEval(TfLiteContext * context,TfLiteNode * node)252 TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
253   return EvalNumeric(context, node, std::log);
254 }
255 
SqrtEval(TfLiteContext * context,TfLiteNode * node)256 TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
257   return EvalNumeric(context, node, std::sqrt);
258 }
259 
RsqrtEvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteType type)260 TfLiteStatus RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node,
261                                 TfLiteType type) {
262   const auto* op_data = static_cast<const OpData*>(node->user_data);
263   const int kMin = std::numeric_limits<int8_t>::min();
264   const int kMax = std::numeric_limits<int8_t>::max();
265   std::function<TfLiteStatus(int8_t)> validate_input_func = [&](int8_t i) {
266     TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset,
267                        "Rsqrt is only defined for positive values");
268     return kTfLiteOk;
269   };
270 
271   std::function<int8_t(int8_t)> func = [&](int8_t i) {
272     const int32_t value = (i - op_data->input_offset);
273     const int32_t kShift = 20;  // Shift to keep value integer.
274     if (value == 0) {
275       // Assume that any value close to 0 represents the max output value.
276       return static_cast<int8_t>(kMax);
277     }
278     int32_t inv_sqrt_multiplier;
279     int inv_sqrt_shift;
280     GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier,
281                                      &inv_sqrt_shift);
282     const int32_t data = MultiplyByQuantizedMultiplier(1, inv_sqrt_multiplier,
283                                                        inv_sqrt_shift + kShift);
284     const int32_t output =
285         MultiplyByQuantizedMultiplier(data, op_data->multiplier,
286                                       op_data->shift - kShift) +
287         op_data->output_offset;
288     return static_cast<int8_t>(std::min(std::max(output, kMin), kMax));
289   };
290 
291   return EvalImpl<int8_t>(context, node, func, validate_input_func, type);
292 }
293 
RsqrtEval(TfLiteContext * context,TfLiteNode * node)294 TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
295   const TfLiteType type = GetInput(context, node, 0)->type;
296   switch (type) {
297     case kTfLiteFloat32:
298       return EvalImpl<float>(
299           context, node, [](float f) { return 1.f / std::sqrt(f); }, type);
300     case kTfLiteInt8:
301       return RsqrtEvalQuantized(context, node, type);
302     default:
303       TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
304                          TfLiteTypeGetName(type));
305       return kTfLiteError;
306   }
307 }
308 
SquareEval(TfLiteContext * context,TfLiteNode * node)309 TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
310   return EvalNumeric(context, node, [](float f) { return f * f; });
311 }
312 
LogicalNotEval(TfLiteContext * context,TfLiteNode * node)313 TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
314   return EvalLogical(context, node, [](bool v) { return !v; });
315 }
316 
317 }  // namespace
318 }  // namespace elementwise
319 
320 // Given a function...
321 // template<int T>
322 // int Foo(int b)
323 //
324 // typedef int(*Bar)(int);
325 //
326 // MSVC2015 will not see Foo<10> as the same type as Bar.
327 //
328 // This works around the issue by instantiating wrapper methods around
329 // elementwise::GenericPrepare() rather than using a templated
330 // elementwise::GenericPrepare method.
331 #define GENERIC_PREPARE(function_name, is_supported_type_function, type_name)  \
332   static TfLiteStatus function_name(TfLiteContext* context,                    \
333                                     TfLiteNode* node) {                        \
334     return elementwise::GenericPrepare(context, node,                          \
335                                        is_supported_type_function, type_name); \
336   }
337 
GENERIC_PREPARE(PrepareAbs,elementwise::IsAbsSupportedType,elementwise::kAbsName)338 GENERIC_PREPARE(PrepareAbs, elementwise::IsAbsSupportedType,
339                 elementwise::kAbsName)
340 
341 TfLiteRegistration* Register_ABS() {
342   static TfLiteRegistration r = {elementwise::ElementWiseQuantizedInit,
343                                  elementwise::ElementWiseQuantizedFree,
344                                  PrepareAbs, elementwise::AbsEval};
345   return &r;
346 }
347 
348 GENERIC_PREPARE(PrepareSin, elementwise::IsNumericSupportedType, "Sin")
349 
Register_SIN()350 TfLiteRegistration* Register_SIN() {
351   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareSin,
352                                  elementwise::SinEval};
353   return &r;
354 }
355 
356 GENERIC_PREPARE(PrepareCos, elementwise::IsNumericSupportedType, "Cos")
357 
Register_COS()358 TfLiteRegistration* Register_COS() {
359   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareCos,
360                                  elementwise::CosEval};
361   return &r;
362 }
363 
364 GENERIC_PREPARE(PrepareLog, elementwise::IsNumericSupportedType, "Log")
365 
Register_LOG()366 TfLiteRegistration* Register_LOG() {
367   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareLog,
368                                  elementwise::LogEval};
369   return &r;
370 }
371 
372 GENERIC_PREPARE(PrepareSqrt, elementwise::IsNumericSupportedType, "Sqrt")
373 
Register_SQRT()374 TfLiteRegistration* Register_SQRT() {
375   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
376                                  PrepareSqrt, elementwise::SqrtEval};
377   return &r;
378 }
379 
GENERIC_PREPARE(PrepareRsqrt,elementwise::IsRsqrtSupportedType,elementwise::kRsqrtName)380 GENERIC_PREPARE(PrepareRsqrt, elementwise::IsRsqrtSupportedType,
381                 elementwise::kRsqrtName)
382 
383 TfLiteRegistration* Register_RSQRT() {
384   static TfLiteRegistration r = {elementwise::ElementWiseQuantizedInit,
385                                  elementwise::ElementWiseQuantizedFree,
386                                  PrepareRsqrt, elementwise::RsqrtEval};
387   return &r;
388 }
389 
390 GENERIC_PREPARE(PrepareSquare, elementwise::IsNumericSupportedType, "Square")
391 
Register_SQUARE()392 TfLiteRegistration* Register_SQUARE() {
393   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
394                                  PrepareSquare, elementwise::SquareEval};
395   return &r;
396 }
397 
398 GENERIC_PREPARE(PrepareNot, elementwise::IsLogicalSupportedType, "Not")
399 
Register_LOGICAL_NOT()400 TfLiteRegistration* Register_LOGICAL_NOT() {
401   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareNot,
402                                  elementwise::LogicalNotEval};
403   return &r;
404 }
405 
406 }  // namespace builtin
407 }  // namespace ops
408 }  // namespace tflite
409