1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <stdint.h>
17 #include <stdlib.h>
18
19 #include <algorithm>
20 #include <cmath>
21 #include <functional>
22 #include <limits>
23
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/kernels/internal/quantization_util.h"
26 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
27 #include "tensorflow/lite/kernels/internal/tensor.h"
28 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
29 #include "tensorflow/lite/kernels/kernel_util.h"
30 #include "tensorflow/lite/kernels/op_macros.h"
31
32 namespace tflite {
33 namespace ops {
34 namespace builtin {
35 namespace elementwise {
36 namespace {
37
38 const char kAbsName[] = "Abs";
39 const char kRsqrtName[] = "Rsqrt";
40
41 struct OpData {
42 int32_t multiplier;
43 int32_t shift;
44 int input_offset;
45 int output_offset;
46 bool needs_rescale;
47 };
48
IsNumericSupportedType(const TfLiteType type)49 bool IsNumericSupportedType(const TfLiteType type) {
50 return type == kTfLiteFloat32;
51 }
52
IsLogicalSupportedType(const TfLiteType type)53 bool IsLogicalSupportedType(const TfLiteType type) {
54 return type == kTfLiteBool;
55 }
56
IsAbsSupportedType(const TfLiteType type)57 bool IsAbsSupportedType(const TfLiteType type) {
58 return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
59 }
60
IsRsqrtSupportedType(const TfLiteType type)61 bool IsRsqrtSupportedType(const TfLiteType type) {
62 return type == kTfLiteFloat32 || type == kTfLiteInt8;
63 }
64
SetAbsOutputMultiplier(const float input_scale,const float output_scale,int32_t * multiplier,int32_t * shift)65 inline void SetAbsOutputMultiplier(const float input_scale,
66 const float output_scale,
67 int32_t* multiplier, int32_t* shift) {
68 QuantizeMultiplier(input_scale / output_scale, multiplier, shift);
69 }
70
SetRsqrtOutputMultiplier(const float input_scale,const float output_scale,int32_t * multiplier,int32_t * shift)71 inline void SetRsqrtOutputMultiplier(const float input_scale,
72 const float output_scale,
73 int32_t* multiplier, int32_t* shift) {
74 const double scale = 1. / (std::sqrt(input_scale) * output_scale);
75 QuantizeMultiplier(scale, multiplier, shift);
76 }
77
78 typedef bool (*IsSupportedType)(TfLiteType);
GenericPrepare(TfLiteContext * context,TfLiteNode * node,IsSupportedType is_supported_type,const char * op_name)79 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node,
80 IsSupportedType is_supported_type,
81 const char* op_name) {
82 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
83 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
84 const TfLiteTensor* input;
85 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
86 TfLiteTensor* output;
87 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
88 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
89 if (!is_supported_type(input->type)) {
90 TF_LITE_UNSUPPORTED_TYPE(context, input->type, op_name);
91 }
92 // For int16 type input, we support both quantized and non-quantized
93 // evaluation.
94 if (input->type == kTfLiteInt8 ||
95 (input->type == kTfLiteInt16 &&
96 input->quantization.type != kTfLiteNoQuantization)) {
97 TfLiteTensor* output = GetOutput(context, node, 0);
98 auto* op_data = static_cast<OpData*>(node->user_data);
99 TF_LITE_ENSURE_EQ(context, input->quantization.type,
100 kTfLiteAffineQuantization);
101 TF_LITE_ENSURE_EQ(context, output->quantization.type,
102 kTfLiteAffineQuantization);
103 const auto* input_params =
104 reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
105 const auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
106 output->quantization.params);
107 TF_LITE_ENSURE(context, input_params != nullptr);
108 TF_LITE_ENSURE(context, input_params->scale != nullptr);
109 TF_LITE_ENSURE(context, input_params->scale->size > 0);
110 TF_LITE_ENSURE(context, input_params->zero_point->size > 0);
111 TF_LITE_ENSURE(context, output_params != nullptr);
112 TF_LITE_ENSURE(context, output_params->scale != nullptr);
113 TF_LITE_ENSURE(context, output_params->scale->size > 0);
114 TF_LITE_ENSURE(context, output_params->zero_point->size > 0);
115 op_data->input_offset = input_params->zero_point->data[0];
116 op_data->output_offset = output_params->zero_point->data[0];
117 if (input->type == kTfLiteInt16) {
118 TF_LITE_ENSURE_EQ(context, op_data->input_offset, 0);
119 TF_LITE_ENSURE_EQ(context, op_data->output_offset, 0);
120 }
121 const float input_scale = input_params->scale->data[0];
122 const float output_scale = output_params->scale->data[0];
123 op_data->needs_rescale = input_scale != output_scale;
124 if (op_name == kAbsName && op_data->needs_rescale) {
125 SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
126 &op_data->shift);
127 } else if (op_name == kRsqrtName) {
128 SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
129 &op_data->shift);
130 }
131 }
132 return context->ResizeTensor(context, output,
133 TfLiteIntArrayCopy(input->dims));
134 }
135
136 template <typename T>
EvalImpl(TfLiteContext * context,TfLiteNode * node,std::function<T (T)> func,std::function<TfLiteStatus (T)> validate_input_func,TfLiteType expected_type)137 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
138 std::function<T(T)> func,
139 std::function<TfLiteStatus(T)> validate_input_func,
140 TfLiteType expected_type) {
141 const TfLiteTensor* input;
142 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
143 TfLiteTensor* output;
144 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
145 TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
146 const int64_t num_elements = NumElements(input);
147 const T* in_data = GetTensorData<T>(input);
148 T* out_data = GetTensorData<T>(output);
149 for (int64_t i = 0; i < num_elements; ++i) {
150 if (validate_input_func) {
151 TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i]));
152 }
153 out_data[i] = func(in_data[i]);
154 }
155 return kTfLiteOk;
156 }
157
158 // Non-quantized evaluation of Abs op when input is int16.
AbsInt16EvalImpl(TfLiteContext * context,TfLiteNode * node,TfLiteType expected_type)159 inline TfLiteStatus AbsInt16EvalImpl(TfLiteContext* context, TfLiteNode* node,
160 TfLiteType expected_type) {
161 const TfLiteTensor* input;
162 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
163 TfLiteTensor* output;
164 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
165 TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
166 const int64_t num_elements = NumElements(input);
167 const int16_t* in_data = GetTensorData<int16_t>(input);
168 int16_t* out_data = GetTensorData<int16_t>(output);
169 for (int64_t i = 0; i < num_elements; ++i) {
170 out_data[i] = static_cast<int16_t>(
171 std::abs<int32_t>(static_cast<int32_t>(in_data[i])));
172 }
173 return kTfLiteOk;
174 }
175
176 template <typename T>
EvalImpl(TfLiteContext * context,TfLiteNode * node,std::function<T (T)> func,TfLiteType expected_type)177 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
178 std::function<T(T)> func,
179 TfLiteType expected_type) {
180 return EvalImpl<T>(context, node, func, /*validate_input_func=*/nullptr,
181 expected_type);
182 }
183
EvalNumeric(TfLiteContext * context,TfLiteNode * node,float float_func (float))184 inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
185 float float_func(float)) {
186 return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
187 }
188
EvalLogical(TfLiteContext * context,TfLiteNode * node,bool bool_func (bool))189 inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
190 bool bool_func(bool)) {
191 return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
192 }
193
ElementWiseQuantizedInit(TfLiteContext * context,const char * buffer,size_t length)194 void* ElementWiseQuantizedInit(TfLiteContext* context, const char* buffer,
195 size_t length) {
196 return new OpData();
197 }
198
ElementWiseQuantizedFree(TfLiteContext * context,void * buffer)199 void ElementWiseQuantizedFree(TfLiteContext* context, void* buffer) {
200 delete static_cast<OpData*>(buffer);
201 }
202
203 template <typename T>
AbsEvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteType type)204 TfLiteStatus AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node,
205 TfLiteType type) {
206 const auto* op_data = static_cast<const OpData*>(node->user_data);
207 const int kMin = std::numeric_limits<T>::min();
208 const int kMax = std::numeric_limits<T>::max();
209
210 std::function<T(T)> func = [&](T i) {
211 const int32_t value = std::abs(i - op_data->input_offset);
212 if (!op_data->needs_rescale) {
213 return static_cast<T>(
214 std::min(std::max(value + op_data->output_offset, kMin), kMax));
215 }
216 const int32_t output = MultiplyByQuantizedMultiplier(
217 value, op_data->multiplier, op_data->shift) +
218 op_data->output_offset;
219 return static_cast<T>(std::min(std::max(output, kMin), kMax));
220 };
221
222 return EvalImpl<T>(context, node, func, type);
223 }
224
AbsEval(TfLiteContext * context,TfLiteNode * node)225 TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
226 const TfLiteTensor* input = GetInput(context, node, 0);
227 const TfLiteType type = input->type;
228 switch (type) {
229 case kTfLiteFloat32:
230 return EvalImpl<float>(context, node, std::abs<float>, type);
231 case kTfLiteInt8:
232 return AbsEvalQuantized<int8_t>(context, node, type);
233 case kTfLiteInt16:
234 return input->quantization.type == kTfLiteNoQuantization
235 ? AbsInt16EvalImpl(context, node, type)
236 : AbsEvalQuantized<int16_t>(context, node, type);
237 default:
238 TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
239 TfLiteTypeGetName(type));
240 return kTfLiteError;
241 }
242 }
243
SinEval(TfLiteContext * context,TfLiteNode * node)244 TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
245 return EvalNumeric(context, node, std::sin);
246 }
247
CosEval(TfLiteContext * context,TfLiteNode * node)248 TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
249 return EvalNumeric(context, node, std::cos);
250 }
251
LogEval(TfLiteContext * context,TfLiteNode * node)252 TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
253 return EvalNumeric(context, node, std::log);
254 }
255
SqrtEval(TfLiteContext * context,TfLiteNode * node)256 TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
257 return EvalNumeric(context, node, std::sqrt);
258 }
259
RsqrtEvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteType type)260 TfLiteStatus RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node,
261 TfLiteType type) {
262 const auto* op_data = static_cast<const OpData*>(node->user_data);
263 const int kMin = std::numeric_limits<int8_t>::min();
264 const int kMax = std::numeric_limits<int8_t>::max();
265 std::function<TfLiteStatus(int8_t)> validate_input_func = [&](int8_t i) {
266 TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset,
267 "Rsqrt is only defined for positive values");
268 return kTfLiteOk;
269 };
270
271 std::function<int8_t(int8_t)> func = [&](int8_t i) {
272 const int32_t value = (i - op_data->input_offset);
273 const int32_t kShift = 20; // Shift to keep value integer.
274 if (value == 0) {
275 // Assume that any value close to 0 represents the max output value.
276 return static_cast<int8_t>(kMax);
277 }
278 int32_t inv_sqrt_multiplier;
279 int inv_sqrt_shift;
280 GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier,
281 &inv_sqrt_shift);
282 const int32_t data = MultiplyByQuantizedMultiplier(1, inv_sqrt_multiplier,
283 inv_sqrt_shift + kShift);
284 const int32_t output =
285 MultiplyByQuantizedMultiplier(data, op_data->multiplier,
286 op_data->shift - kShift) +
287 op_data->output_offset;
288 return static_cast<int8_t>(std::min(std::max(output, kMin), kMax));
289 };
290
291 return EvalImpl<int8_t>(context, node, func, validate_input_func, type);
292 }
293
RsqrtEval(TfLiteContext * context,TfLiteNode * node)294 TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
295 const TfLiteType type = GetInput(context, node, 0)->type;
296 switch (type) {
297 case kTfLiteFloat32:
298 return EvalImpl<float>(
299 context, node, [](float f) { return 1.f / std::sqrt(f); }, type);
300 case kTfLiteInt8:
301 return RsqrtEvalQuantized(context, node, type);
302 default:
303 TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
304 TfLiteTypeGetName(type));
305 return kTfLiteError;
306 }
307 }
308
SquareEval(TfLiteContext * context,TfLiteNode * node)309 TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
310 return EvalNumeric(context, node, [](float f) { return f * f; });
311 }
312
LogicalNotEval(TfLiteContext * context,TfLiteNode * node)313 TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
314 return EvalLogical(context, node, [](bool v) { return !v; });
315 }
316
317 } // namespace
318 } // namespace elementwise
319
320 // Given a function...
321 // template<int T>
322 // int Foo(int b)
323 //
324 // typedef int(*Bar)(int);
325 //
326 // MSVC2015 will not see Foo<10> as the same type as Bar.
327 //
328 // This works around the issue by instantiating wrapper methods around
329 // elementwise::GenericPrepare() rather than using a templated
330 // elementwise::GenericPrepare method.
331 #define GENERIC_PREPARE(function_name, is_supported_type_function, type_name) \
332 static TfLiteStatus function_name(TfLiteContext* context, \
333 TfLiteNode* node) { \
334 return elementwise::GenericPrepare(context, node, \
335 is_supported_type_function, type_name); \
336 }
337
GENERIC_PREPARE(PrepareAbs,elementwise::IsAbsSupportedType,elementwise::kAbsName)338 GENERIC_PREPARE(PrepareAbs, elementwise::IsAbsSupportedType,
339 elementwise::kAbsName)
340
341 TfLiteRegistration* Register_ABS() {
342 static TfLiteRegistration r = {elementwise::ElementWiseQuantizedInit,
343 elementwise::ElementWiseQuantizedFree,
344 PrepareAbs, elementwise::AbsEval};
345 return &r;
346 }
347
348 GENERIC_PREPARE(PrepareSin, elementwise::IsNumericSupportedType, "Sin")
349
Register_SIN()350 TfLiteRegistration* Register_SIN() {
351 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareSin,
352 elementwise::SinEval};
353 return &r;
354 }
355
356 GENERIC_PREPARE(PrepareCos, elementwise::IsNumericSupportedType, "Cos")
357
Register_COS()358 TfLiteRegistration* Register_COS() {
359 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareCos,
360 elementwise::CosEval};
361 return &r;
362 }
363
364 GENERIC_PREPARE(PrepareLog, elementwise::IsNumericSupportedType, "Log")
365
Register_LOG()366 TfLiteRegistration* Register_LOG() {
367 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareLog,
368 elementwise::LogEval};
369 return &r;
370 }
371
372 GENERIC_PREPARE(PrepareSqrt, elementwise::IsNumericSupportedType, "Sqrt")
373
Register_SQRT()374 TfLiteRegistration* Register_SQRT() {
375 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
376 PrepareSqrt, elementwise::SqrtEval};
377 return &r;
378 }
379
GENERIC_PREPARE(PrepareRsqrt,elementwise::IsRsqrtSupportedType,elementwise::kRsqrtName)380 GENERIC_PREPARE(PrepareRsqrt, elementwise::IsRsqrtSupportedType,
381 elementwise::kRsqrtName)
382
383 TfLiteRegistration* Register_RSQRT() {
384 static TfLiteRegistration r = {elementwise::ElementWiseQuantizedInit,
385 elementwise::ElementWiseQuantizedFree,
386 PrepareRsqrt, elementwise::RsqrtEval};
387 return &r;
388 }
389
390 GENERIC_PREPARE(PrepareSquare, elementwise::IsNumericSupportedType, "Square")
391
Register_SQUARE()392 TfLiteRegistration* Register_SQUARE() {
393 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
394 PrepareSquare, elementwise::SquareEval};
395 return &r;
396 }
397
398 GENERIC_PREPARE(PrepareNot, elementwise::IsLogicalSupportedType, "Not")
399
Register_LOGICAL_NOT()400 TfLiteRegistration* Register_LOGICAL_NOT() {
401 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareNot,
402 elementwise::LogicalNotEval};
403 return &r;
404 }
405
406 } // namespace builtin
407 } // namespace ops
408 } // namespace tflite
409