xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/activations.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stddef.h>
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 
23 #include "tensorflow/lite/c/builtin_op_data.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/kernels/cpu_backend_context.h"
26 #include "tensorflow/lite/kernels/internal/common.h"
27 #include "tensorflow/lite/kernels/internal/compatibility.h"
28 #include "tensorflow/lite/kernels/internal/cppmath.h"
29 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/leaky_relu.h"
30 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
31 #include "tensorflow/lite/kernels/internal/quantization_util.h"
32 #include "tensorflow/lite/kernels/internal/reference/binary_function.h"
33 #include "tensorflow/lite/kernels/internal/reference/gelu.h"
34 #include "tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h"
35 #include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
36 #include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
37 #include "tensorflow/lite/kernels/internal/reference/logistic.h"
38 #include "tensorflow/lite/kernels/internal/reference/prelu.h"
39 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
40 #include "tensorflow/lite/kernels/internal/reference/softmax.h"
41 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
42 #include "tensorflow/lite/kernels/internal/tensor.h"
43 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
44 #include "tensorflow/lite/kernels/internal/types.h"
45 #include "tensorflow/lite/kernels/kernel_util.h"
46 
47 #if __aarch64__ && __clang__
48 #include <arm_neon.h>
49 #endif
50 
51 namespace tflite {
52 namespace ops {
53 namespace builtin {
54 namespace activations {
55 
56 // TODO(b/142762739): We should figure out a multi-threading plan for most of
57 // the activation ops below.
58 
59 enum KernelType {
60   kReference,
61   kGenericOptimized,
62   kFixedPointOptimized,
63 };
64 
65 struct OpData {
66   int32_t input_multiplier = 0;
67   int input_left_shift = 0;
68   int32_t input_range_radius = 0;
69   int diff_min = 0;
70   uint8_t table[256] = {0};
71 };
72 
73 struct SoftmaxOpData {
74   struct SoftmaxParams params = {};
75   float table[256];
76 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
77   uint8_t uint8_table1[256];
78   uint8_t uint8_table2[256];
79 #endif
80   static constexpr int kInt16LUTArraySize = lut_size<int16_t>();
81   int16_t exp_lut[kInt16LUTArraySize];  // int16 LUT for exp(x), where x uniform
82                                         // distributed between [-10.0 , 0.0]
83   int16_t one_over_one_plus_x_lut[kInt16LUTArraySize];  // int16 LUT for 1 /
84                                                         // (1 + x), where x
85                                                         // uniform distributed
86                                                         // between [0.0 , 1.0]
87 };
88 
89 struct LogSoftmaxOpData : public OpData {
90   int32_t reverse_scaling_divisor = 0;
91   int32_t reverse_scaling_right_shift = 0;
92   struct SoftmaxParams params = {};
93   float f_table[256];
94 };
95 
96 struct LeakyReluOpData : public OpData {
97   int32_t output_multiplier_alpha = 0;
98   int32_t output_shift_alpha = 0;
99   int32_t output_multiplier_identity = 0;
100   int32_t output_shift_identity = 0;
101 };
102 
103 struct PreluOpData : public OpData {
104   int32_t output_multiplier_1 = 0;
105   int32_t output_shift_1 = 0;
106   int32_t output_multiplier_2 = 0;
107   int32_t output_shift_2 = 0;
108   bool requires_broadcast;
109 };
110 
111 struct HardSwishData {
112   HardSwishParams params;
113 };
114 
115 struct ReluOpData : public OpData {
116   int32_t output_multiplier = 0;
117   int output_shift = 0;
118 };
119 
120 namespace {
121 template <typename T>
PopulateLookupTable(struct OpData * data,const TfLiteTensor * input,TfLiteTensor * output,const std::function<float (float)> & transform)122 void PopulateLookupTable(struct OpData* data, const TfLiteTensor* input,
123                          TfLiteTensor* output,
124                          const std::function<float(float)>& transform) {
125   static_assert(sizeof(T) == 1, "Lookup table valid only for 8bit");
126   const float inverse_scale = 1 / output->params.scale;
127   int32_t maxval = std::numeric_limits<T>::max();
128   int32_t minval = std::numeric_limits<T>::min();
129   for (int32_t val = minval; val <= maxval; ++val) {
130     const float dequantized =
131         input->params.scale * (val - input->params.zero_point);
132     const float transformed = transform(dequantized);
133     const float rescaled = std::round(transformed * inverse_scale);
134     const int32_t quantized =
135         static_cast<int32_t>(rescaled + output->params.zero_point);
136     data->table[static_cast<uint8_t>(static_cast<T>(val))] =
137         static_cast<uint8_t>(
138             static_cast<T>(std::max(std::min(maxval, quantized), minval)));
139   }
140 }
141 
142 // TODO(b/143696793): move this to optimized_ops.
EvalUsingLookupTable(struct OpData * data,const TfLiteTensor * input,TfLiteTensor * output)143 void EvalUsingLookupTable(struct OpData* data, const TfLiteTensor* input,
144                           TfLiteTensor* output) {
145   const int size =
146       MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
147   uint8_t* output_data = GetTensorData<uint8_t>(output);
148   const uint8_t* input_data = GetTensorData<uint8_t>(input);
149   int i = 0;
150 #if __aarch64__ && __clang__
151   // This code uses ARM64-only instructions.
152   // TODO(b/143709993): Port to ARMv7
153 
154   // Load the tables into registers. (4*4 128-bit registers)
155   uint8x16x4_t table[4];
156   table[0] = vld1q_u8_x4(data->table + 16 * 4 * 0);
157   table[1] = vld1q_u8_x4(data->table + 16 * 4 * 1);
158   table[2] = vld1q_u8_x4(data->table + 16 * 4 * 2);
159   table[3] = vld1q_u8_x4(data->table + 16 * 4 * 3);
160 
161   // Vectorized loop; process uint8x16_t (16 elements) at a time.
162   constexpr int vectorized_16_loop_step = 16;
163   const int vectorized_16_loop_end =
164       size / vectorized_16_loop_step * vectorized_16_loop_step;
165   for (; i < vectorized_16_loop_end; i += vectorized_16_loop_step) {
166     uint8x16_t input = vld1q_u8(input_data + i);
167     uint8x16_t output = optimized_ops::aarch64_lookup_vector(table, input);
168     vst1q_u8(output_data + i, output);
169   }
170   // Postamble and non-ARM64 code: simple for loop.
171 #endif
172   for (; i < size; ++i) {
173     output_data[i] = data->table[input_data[i]];
174   }
175 }
176 
177 template <typename T>
QuantizedReluX(float act_min,float act_max,const TfLiteTensor * input,TfLiteTensor * output,const ReluOpData * data)178 void QuantizedReluX(float act_min, float act_max, const TfLiteTensor* input,
179                     TfLiteTensor* output, const ReluOpData* data) {
180   ReluParams params;
181   params.quantized_activation_min =
182       std::max(static_cast<int32_t>(std::numeric_limits<T>::min()),
183                output->params.zero_point +
184                    static_cast<int32>(roundf(act_min / output->params.scale)));
185   params.quantized_activation_max =
186       act_max == std::numeric_limits<float>::infinity()
187           ? static_cast<int32_t>(std::numeric_limits<T>::max())
188           : std::min(
189                 static_cast<int32_t>(std::numeric_limits<T>::max()),
190                 output->params.zero_point +
191                     static_cast<int32>(roundf(act_max / output->params.scale)));
192   params.input_offset = input->params.zero_point;
193   params.output_offset = output->params.zero_point;
194   params.output_multiplier = data->output_multiplier;
195   params.output_shift = data->output_shift;
196   optimized_ops::ReluX(params, GetTensorShape(input), GetTensorData<T>(input),
197                        GetTensorShape(output), GetTensorData<T>(output));
198 }
199 
200 }  // namespace
201 
Init(TfLiteContext * context,const char * buffer,size_t length)202 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
203   // This is a builtin op, so we don't use the contents in 'buffer', if any.
204   // Instead, we allocate a new object to carry information from Prepare() to
205   // Eval().
206   return new OpData;
207 }
208 
SoftmaxInit(TfLiteContext * context,const char * buffer,size_t length)209 void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
210   return new SoftmaxOpData;
211 }
212 
SoftmaxFree(TfLiteContext * context,void * buffer)213 void SoftmaxFree(TfLiteContext* context, void* buffer) {
214   delete reinterpret_cast<SoftmaxOpData*>(buffer);
215 }
216 
LogSoftmaxInit(TfLiteContext * context,const char * buffer,size_t length)217 void* LogSoftmaxInit(TfLiteContext* context, const char* buffer,
218                      size_t length) {
219   return new LogSoftmaxOpData;
220 }
221 
PreluInit(TfLiteContext * context,const char * buffer,size_t length)222 void* PreluInit(TfLiteContext* context, const char* buffer, size_t length) {
223   return new PreluOpData;
224 }
225 
Free(TfLiteContext * context,void * buffer)226 void Free(TfLiteContext* context, void* buffer) {
227   delete reinterpret_cast<OpData*>(buffer);
228 }
229 
LogSoftmaxFree(TfLiteContext * context,void * buffer)230 void LogSoftmaxFree(TfLiteContext* context, void* buffer) {
231   delete reinterpret_cast<LogSoftmaxOpData*>(buffer);
232 }
233 
PreluFree(TfLiteContext * context,void * buffer)234 void PreluFree(TfLiteContext* context, void* buffer) {
235   delete reinterpret_cast<PreluOpData*>(buffer);
236 }
237 
HardSwishInit(TfLiteContext * context,const char * buffer,size_t length)238 void* HardSwishInit(TfLiteContext* context, const char* buffer, size_t length) {
239   return new HardSwishData;
240 }
241 
GenericPrepare(TfLiteContext * context,TfLiteNode * node)242 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
243   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
244   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
245   const TfLiteTensor* input;
246   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
247   TfLiteTensor* output;
248   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
249   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
250 
251   return context->ResizeTensor(context, output,
252                                TfLiteIntArrayCopy(input->dims));
253 }
254 
ReluInit(TfLiteContext * context,const char * buffer,size_t length)255 void* ReluInit(TfLiteContext* context, const char* buffer, size_t length) {
256   return new ReluOpData;
257 }
258 
ReluFree(TfLiteContext * context,void * buffer)259 void ReluFree(TfLiteContext* context, void* buffer) {
260   delete reinterpret_cast<ReluOpData*>(buffer);
261 }
262 
ReluPrepare(TfLiteContext * context,TfLiteNode * node)263 TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
264   ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
265   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
266   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
267   const TfLiteTensor* input;
268   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
269   TfLiteTensor* output;
270   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
271   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
272 
273   if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8 ||
274       input->type == kTfLiteInt16) {
275     double real_multiplier = input->params.scale / output->params.scale;
276     QuantizeMultiplier(real_multiplier, &data->output_multiplier,
277                        &data->output_shift);
278   }
279 
280   if (input->type == kTfLiteInt16) {
281     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
282     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
283   }
284 
285   return context->ResizeTensor(context, output,
286                                TfLiteIntArrayCopy(input->dims));
287 }
288 
LeakyReluInit(TfLiteContext * context,const char * buffer,size_t length)289 void* LeakyReluInit(TfLiteContext* context, const char* buffer, size_t length) {
290   return new LeakyReluOpData;
291 }
292 
LeakyReluFree(TfLiteContext * context,void * buffer)293 void LeakyReluFree(TfLiteContext* context, void* buffer) {
294   delete reinterpret_cast<LeakyReluOpData*>(buffer);
295 }
296 
HardSwishFree(TfLiteContext * context,void * buffer)297 void HardSwishFree(TfLiteContext* context, void* buffer) {
298   delete static_cast<HardSwishData*>(buffer);
299 }
300 
HardSwishPrepare(TfLiteContext * context,TfLiteNode * node)301 TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
302   TF_LITE_ENSURE_STATUS(GenericPrepare(context, node));
303   TfLiteTensor* output;
304   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
305 
306   if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
307     HardSwishData* data = static_cast<HardSwishData*>(node->user_data);
308     HardSwishParams* params = &data->params;
309     const TfLiteTensor* input;
310     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
311     params->input_zero_point = input->params.zero_point;
312     params->output_zero_point = output->params.zero_point;
313     const float input_scale = input->params.scale;
314     const float hires_input_scale = (1.0f / 128.0f) * input_scale;
315     const float reluish_scale = 3.0f / 32768.0f;
316     const float output_scale = output->params.scale;
317 
318     const float output_multiplier = hires_input_scale / output_scale;
319 
320     int32_t output_multiplier_fixedpoint_int32;
321     QuantizeMultiplier(output_multiplier, &output_multiplier_fixedpoint_int32,
322                        &params->output_multiplier_exponent);
323     DownScaleInt32ToInt16Multiplier(
324         output_multiplier_fixedpoint_int32,
325         &params->output_multiplier_fixedpoint_int16);
326     TF_LITE_ENSURE(context, params->output_multiplier_exponent <= 0);
327 
328     const float reluish_multiplier = hires_input_scale / reluish_scale;
329     int32_t reluish_multiplier_fixedpoint_int32;
330     QuantizeMultiplier(reluish_multiplier, &reluish_multiplier_fixedpoint_int32,
331                        &params->reluish_multiplier_exponent);
332     DownScaleInt32ToInt16Multiplier(
333         reluish_multiplier_fixedpoint_int32,
334         &params->reluish_multiplier_fixedpoint_int16);
335   }
336   return kTfLiteOk;
337 }
338 
LeakyReluPrepare(TfLiteContext * context,TfLiteNode * node)339 TfLiteStatus LeakyReluPrepare(TfLiteContext* context, TfLiteNode* node) {
340   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
341   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
342   const TfLiteTensor* input;
343   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
344   TfLiteTensor* output;
345   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
346   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
347 
348   LeakyReluOpData* data = reinterpret_cast<LeakyReluOpData*>(node->user_data);
349 
350   if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
351       output->type == kTfLiteInt16) {
352     const auto* params =
353         reinterpret_cast<TfLiteLeakyReluParams*>(node->builtin_data);
354 
355     double alpha_multiplier =
356         input->params.scale * params->alpha / output->params.scale;
357     QuantizeMultiplier(alpha_multiplier, &data->output_multiplier_alpha,
358                        &data->output_shift_alpha);
359     double identity_multiplier = input->params.scale / output->params.scale;
360     QuantizeMultiplier(identity_multiplier, &data->output_multiplier_identity,
361                        &data->output_shift_identity);
362   }
363 
364   if (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) {
365     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
366     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
367   }
368 
369   return context->ResizeTensor(context, output,
370                                TfLiteIntArrayCopy(input->dims));
371 }
372 
373 template <KernelType kernel_type>
TanhPrepare(TfLiteContext * context,TfLiteNode * node)374 TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {
375   OpData* data = reinterpret_cast<OpData*>(node->user_data);
376 
377   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
378   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
379   const TfLiteTensor* input;
380   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
381   TfLiteTensor* output;
382   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
383   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
384 
385   if (kernel_type == kFixedPointOptimized) {
386     if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
387       static constexpr int kInputIntegerBits = 4;
388 
389       const double input_real_multiplier =
390           input->params.scale *
391           static_cast<double>(1 << (15 - kInputIntegerBits));
392 
393       const double q =
394           std::frexp(input_real_multiplier, &data->input_left_shift);
395       auto q_fixed = static_cast<int32_t>(TfLiteRound(q * (1LL << 15)));
396       data->input_multiplier = static_cast<int16_t>(q_fixed);
397 
398       int16_t input_range_radius =
399           CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 15);
400       data->input_range_radius = input_range_radius;
401     }
402   }
403 
404   if (kernel_type == kGenericOptimized || kernel_type == kReference) {
405     if (input->type == kTfLiteUInt8) {
406       PopulateLookupTable<uint8_t>(
407           data, input, output, [](float value) { return std::tanh(value); });
408     } else if (input->type == kTfLiteInt8) {
409       PopulateLookupTable<int8_t>(data, input, output,
410                                   [](float value) { return std::tanh(value); });
411     }
412   }
413 
414   if (input->type == kTfLiteInt16) {
415     static constexpr int kInputIntegerBits = 3;
416     static constexpr int kOutputFractionalBits = 15;
417 
418     // These operators are implemented in fixed-point arithmetic,
419     // which intrinsically wants symmetric ranges (zero_point==0)
420     // and power-of-two scales (power-of-two is abbreviated below as POT).
421     // While more general support would be possible by means of rescaling,
422     // that would add some overhead and some loss of accuracy and wouldn't
423     // be used at the moment as current quantized LSTM applications are
424     // happy with symmetric, power-of-two-scales quantization. So we just
425     // implement that narrow case only for now.
426 
427     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
428     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
429 
430     int input_scale_log2_rounded;
431     bool param_scale_pot =
432         CheckedLog2(input->params.scale, &input_scale_log2_rounded);
433 
434     data->input_left_shift =
435         (15 - kInputIntegerBits) + input_scale_log2_rounded;
436     param_scale_pot &=
437         (data->input_left_shift == 0 || data->input_left_shift == 1);
438 
439     if (!param_scale_pot) {
440       // Calculate multiplier to change input scale to 1/(3*4096)
441       // as required by the table lookup.
442       // The number 3.0 in the multiplier comes from here,
443       // because the interval is [-10.7, 10.7] instead of [-8, 8].
444       // So, in this scaling +/-2^17 represents +/-10.7.
445 
446       double multiplier = input->params.scale * 4096.0 * 3.0;
447       data->input_left_shift = 0;
448 
449       while (multiplier <= 32767.0 / 2.0 && data->input_left_shift <= 30) {
450         data->input_left_shift++;
451         multiplier = multiplier * 2.0;
452       }
453 
454       data->input_multiplier = static_cast<int32_t>(multiplier);
455     }
456 
457     int output_scale_log2_rounded;
458     TF_LITE_ENSURE(
459         context, CheckedLog2(output->params.scale, &output_scale_log2_rounded));
460     TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded,
461                       -kOutputFractionalBits);
462   }
463 
464   return context->ResizeTensor(context, output,
465                                TfLiteIntArrayCopy(input->dims));
466 }
467 
468 template <KernelType kernel_type>
SigmoidPrepare(TfLiteContext * context,TfLiteNode * node)469 TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
470   OpData* data = reinterpret_cast<OpData*>(node->user_data);
471 
472   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
473   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
474   const TfLiteTensor* input;
475   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
476   TfLiteTensor* output;
477   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
478   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
479 
480   if (kernel_type == kFixedPointOptimized) {
481     if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
482       if (input->type == kTfLiteUInt8) {
483         TF_LITE_ENSURE_EQ(context, output->params.zero_point,
484                           std::numeric_limits<uint8_t>::min());
485       }
486       if (input->type == kTfLiteInt8) {
487         TF_LITE_ENSURE_EQ(context, output->params.zero_point,
488                           std::numeric_limits<int8_t>::min());
489       }
490       TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
491 
492       static constexpr int kInputIntegerBits = 4;
493 
494       const double input_real_multiplier =
495           input->params.scale *
496           static_cast<double>(1 << (15 - kInputIntegerBits));
497 
498       const double q =
499           std::frexp(input_real_multiplier, &data->input_left_shift);
500       auto q_fixed = static_cast<int32_t>(TfLiteRound(q * (1LL << 15)));
501       data->input_multiplier = static_cast<int16_t>(q_fixed);
502 
503       int16_t input_range_radius =
504           CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 15);
505       data->input_range_radius = input_range_radius;
506     }
507   }
508 
509   if (kernel_type == kGenericOptimized || kernel_type == kReference) {
510     if (input->type == kTfLiteUInt8) {
511       TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
512       PopulateLookupTable<uint8_t>(data, input, output, [](float value) {
513         return 1.0f / (1.0f + std::exp(-value));
514       });
515     } else if (input->type == kTfLiteInt8) {
516       TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
517       PopulateLookupTable<int8_t>(data, input, output, [](float value) {
518         return 1.0f / (1.0f + std::exp(-value));
519       });
520     } else if (input->type == kTfLiteInt16) {
521       TF_LITE_ENSURE(context, output->params.scale == 1. / 32768);
522       TF_LITE_ENSURE(context, output->params.zero_point == 0);
523     }
524   }
525 
526   if (input->type == kTfLiteInt16) {
527     static constexpr int kInputIntegerBits = 3;
528     static constexpr int kOutputFractionalBits = 15;
529 
530     // See comments in TanhPrepare about requiring zero_point==0
531     // and a power-of-two ("POT") scale.
532 
533     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
534     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
535 
536     int input_scale_log2_rounded;
537     bool param_scale_pot =
538         CheckedLog2(input->params.scale, &input_scale_log2_rounded);
539 
540     data->input_left_shift =
541         (15 - kInputIntegerBits) + input_scale_log2_rounded;
542     param_scale_pot &= (data->input_left_shift == 0);
543 
544     if (!param_scale_pot) {
545       // Calculate multiplier to change input scale to 1/(3*4096)
546       // as required by the table lookup.
547       // In this scaling +/-2^17 represents +/-10.7
548       double multiplier = input->params.scale * 4096.0 * 3.0;
549 
550       data->input_left_shift = 0;
551 
552       while (multiplier <= 32767.0 / 2.0 && data->input_left_shift <= 30) {
553         data->input_left_shift++;
554         multiplier = multiplier * 2.0;
555       }
556 
557       data->input_multiplier = static_cast<int32_t>(multiplier);
558     }
559 
560     int output_scale_log2_rounded;
561     TF_LITE_ENSURE(
562         context, CheckedLog2(output->params.scale, &output_scale_log2_rounded));
563     TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded,
564                       -kOutputFractionalBits);
565   }
566 
567   return context->ResizeTensor(context, output,
568                                TfLiteIntArrayCopy(input->dims));
569 }
570 
571 template <KernelType kernel_type>
SoftmaxPrepare(TfLiteContext * context,TfLiteNode * node)572 TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
573   auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
574   SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
575 
576   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
577   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
578   const TfLiteTensor* input;
579   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
580   TfLiteTensor* output;
581   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
582 
583   TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
584 
585   if (input->type == kTfLiteInt8 && output->type == kTfLiteInt8) {
586     TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
587     TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 256,
588                         (0.001f * 1.f / 256));
589   } else if (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) {
590     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
591     TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 32768,
592                         (0.001f * 1.f / 32768));
593   }
594 
595   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
596     if (kernel_type == kReference) {
597       const int kScaledDiffIntegerBits = 5;
598       int input_left_shift;
599       tflite::PreprocessSoftmaxScaling(
600           static_cast<double>(params->beta),
601           static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
602           &data->params.input_multiplier, &input_left_shift);
603       data->params.input_left_shift = input_left_shift;
604       data->params.diff_min =
605           -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
606                                               input_left_shift);
607     } else {
608       switch (output->type) {
609         case kTfLiteUInt8:
610         case kTfLiteInt8:
611 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
612           // Only apply when both input & output are uint8/int8 & build with
613           // clang on aarch64.
614           // TODO(b/143709993): Port to ARMv7 and other platforms.
615           data->params.uint8_table1 = data->uint8_table1;
616           data->params.uint8_table2 = data->uint8_table2;
617           optimized_ops::PopulateSoftmaxUInt8LookupTable(
618               &data->params, input->params.scale, params->beta);
619           break;
620 #endif
621         case kTfLiteInt16:
622         default:
623           data->params.table = data->table;
624           optimized_ops::PopulateSoftmaxLookupTable(
625               &data->params, input->params.scale, params->beta);
626       }
627 
628       data->params.zero_point = output->params.zero_point;
629       data->params.scale = output->params.scale;
630     }
631   } else if (input->type == kTfLiteInt16) {
632     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
633     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
634 
635     data->params.exp_lut = data->exp_lut;
636     // exp LUT only used on nagative values
637     // we consider exp(-10.0) is insignificant to accumulation
638     gen_lut<double, int16_t, int16_t>(
639         [](double value) { return std::exp(value); }, -10.0, 0.0, -1.0, 1.0,
640         data->params.exp_lut);
641     data->params.one_over_one_plus_x_lut = data->one_over_one_plus_x_lut;
642     gen_lut<double, int16_t, int16_t>(
643         [](double value) { return 1.0 / (1.0 + value); }, 0.0, 1.0, -1.0, 1.0,
644         data->params.one_over_one_plus_x_lut);
645     data->params.zero_point = output->params.zero_point;
646     data->params.scale = output->params.scale;
647 
648     double input_scale_beta_rescale =
649         input->params.scale * params->beta /
650         (10.0 / 65535.0);  // scale the input_diff such that [-65535, 0]
651                            // correspond to [-10.0, 0.0]
652     QuantizeMultiplier(input_scale_beta_rescale, &data->params.input_multiplier,
653                        &data->params.input_left_shift);
654   }
655 
656   return context->ResizeTensor(context, output,
657                                TfLiteIntArrayCopy(input->dims));
658 }
659 
660 template <KernelType kernel_type>
LogSoftmaxPrepare(TfLiteContext * context,TfLiteNode * node)661 TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
662   LogSoftmaxOpData* data = reinterpret_cast<LogSoftmaxOpData*>(node->user_data);
663 
664   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
665   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
666   const TfLiteTensor* input;
667   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
668   TfLiteTensor* output;
669   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
670   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
671 
672   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
673     TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256);
674     static const double kBeta = 1.0;
675     if (input->type == kTfLiteUInt8) {
676       TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255);
677     }
678     if (input->type == kTfLiteInt8) {
679       TF_LITE_ENSURE_EQ(context, output->params.zero_point, 127);
680     }
681 
682     if (kernel_type == kReference) {
683       const int kScaledDiffIntegerBits = 5;
684       int input_left_shift;
685       int reverse_scaling_right_shift;
686       tflite::PreprocessLogSoftmaxScalingExp(
687           kBeta, static_cast<double>(input->params.scale),
688           kScaledDiffIntegerBits, &data->params.input_multiplier,
689           &input_left_shift, &data->params.reverse_scaling_divisor,
690           &reverse_scaling_right_shift);
691       reverse_scaling_right_shift *= -1;
692       data->params.input_left_shift = input_left_shift;
693       data->params.reverse_scaling_right_shift = reverse_scaling_right_shift;
694       data->params.diff_min = -tflite::CalculateInputRadius(
695           kScaledDiffIntegerBits, input_left_shift);
696     } else {
697       data->params.table = data->f_table;
698       optimized_ops::PopulateSoftmaxLookupTable(&data->params,
699                                                 input->params.scale, kBeta);
700       data->params.zero_point = output->params.zero_point;
701       data->params.scale = output->params.scale;
702     }
703   }
704 
705   return context->ResizeTensor(context, output,
706                                TfLiteIntArrayCopy(input->dims));
707 }
708 
PreluPrepare(TfLiteContext * context,TfLiteNode * node)709 TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
710   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
711   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
712   const TfLiteTensor* input;
713   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
714   TfLiteTensor* output;
715   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
716   const TfLiteTensor* alpha;
717   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &alpha));
718   PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);
719 
720   TF_LITE_ENSURE_TYPES_EQ(context, input->type, alpha->type);
721 
722   output->type = input->type;
723 
724   if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
725     // prelu(x) = x if x >= 0 else x * alpha.
726     // So if we translate that for quantized computation:
727     //
728     // input_float = (input_q - input_zp) * input_scale
729     // output_float = (output_q - output_zp) * output_scale
730     // alpha_float = (alpha_q - alpha_zp) * alpha_scale
731     //
732     // When input_q - input_zp >= 0:
733     // output_q = (input_q - input_zp) * input_scale / output_scale + output_q
734     // else:
735     // output_q = (input_q - input_zp) * (alpha_q - alpha_zp) * input_scale
736     //            * alpha_scale / output_scale + output_q
737     //
738     // So for input_q - input_zp >= 0:
739     // output real multiplier 1 is input_scale / output_scale;
740     // for input_q - input_zp < 0:
741     // output real multiplier 2 is input_scale  * alpha_scale/ output_scale.
742     double real_multiplier_1 = input->params.scale / output->params.scale;
743     double real_multiplier_2 =
744         input->params.scale * alpha->params.scale / output->params.scale;
745     QuantizeMultiplier(real_multiplier_1, &data->output_multiplier_1,
746                        &data->output_shift_1);
747     QuantizeMultiplier(real_multiplier_2, &data->output_multiplier_2,
748                        &data->output_shift_2);
749   }
750 
751   data->requires_broadcast = !HaveSameShapes(input, alpha);
752   // PRelu (parameteric Relu) shares the same alpha value on "shared axis".
753   // This means it's always required to "broadcast" alpha values in PRelu.
754   TfLiteIntArray* output_size = nullptr;
755   TF_LITE_ENSURE_OK(
756       context, CalculateShapeForBroadcast(context, input, alpha, &output_size));
757 
758   TF_LITE_ENSURE_OK(context,
759                     context->ResizeTensor(context, output, output_size));
760   // After broadcasting, the output shape should always be the same as the
761   // input shape.
762   TF_LITE_ENSURE(context, HaveSameShapes(input, output));
763 
764   return kTfLiteOk;
765 }
766 
ReluEval(TfLiteContext * context,TfLiteNode * node)767 TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
768   const TfLiteTensor* input;
769   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
770   TfLiteTensor* output;
771   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
772   const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
773   switch (input->type) {
774     case kTfLiteFloat32: {
775       optimized_ops::Relu(GetTensorShape(input), GetTensorData<float>(input),
776                           GetTensorShape(output), GetTensorData<float>(output));
777     } break;
778     // TODO(renjieliu): We may revisit the quantization calculation logic,
779     // the unbounded upper limit is actually hard to quantize.
780     case kTfLiteUInt8: {
781       QuantizedReluX<uint8_t>(0.0f, std::numeric_limits<float>::infinity(),
782                               input, output, data);
783     } break;
784     case kTfLiteInt8: {
785       QuantizedReluX<int8_t>(0.0f, std::numeric_limits<float>::infinity(),
786                              input, output, data);
787     } break;
788     case kTfLiteInt16: {
789       QuantizedReluX<int16_t>(0.0f, std::numeric_limits<float>::infinity(),
790                               input, output, data);
791     } break;
792     default:
793       TF_LITE_KERNEL_LOG(context,
794                          "Only float32, uint8, int8 and int16 are supported "
795                          "currently, got %s.",
796                          TfLiteTypeGetName(input->type));
797       return kTfLiteError;
798   }
799   return kTfLiteOk;
800 }
801 
Relu1Eval(TfLiteContext * context,TfLiteNode * node)802 TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
803   const TfLiteTensor* input;
804   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
805   TfLiteTensor* output;
806   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
807   const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
808   switch (input->type) {
809     case kTfLiteFloat32: {
810       optimized_ops::Relu1(GetTensorShape(input), GetTensorData<float>(input),
811                            GetTensorShape(output),
812                            GetTensorData<float>(output));
813       return kTfLiteOk;
814     }
815     case kTfLiteUInt8: {
816       QuantizedReluX<uint8_t>(-1.0f, 1.0f, input, output, data);
817       return kTfLiteOk;
818     }
819     case kTfLiteInt8: {
820       QuantizedReluX<int8_t>(-1, 1, input, output, data);
821       return kTfLiteOk;
822     }
823     default:
824       TF_LITE_KERNEL_LOG(context,
825                          "Only float32, uint8, int8 supported "
826                          "currently, got %s.",
827                          TfLiteTypeGetName(input->type));
828       return kTfLiteError;
829   }
830 }
831 
832 template <KernelType kernel_type>
HardSwishEval(TfLiteContext * context,TfLiteNode * node)833 TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
834   HardSwishData* data = static_cast<HardSwishData*>(node->user_data);
835 
836   const TfLiteTensor* input;
837   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
838   TfLiteTensor* output;
839   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
840   switch (input->type) {
841     case kTfLiteFloat32: {
842       if (kernel_type == kReference) {
843         reference_ops::HardSwish(
844             GetTensorShape(input), GetTensorData<float>(input),
845             GetTensorShape(output), GetTensorData<float>(output));
846       } else {
847         optimized_ops::HardSwish(
848             GetTensorShape(input), GetTensorData<float>(input),
849             GetTensorShape(output), GetTensorData<float>(output));
850       }
851       return kTfLiteOk;
852     } break;
853     case kTfLiteUInt8: {
854       HardSwishParams& params = data->params;
855       if (kernel_type == kReference) {
856         reference_ops::HardSwish(
857             params, GetTensorShape(input), GetTensorData<uint8_t>(input),
858             GetTensorShape(output), GetTensorData<uint8_t>(output));
859       } else {
860         optimized_ops::HardSwish(
861             params, GetTensorShape(input), GetTensorData<uint8_t>(input),
862             GetTensorShape(output), GetTensorData<uint8_t>(output));
863       }
864       return kTfLiteOk;
865     } break;
866     case kTfLiteInt8: {
867       HardSwishParams& params = data->params;
868       if (kernel_type == kReference) {
869         reference_ops::HardSwish(
870             params, GetTensorShape(input), GetTensorData<int8_t>(input),
871             GetTensorShape(output), GetTensorData<int8_t>(output));
872       } else {
873         optimized_ops::HardSwish(
874             params, GetTensorShape(input), GetTensorData<int8_t>(input),
875             GetTensorShape(output), GetTensorData<int8_t>(output));
876       }
877       return kTfLiteOk;
878     } break;
879     default:
880       TF_LITE_KERNEL_LOG(
881           context,
882           "Only float32, uint8 and int8 are supported currently, got %s.",
883           TfLiteTypeGetName(input->type));
884       return kTfLiteError;
885   }
886 }
887 
Relu0to1Eval(TfLiteContext * context,TfLiteNode * node)888 TfLiteStatus Relu0to1Eval(TfLiteContext* context, TfLiteNode* node) {
889   const TfLiteTensor* input;
890   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
891   TfLiteTensor* output;
892   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
893   const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
894   switch (input->type) {
895     case kTfLiteFloat32: {
896       optimized_ops::Relu0To1(
897           GetTensorShape(input), GetTensorData<float>(input),
898           GetTensorShape(output), GetTensorData<float>(output));
899       return kTfLiteOk;
900     }
901     case kTfLiteUInt8: {
902       QuantizedReluX<uint8_t>(0.0f, 1.0f, input, output, data);
903       return kTfLiteOk;
904     }
905     case kTfLiteInt8: {
906       QuantizedReluX<int8_t>(0, 1, input, output, data);
907       return kTfLiteOk;
908     }
909     default:
910       TF_LITE_KERNEL_LOG(context,
911                          "Only float32, uint8, int8 supported "
912                          "currently, got %s.",
913                          TfLiteTypeGetName(input->type));
914       return kTfLiteError;
915   }
916 }
917 
Relu6Eval(TfLiteContext * context,TfLiteNode * node)918 TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
919   const TfLiteTensor* input;
920   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
921   TfLiteTensor* output;
922   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
923   ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
924   switch (input->type) {
925     case kTfLiteFloat32: {
926       size_t elements = input->bytes / sizeof(float);
927       const float* in = GetTensorData<float>(input);
928       const float* in_end = in + elements;
929       float* out = GetTensorData<float>(output);
930       for (; in < in_end; in++, out++) *out = std::min(std::max(0.f, *in), 6.f);
931       return kTfLiteOk;
932     }
933     case kTfLiteUInt8:
934       QuantizedReluX<uint8_t>(0.0f, 6.0f, input, output, data);
935       return kTfLiteOk;
936     case kTfLiteInt8: {
937       QuantizedReluX<int8_t>(0.0f, 6.0f, input, output, data);
938       return kTfLiteOk;
939     }
940     case kTfLiteInt16: {
941       QuantizedReluX<int16_t>(0.0f, 6.0f, input, output, data);
942       return kTfLiteOk;
943     }
944     default:
945       TF_LITE_KERNEL_LOG(context,
946                          "Only float32, uint8, int8 and int16 are supported "
947                          "currently, got %s.",
948                          TfLiteTypeGetName(input->type));
949       return kTfLiteError;
950   }
951 }
952 
953 template <KernelType kernel_type>
TanhEval(TfLiteContext * context,TfLiteNode * node)954 TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
955   OpData* data = reinterpret_cast<OpData*>(node->user_data);
956   const TfLiteTensor* input;
957   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
958   TfLiteTensor* output;
959   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
960   switch (input->type) {
961     case kTfLiteFloat32: {
962       if (kernel_type == kReference) {
963         reference_ops::Tanh(GetTensorShape(input), GetTensorData<float>(input),
964                             GetTensorShape(output),
965                             GetTensorData<float>(output));
966       } else {
967         optimized_ops::Tanh(GetTensorShape(input), GetTensorData<float>(input),
968                             GetTensorShape(output),
969                             GetTensorData<float>(output));
970       }
971       return kTfLiteOk;
972     } break;
973     case kTfLiteInt16: {
974       TanhParams params;
975       params.input_left_shift = data->input_left_shift;
976       if (kernel_type == kReference || (data->input_multiplier > 0)) {
977         reference_integer_ops::Tanh(
978             data->input_multiplier, data->input_left_shift,
979             GetTensorShape(input), GetTensorData<int16_t>(input),
980             GetTensorShape(output), GetTensorData<int16_t>(output));
981       } else {
982         optimized_ops::Tanh(
983             params, GetTensorShape(input), GetTensorData<int16_t>(input),
984             GetTensorShape(output), GetTensorData<int16_t>(output));
985       }
986       return kTfLiteOk;
987     } break;
988     case kTfLiteUInt8: {
989       if (kernel_type == kFixedPointOptimized) {
990         TanhParams params;
991         params.input_zero_point = input->params.zero_point;
992         params.input_range_radius = data->input_range_radius;
993         params.input_multiplier = data->input_multiplier;
994         params.input_left_shift = data->input_left_shift;
995         optimized_ops::Tanh16bitPrecision(
996             params, GetTensorShape(input), GetTensorData<uint8_t>(input),
997             GetTensorShape(output), GetTensorData<uint8_t>(output));
998       } else {
999         EvalUsingLookupTable(data, input, output);
1000       }
1001       return kTfLiteOk;
1002     } break;
1003     case kTfLiteInt8: {
1004       if (kernel_type == kFixedPointOptimized) {
1005         TanhParams params;
1006         params.input_zero_point = input->params.zero_point;
1007         params.input_range_radius = data->input_range_radius;
1008         params.input_multiplier = data->input_multiplier;
1009         params.input_left_shift = data->input_left_shift;
1010         optimized_ops::Tanh16bitPrecision(
1011             params, GetTensorShape(input), GetTensorData<int8_t>(input),
1012             GetTensorShape(output), GetTensorData<int8_t>(output));
1013       } else {
1014         EvalUsingLookupTable(data, input, output);
1015       }
1016       return kTfLiteOk;
1017     } break;
1018     default:
1019       TF_LITE_KERNEL_LOG(context,
1020                          "Only float32, uint8, int16 and int8 are supported "
1021                          "currently, got %s.",
1022                          TfLiteTypeGetName(input->type));
1023       return kTfLiteError;
1024   }
1025 }
1026 
1027 // Sigmoid is also know as "Logistic".
1028 template <KernelType kernel_type>
SigmoidEval(TfLiteContext * context,TfLiteNode * node)1029 TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
1030   OpData* data = reinterpret_cast<OpData*>(node->user_data);
1031 
1032   const TfLiteTensor* input;
1033   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
1034   TfLiteTensor* output;
1035   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
1036   switch (input->type) {
1037     case kTfLiteFloat32: {
1038       if (kernel_type == kReference) {
1039         reference_ops::Logistic(
1040             GetTensorShape(input), GetTensorData<float>(input),
1041             GetTensorShape(output), GetTensorData<float>(output));
1042       } else {
1043         optimized_ops::Logistic(
1044             GetTensorShape(input), GetTensorData<float>(input),
1045             GetTensorShape(output), GetTensorData<float>(output));
1046       }
1047       break;
1048     }
1049     case kTfLiteInt16: {
1050       LogisticParams params;
1051       if (kernel_type == kReference || (data->input_multiplier > 0)) {
1052         const int size =
1053             MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
1054 
1055         reference_integer_ops::Logistic(
1056             data->input_multiplier, data->input_left_shift, size,
1057             GetTensorData<int16_t>(input), GetTensorData<int16_t>(output));
1058       } else {
1059         optimized_ops::Logistic(
1060             params, GetTensorShape(input), GetTensorData<int16_t>(input),
1061             GetTensorShape(output), GetTensorData<int16_t>(output));
1062       }
1063       break;
1064     }
1065     case kTfLiteUInt8: {
1066       if (kernel_type == kFixedPointOptimized) {
1067         LogisticParams params;
1068         params.input_zero_point = input->params.zero_point;
1069         params.input_range_radius = data->input_range_radius;
1070         params.input_multiplier = data->input_multiplier;
1071         params.input_left_shift = data->input_left_shift;
1072         optimized_ops::Logistic16bitPrecision(
1073             params, GetTensorShape(input), GetTensorData<uint8_t>(input),
1074             GetTensorShape(output), GetTensorData<uint8_t>(output));
1075       } else {
1076         EvalUsingLookupTable(data, input, output);
1077       }
1078       break;
1079     }
1080     case kTfLiteInt8: {
1081       if (kernel_type == kFixedPointOptimized) {
1082         LogisticParams params;
1083         params.input_zero_point = input->params.zero_point;
1084         params.input_range_radius = data->input_range_radius;
1085         params.input_multiplier = data->input_multiplier;
1086         params.input_left_shift = data->input_left_shift;
1087         optimized_ops::Logistic16bitPrecision(
1088             params, GetTensorShape(input), GetTensorData<int8_t>(input),
1089             GetTensorShape(output), GetTensorData<int8_t>(output));
1090       } else {
1091         EvalUsingLookupTable(data, input, output);
1092       }
1093       break;
1094     }
1095     default:
1096       TF_LITE_KERNEL_LOG(context,
1097                          "Only float32, uint8, int16 and int8 are supported "
1098                          "currently, got %s.",
1099                          TfLiteTypeGetName(input->type));
1100       return kTfLiteError;
1101   }
1102   return kTfLiteOk;
1103 }
1104 
SoftmaxFloat(TfLiteContext * context,const TfLiteTensor * input,TfLiteTensor * output,TfLiteSoftmaxParams * params,KernelType kernel_type=kGenericOptimized)1105 TfLiteStatus SoftmaxFloat(TfLiteContext* context, const TfLiteTensor* input,
1106                           TfLiteTensor* output, TfLiteSoftmaxParams* params,
1107                           KernelType kernel_type = kGenericOptimized) {
1108   SoftmaxParams op_params;
1109   op_params.beta = params->beta;
1110   if (kernel_type == kReference) {
1111     reference_ops::Softmax(op_params, GetTensorShape(input),
1112                            GetTensorData<float>(input), GetTensorShape(output),
1113                            GetTensorData<float>(output));
1114   } else {
1115     optimized_ops::Softmax(op_params, GetTensorShape(input),
1116                            GetTensorData<float>(input), GetTensorShape(output),
1117                            GetTensorData<float>(output),
1118                            CpuBackendContext::GetFromContext(context));
1119   }
1120   return kTfLiteOk;
1121 }
1122 
1123 template <typename In, typename Out>
SoftmaxQuantized(TfLiteContext * context,const TfLiteTensor * input,TfLiteTensor * output,SoftmaxOpData * data,KernelType kernel_type=kGenericOptimized)1124 TfLiteStatus SoftmaxQuantized(TfLiteContext* context, const TfLiteTensor* input,
1125                               TfLiteTensor* output, SoftmaxOpData* data,
1126                               KernelType kernel_type = kGenericOptimized) {
1127   if (kernel_type == kReference) {
1128     reference_ops::Softmax(data->params, GetTensorShape(input),
1129                            GetTensorData<In>(input), GetTensorShape(output),
1130                            GetTensorData<Out>(output));
1131   } else {
1132     optimized_ops::Softmax(data->params, GetTensorShape(input),
1133                            GetTensorData<In>(input), GetTensorShape(output),
1134                            GetTensorData<Out>(output));
1135   }
1136   return kTfLiteOk;
1137 }
1138 
1139 template <>
SoftmaxQuantized(TfLiteContext * context,const TfLiteTensor * input,TfLiteTensor * output,SoftmaxOpData * data,KernelType kernel_type)1140 TfLiteStatus SoftmaxQuantized<int8_t, int8_t>(TfLiteContext* context,
1141                                               const TfLiteTensor* input,
1142                                               TfLiteTensor* output,
1143                                               SoftmaxOpData* data,
1144                                               KernelType kernel_type) {
1145   if (kernel_type == kReference) {
1146     reference_ops::Softmax(data->params, GetTensorShape(input),
1147                            GetTensorData<int8_t>(input), GetTensorShape(output),
1148                            GetTensorData<int8_t>(output));
1149   } else {
1150 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
1151     optimized_ops::SoftmaxInt8LUT(
1152         data->params, GetTensorShape(input), GetTensorData<int8_t>(input),
1153         GetTensorShape(output), GetTensorData<int8_t>(output));
1154 #else
1155     optimized_ops::Softmax(data->params, GetTensorShape(input),
1156                            GetTensorData<int8_t>(input), GetTensorShape(output),
1157                            GetTensorData<int8_t>(output));
1158 #endif
1159   }
1160   return kTfLiteOk;
1161 }
1162 
1163 template <>
SoftmaxQuantized(TfLiteContext * context,const TfLiteTensor * input,TfLiteTensor * output,SoftmaxOpData * data,KernelType kernel_type)1164 TfLiteStatus SoftmaxQuantized<uint8_t, uint8_t>(TfLiteContext* context,
1165                                                 const TfLiteTensor* input,
1166                                                 TfLiteTensor* output,
1167                                                 SoftmaxOpData* data,
1168                                                 KernelType kernel_type) {
1169   if (kernel_type == kReference) {
1170     reference_ops::Softmax(
1171         data->params, GetTensorShape(input), GetTensorData<uint8_t>(input),
1172         GetTensorShape(output), GetTensorData<uint8_t>(output));
1173   } else {
1174 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
1175     optimized_ops::SoftmaxInt8LUT(
1176         data->params, GetTensorShape(input), GetTensorData<uint8_t>(input),
1177         GetTensorShape(output), GetTensorData<uint8_t>(output));
1178 #else
1179     optimized_ops::Softmax(
1180         data->params, GetTensorShape(input), GetTensorData<uint8_t>(input),
1181         GetTensorShape(output), GetTensorData<uint8_t>(output));
1182 #endif
1183   }
1184   return kTfLiteOk;
1185 }
1186 
1187 template <>
SoftmaxQuantized(TfLiteContext * context,const TfLiteTensor * input,TfLiteTensor * output,SoftmaxOpData * data,KernelType kernel_type)1188 TfLiteStatus SoftmaxQuantized<int16, int16>(TfLiteContext* context,
1189                                             const TfLiteTensor* input,
1190                                             TfLiteTensor* output,
1191                                             SoftmaxOpData* data,
1192                                             KernelType kernel_type) {
1193   if (NumDimensions(input) >= 1 && NumDimensions(input) <= 4) {
1194     reference_ops::SoftmaxInt16(
1195         data->params, GetTensorShape(input), GetTensorData<int16_t>(input),
1196         GetTensorShape(output), GetTensorData<int16_t>(output));
1197     return kTfLiteOk;
1198   } else {
1199     TF_LITE_KERNEL_LOG(context,
1200                        "Only 1D, 2D, 3D and 4D tensors supported for int16 "
1201                        "input with int16 output, got %dD.",
1202                        NumDimensions(input));
1203     return kTfLiteError;
1204   }
1205 }
1206 
1207 template <KernelType kernel_type>
SoftmaxEval(TfLiteContext * context,TfLiteNode * node)1208 TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
1209   auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
1210   SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
1211 
1212   const TfLiteTensor* input;
1213   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
1214   TfLiteTensor* output;
1215   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
1216 
1217   switch (input->type) {
1218     case kTfLiteFloat32: {
1219       return SoftmaxFloat(context, input, output, params, kernel_type);
1220     }
1221     case kTfLiteUInt8: {
1222       switch (output->type) {
1223         case kTfLiteUInt8:
1224           return SoftmaxQuantized<uint8_t, uint8_t>(context, input, output,
1225                                                     data, kernel_type);
1226         case kTfLiteInt16:
1227           return SoftmaxQuantized<uint8_t, int16_t>(context, input, output,
1228                                                     data, kernel_type);
1229         default:
1230           TF_LITE_KERNEL_LOG(context,
1231                              "Only uint8_t and int16_t outputs are supported "
1232                              "with uint8_t inputs currently, got %s.",
1233                              TfLiteTypeGetName(output->type));
1234           return kTfLiteError;
1235       }
1236     }
1237     case kTfLiteInt8: {
1238       switch (output->type) {
1239         case kTfLiteInt8:
1240           return SoftmaxQuantized<int8_t, int8_t>(context, input, output, data,
1241                                                   kernel_type);
1242         case kTfLiteInt16:
1243           return SoftmaxQuantized<int8_t, int16_t>(context, input, output, data,
1244                                                    kernel_type);
1245         default:
1246           TF_LITE_KERNEL_LOG(context,
1247                              "Only int8_t and int16_t outputs are supported "
1248                              "with int8_t inputs currently, got %s.",
1249                              TfLiteTypeGetName(output->type));
1250           return kTfLiteError;
1251       }
1252     }
1253     case kTfLiteInt16: {
1254       return SoftmaxQuantized<int16_t, int16_t>(context, input, output, data,
1255                                                 kernel_type);
1256     }
1257 
1258     default:
1259       TF_LITE_KERNEL_LOG(context,
1260                          "Only float32, uint8_t, Int8_t, Int16_t are supported "
1261                          "currently, got %s.",
1262                          TfLiteTypeGetName(input->type));
1263       return kTfLiteError;
1264   }
1265 }
1266 
1267 template <KernelType kernel_type>
LogSoftmaxEval(TfLiteContext * context,TfLiteNode * node)1268 TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
1269   const LogSoftmaxOpData* data =
1270       reinterpret_cast<LogSoftmaxOpData*>(node->user_data);
1271   const TfLiteTensor* input;
1272   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
1273   TfLiteTensor* output;
1274   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
1275   switch (input->type) {
1276     case kTfLiteFloat32: {
1277       SoftmaxParams op_params;
1278       if (kernel_type == kGenericOptimized) {
1279         optimized_ops::LogSoftmax(
1280             op_params, GetTensorShape(input), GetTensorData<float>(input),
1281             GetTensorShape(output), GetTensorData<float>(output));
1282       } else {
1283         reference_ops::LogSoftmax(
1284             op_params, GetTensorShape(input), GetTensorData<float>(input),
1285             GetTensorShape(output), GetTensorData<float>(output));
1286       }
1287       return kTfLiteOk;
1288     }
1289     case kTfLiteUInt8: {
1290       const SoftmaxParams& op_params = data->params;
1291       if (kernel_type == kGenericOptimized) {
1292         optimized_ops::LogSoftmax(
1293             op_params, input->params.scale, GetTensorShape(input),
1294             GetTensorData<uint8_t>(input), GetTensorShape(output),
1295             GetTensorData<uint8_t>(output));
1296       } else {
1297         reference_ops::LogSoftmax(
1298             op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
1299             GetTensorShape(output), GetTensorData<uint8_t>(output));
1300       }
1301       return kTfLiteOk;
1302     }
1303     case kTfLiteInt8: {
1304       const SoftmaxParams& op_params = data->params;
1305       if (kernel_type == kGenericOptimized) {
1306         optimized_ops::LogSoftmax(
1307             op_params, input->params.scale, GetTensorShape(input),
1308             GetTensorData<int8_t>(input), GetTensorShape(output),
1309             GetTensorData<int8_t>(output));
1310       } else {
1311         const auto input_shape = GetTensorShape(input);
1312         const auto output_shape = GetTensorShape(output);
1313         const int trailing_dim = input_shape.DimensionsCount() - 1;
1314         const int outer_size =
1315             MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
1316         const int depth =
1317             MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
1318         reference_integer_ops::LogSoftmax(
1319             op_params.input_multiplier, op_params.input_left_shift,
1320             op_params.reverse_scaling_divisor,
1321             op_params.reverse_scaling_right_shift, op_params.diff_min,
1322             outer_size, depth, GetTensorData<int8_t>(input),
1323             GetTensorData<int8_t>(output));
1324       }
1325       return kTfLiteOk;
1326     }
1327     default:
1328       TF_LITE_KERNEL_LOG(
1329           context,
1330           "Only float32, uint8 and int8 are supported currently, got %s.",
1331           TfLiteTypeGetName(input->type));
1332       return kTfLiteError;
1333   }
1334 }
1335 
1336 template <typename T>
ApplyPrelu(T input,T alpha)1337 T ApplyPrelu(T input, T alpha) {
1338   return input >= 0.0 ? input : input * alpha;
1339 }
1340 
1341 template <KernelType kernel_type>
PreluEval(TfLiteContext * context,TfLiteNode * node)1342 TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
1343   const TfLiteTensor* input;
1344   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
1345   const TfLiteTensor* alpha;
1346   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &alpha));
1347   TfLiteTensor* output;
1348   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
1349   const PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);
1350   switch (input->type) {
1351     case kTfLiteFloat32: {
1352       if (kernel_type == kGenericOptimized) {
1353         tflite::ArithmeticParams op_params;
1354         bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
1355             GetTensorShape(input), GetTensorShape(alpha), &op_params);
1356         if (need_broadcast) {
1357           optimized_ops::BroadcastPReluDispatch(
1358               op_params, GetTensorShape(input), GetTensorData<float>(input),
1359               GetTensorShape(alpha), GetTensorData<float>(alpha),
1360               GetTensorShape(output), GetTensorData<float>(output),
1361               ApplyPrelu<float>);
1362         } else {
1363           const int flat_size =
1364               MatchingElementsSize(GetTensorShape(input), GetTensorShape(alpha),
1365                                    GetTensorShape(output));
1366           optimized_ops::PReluElementWise(
1367               flat_size, op_params, GetTensorData<float>(alpha),
1368               GetTensorData<float>(input), GetTensorData<float>(output));
1369         }
1370       } else {
1371         if (data->requires_broadcast) {
1372           reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
1373               GetTensorShape(input), GetTensorData<float>(input),
1374               GetTensorShape(alpha), GetTensorData<float>(alpha),
1375               GetTensorShape(output), GetTensorData<float>(output),
1376               ApplyPrelu<float>);
1377         } else {
1378           reference_ops::BinaryFunction<float, float, float>(
1379               GetTensorShape(input), GetTensorData<float>(input),
1380               GetTensorShape(alpha), GetTensorData<float>(alpha),
1381               GetTensorShape(output), GetTensorData<float>(output),
1382               ApplyPrelu<float>);
1383         }
1384       }
1385       return kTfLiteOk;
1386     }
1387     case kTfLiteUInt8: {
1388       PreluParams op_params;
1389       op_params.input_offset = -input->params.zero_point;
1390       op_params.alpha_offset = -alpha->params.zero_point;
1391       op_params.output_offset = output->params.zero_point;
1392       op_params.output_multiplier_1 = data->output_multiplier_1;
1393       op_params.output_shift_1 = data->output_shift_1;
1394       op_params.output_multiplier_2 = data->output_multiplier_2;
1395       op_params.output_shift_2 = data->output_shift_2;
1396       if (data->requires_broadcast) {
1397         reference_ops::BroadcastPrelu4DSlow(
1398             op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
1399             GetTensorShape(alpha), GetTensorData<uint8_t>(alpha),
1400             GetTensorShape(output), GetTensorData<uint8_t>(output));
1401       } else {
1402         reference_ops::Prelu(
1403             op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
1404             GetTensorShape(alpha), GetTensorData<uint8_t>(alpha),
1405             GetTensorShape(output), GetTensorData<uint8_t>(output));
1406       }
1407       return kTfLiteOk;
1408     }
1409     case kTfLiteInt8: {
1410       PreluParams op_params;
1411       op_params.input_offset = -input->params.zero_point;
1412       op_params.alpha_offset = -alpha->params.zero_point;
1413       op_params.output_offset = output->params.zero_point;
1414       op_params.output_multiplier_1 = data->output_multiplier_1;
1415       op_params.output_shift_1 = data->output_shift_1;
1416       op_params.output_multiplier_2 = data->output_multiplier_2;
1417       op_params.output_shift_2 = data->output_shift_2;
1418       if (data->requires_broadcast) {
1419         reference_ops::BroadcastPrelu4DSlow(
1420             op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
1421             GetTensorShape(alpha), GetTensorData<int8_t>(alpha),
1422             GetTensorShape(output), GetTensorData<int8_t>(output));
1423       } else {
1424         reference_ops::Prelu(
1425             op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
1426             GetTensorShape(alpha), GetTensorData<int8_t>(alpha),
1427             GetTensorShape(output), GetTensorData<int8_t>(output));
1428       }
1429       return kTfLiteOk;
1430     }
1431     default:
1432       TF_LITE_KERNEL_LOG(
1433           context,
1434           "Only float32 and uint8 and int8 are supported currently, got %d.",
1435           TfLiteTypeGetName(input->type));
1436       return kTfLiteError;
1437   }
1438 }
1439 
1440 template <KernelType kernel_type, typename T>
QuantizeLeakyRelu(const TfLiteTensor * input,TfLiteTensor * output,const LeakyReluOpData * data)1441 void QuantizeLeakyRelu(const TfLiteTensor* input, TfLiteTensor* output,
1442                        const LeakyReluOpData* data) {
1443   LeakyReluParams op_params;
1444 
1445   op_params.input_offset = input->params.zero_point;
1446   op_params.output_offset = output->params.zero_point;
1447   op_params.output_multiplier_alpha = data->output_multiplier_alpha;
1448   op_params.output_shift_alpha = data->output_shift_alpha;
1449   op_params.output_multiplier_identity = data->output_multiplier_identity;
1450   op_params.output_shift_identity = data->output_shift_identity;
1451   if (kernel_type != KernelType::kReference && input->type == kTfLiteInt16) {
1452     optimized_integer_ops::QuantizeLeakyRelu(
1453         op_params, GetTensorShape(input), GetTensorData<int16>(input),
1454         GetTensorShape(output), GetTensorData<int16>(output));
1455   } else {
1456     reference_ops::QuantizeLeakyRelu(
1457         op_params, GetTensorShape(input), GetTensorData<T>(input),
1458         GetTensorShape(output), GetTensorData<T>(output));
1459   }
1460 }
1461 
1462 template <KernelType kernel_type>
LeakyReluEval(TfLiteContext * context,TfLiteNode * node)1463 TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
1464   const TfLiteTensor* input;
1465   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
1466   TfLiteTensor* output;
1467   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
1468   const auto* params =
1469       reinterpret_cast<TfLiteLeakyReluParams*>(node->builtin_data);
1470   const LeakyReluOpData* data =
1471       reinterpret_cast<LeakyReluOpData*>(node->user_data);
1472 
1473   LeakyReluParams op_params;
1474   switch (input->type) {
1475     case kTfLiteFloat32: {
1476       op_params.alpha = params->alpha;
1477       optimized_ops::LeakyRelu(
1478           op_params, GetTensorShape(input), GetTensorData<float>(input),
1479           GetTensorShape(output), GetTensorData<float>(output));
1480       return kTfLiteOk;
1481     }
1482     case kTfLiteUInt8: {
1483       QuantizeLeakyRelu<kernel_type, uint8_t>(input, output, data);
1484       return kTfLiteOk;
1485     }
1486     case kTfLiteInt8: {
1487       QuantizeLeakyRelu<kernel_type, int8_t>(input, output, data);
1488       return kTfLiteOk;
1489     }
1490     case kTfLiteInt16: {
1491       QuantizeLeakyRelu<kernel_type, int16_t>(input, output, data);
1492       return kTfLiteOk;
1493     }
1494     default:
1495       TF_LITE_KERNEL_LOG(
1496           context,
1497           "Only float32, int8, int16 and uint8 is supported currently, got %s.",
1498           TfLiteTypeGetName(input->type));
1499       return kTfLiteError;
1500   }
1501 }
1502 
EluPrepare(TfLiteContext * context,TfLiteNode * node)1503 TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
1504   const TfLiteTensor* input;
1505   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
1506   TfLiteTensor* output;
1507   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
1508   OpData* data = reinterpret_cast<OpData*>(node->user_data);
1509 
1510   // Use LUT to handle quantized elu path.
1511   if (input->type == kTfLiteInt8) {
1512     PopulateLookupTable<int8_t>(data, input, output, [](float value) {
1513       return value < 0.0f ? std::expm1(value) : value;
1514     });
1515   }
1516   return GenericPrepare(context, node);
1517 }
1518 
EluEval(TfLiteContext * context,TfLiteNode * node)1519 TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
1520   const TfLiteTensor* input;
1521   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
1522   TfLiteTensor* output;
1523   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
1524   switch (input->type) {
1525     case kTfLiteFloat32: {
1526       optimized_ops::Elu(GetTensorShape(input), GetTensorData<float>(input),
1527                          GetTensorShape(output), GetTensorData<float>(output));
1528       return kTfLiteOk;
1529     }
1530     case kTfLiteInt8: {
1531       OpData* data = reinterpret_cast<OpData*>(node->user_data);
1532       EvalUsingLookupTable(data, input, output);
1533       return kTfLiteOk;
1534     }
1535     default:
1536       TF_LITE_KERNEL_LOG(
1537           context, "Only float32 and int8 is supported currently, got %s.",
1538           TfLiteTypeGetName(input->type));
1539       return kTfLiteError;
1540   }
1541 }
1542 
GeluPrepare(TfLiteContext * context,TfLiteNode * node)1543 TfLiteStatus GeluPrepare(TfLiteContext* context, TfLiteNode* node) {
1544   const TfLiteTensor* input;
1545   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
1546   TfLiteTensor* output;
1547   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
1548   OpData* data = reinterpret_cast<OpData*>(node->user_data);
1549   auto* params = reinterpret_cast<TfLiteGeluParams*>(node->builtin_data);
1550 
1551   if (input->type == kTfLiteInt8) {
1552     PopulateLookupTable<int8_t>(
1553         data, input, output, reference_ops::GeluTransform(params->approximate));
1554   } else if (input->type == kTfLiteUInt8) {
1555     PopulateLookupTable<uint8_t>(
1556         data, input, output, reference_ops::GeluTransform(params->approximate));
1557   }
1558   return GenericPrepare(context, node);
1559 }
1560 
GeluEval(TfLiteContext * context,TfLiteNode * node)1561 TfLiteStatus GeluEval(TfLiteContext* context, TfLiteNode* node) {
1562   auto* params = reinterpret_cast<TfLiteGeluParams*>(node->builtin_data);
1563   const TfLiteTensor* input;
1564   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
1565   TfLiteTensor* output;
1566   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
1567 
1568   switch (input->type) {
1569     case kTfLiteFloat32: {
1570       reference_ops::Gelu(GetTensorShape(input), GetTensorData<float>(input),
1571                           params->approximate, GetTensorShape(output),
1572                           GetTensorData<float>(output));
1573       return kTfLiteOk;
1574     }
1575     case kTfLiteInt8:
1576     case kTfLiteUInt8: {
1577       OpData* data = reinterpret_cast<OpData*>(node->user_data);
1578       EvalUsingLookupTable(data, input, output);
1579       return kTfLiteOk;
1580     }
1581     default:
1582       TF_LITE_KERNEL_LOG(
1583           context, "Only float32, int8 and uint8 supported currently, got %s.",
1584           TfLiteTypeGetName(input->type));
1585       return kTfLiteError;
1586   }
1587   return kTfLiteOk;
1588 }
1589 
1590 }  // namespace activations
1591 
Register_ELU()1592 TfLiteRegistration* Register_ELU() {
1593   static TfLiteRegistration r = {activations::Init, activations::Free,
1594                                  activations::EluPrepare, activations::EluEval};
1595   return &r;
1596 }
1597 
Register_RELU()1598 TfLiteRegistration* Register_RELU() {
1599   static TfLiteRegistration r = {activations::ReluInit, activations::ReluFree,
1600                                  activations::ReluPrepare,
1601                                  activations::ReluEval};
1602   return &r;
1603 }
1604 
Register_RELU_N1_TO_1()1605 TfLiteRegistration* Register_RELU_N1_TO_1() {
1606   static TfLiteRegistration r = {activations::ReluInit, activations::ReluFree,
1607                                  activations::ReluPrepare,
1608                                  activations::Relu1Eval};
1609   return &r;
1610 }
1611 
Register_RELU6()1612 TfLiteRegistration* Register_RELU6() {
1613   static TfLiteRegistration r = {activations::ReluInit, activations::ReluFree,
1614                                  activations::ReluPrepare,
1615                                  activations::Relu6Eval};
1616   return &r;
1617 }
1618 
Register_RELU_0_TO_1()1619 TfLiteRegistration* Register_RELU_0_TO_1() {
1620   static TfLiteRegistration r = {activations::ReluInit, activations::ReluFree,
1621                                  activations::ReluPrepare,
1622                                  activations::Relu0to1Eval};
1623   return &r;
1624 }
1625 
Register_TANH_REF()1626 TfLiteRegistration* Register_TANH_REF() {
1627   static TfLiteRegistration r = {
1628       activations::Init, activations::Free,
1629       activations::TanhPrepare<activations::kReference>,
1630       activations::TanhEval<activations::kReference>};
1631   return &r;
1632 }
1633 
Register_TANH_GENERIC_OPT()1634 TfLiteRegistration* Register_TANH_GENERIC_OPT() {
1635   static TfLiteRegistration r = {
1636       activations::Init, activations::Free,
1637       activations::TanhPrepare<activations::kGenericOptimized>,
1638       activations::TanhEval<activations::kGenericOptimized>};
1639   return &r;
1640 }
1641 
Register_TANH_FIXED_POINT_OPT()1642 TfLiteRegistration* Register_TANH_FIXED_POINT_OPT() {
1643   static TfLiteRegistration r = {
1644       activations::Init, activations::Free,
1645       activations::TanhPrepare<activations::kFixedPointOptimized>,
1646       activations::TanhEval<activations::kFixedPointOptimized>};
1647   return &r;
1648 }
1649 
Register_TANH()1650 TfLiteRegistration* Register_TANH() {
1651   // TODO(b/134622898): Switch over from the LUT optimized method to the fixed
1652   // point optimized method when typical Android hardware performs better on
1653   // the latter one.
1654   return Register_TANH_GENERIC_OPT();
1655 }
1656 
Register_LOGISTIC_REF()1657 TfLiteRegistration* Register_LOGISTIC_REF() {
1658   static TfLiteRegistration r = {
1659       activations::Init, activations::Free,
1660       activations::SigmoidPrepare<activations::kReference>,
1661       activations::SigmoidEval<activations::kReference>};
1662   return &r;
1663 }
1664 
Register_LOGISTIC_GENERIC_OPT()1665 TfLiteRegistration* Register_LOGISTIC_GENERIC_OPT() {
1666   static TfLiteRegistration r = {
1667       activations::Init, activations::Free,
1668       activations::SigmoidPrepare<activations::kGenericOptimized>,
1669       activations::SigmoidEval<activations::kGenericOptimized>};
1670   return &r;
1671 }
1672 
Register_LOGISTIC_FIXED_POINT_OPT()1673 TfLiteRegistration* Register_LOGISTIC_FIXED_POINT_OPT() {
1674   static TfLiteRegistration r = {
1675       activations::Init, activations::Free,
1676       activations::SigmoidPrepare<activations::kFixedPointOptimized>,
1677       activations::SigmoidEval<activations::kFixedPointOptimized>};
1678   return &r;
1679 }
1680 
Register_LOGISTIC()1681 TfLiteRegistration* Register_LOGISTIC() {
1682   // TODO(b/134622898): Switch over from the LUT optimized method to the fixed
1683   // point optimized method when typical Android hardware performs better on
1684   // the latter one.
1685   return Register_LOGISTIC_GENERIC_OPT();
1686 }
1687 
Register_SOFTMAX_REF()1688 TfLiteRegistration* Register_SOFTMAX_REF() {
1689   static TfLiteRegistration r = {
1690       activations::SoftmaxInit, activations::SoftmaxFree,
1691       activations::SoftmaxPrepare<activations::kReference>,
1692       activations::SoftmaxEval<activations::kReference>};
1693   return &r;
1694 }
1695 
Register_SOFTMAX()1696 TfLiteRegistration* Register_SOFTMAX() {
1697   static TfLiteRegistration r = {
1698       activations::SoftmaxInit, activations::SoftmaxFree,
1699       activations::SoftmaxPrepare<activations::kGenericOptimized>,
1700       activations::SoftmaxEval<activations::kGenericOptimized>};
1701   return &r;
1702 }
1703 
Register_LOG_SOFTMAX_REF()1704 TfLiteRegistration* Register_LOG_SOFTMAX_REF() {
1705   static TfLiteRegistration r = {
1706       activations::LogSoftmaxInit, activations::LogSoftmaxFree,
1707       activations::LogSoftmaxPrepare<activations::kReference>,
1708       activations::LogSoftmaxEval<activations::kReference>};
1709   return &r;
1710 }
1711 
Register_LOG_SOFTMAX()1712 TfLiteRegistration* Register_LOG_SOFTMAX() {
1713   static TfLiteRegistration r = {
1714       activations::LogSoftmaxInit, activations::LogSoftmaxFree,
1715       activations::LogSoftmaxPrepare<activations::kGenericOptimized>,
1716       activations::LogSoftmaxEval<activations::kGenericOptimized>};
1717   return &r;
1718 }
1719 
Register_PRELU_REF()1720 TfLiteRegistration* Register_PRELU_REF() {
1721   static TfLiteRegistration r = {
1722       activations::PreluInit, activations::PreluFree, activations::PreluPrepare,
1723       activations::PreluEval<activations::kReference>};
1724   return &r;
1725 }
1726 
Register_PRELU()1727 TfLiteRegistration* Register_PRELU() {
1728   static TfLiteRegistration r = {
1729       activations::PreluInit, activations::PreluFree, activations::PreluPrepare,
1730       activations::PreluEval<activations::kGenericOptimized>};
1731   return &r;
1732 }
1733 
Register_LEAKY_RELU_REF()1734 TfLiteRegistration* Register_LEAKY_RELU_REF() {
1735   static TfLiteRegistration r = {
1736       activations::LeakyReluInit, activations::LeakyReluFree,
1737       activations::LeakyReluPrepare,
1738       activations::LeakyReluEval<activations::kReference>};
1739   return &r;
1740 }
1741 
Register_LEAKY_RELU()1742 TfLiteRegistration* Register_LEAKY_RELU() {
1743   static TfLiteRegistration r = {
1744       activations::LeakyReluInit, activations::LeakyReluFree,
1745       activations::LeakyReluPrepare,
1746       activations::LeakyReluEval<activations::kGenericOptimized>};
1747   return &r;
1748 }
1749 
Register_HARD_SWISH()1750 TfLiteRegistration* Register_HARD_SWISH() {
1751   static TfLiteRegistration r = {
1752       activations::HardSwishInit, activations::HardSwishFree,
1753       activations::HardSwishPrepare,
1754       activations::HardSwishEval<activations::kGenericOptimized>};
1755   return &r;
1756 }
1757 
Register_HARD_SWISH_REF()1758 TfLiteRegistration* Register_HARD_SWISH_REF() {
1759   static TfLiteRegistration r = {
1760       activations::HardSwishInit, activations::HardSwishFree,
1761       activations::HardSwishPrepare,
1762       activations::HardSwishEval<activations::kReference>};
1763   return &r;
1764 }
1765 
Register_GELU()1766 TfLiteRegistration* Register_GELU() {
1767   static TfLiteRegistration r = {activations::Init, activations::Free,
1768                                  activations::GeluPrepare,
1769                                  activations::GeluEval};
1770   return &r;
1771 }
1772 
1773 }  // namespace builtin
1774 }  // namespace ops
1775 }  // namespace tflite
1776