xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/l2norm.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/c/builtin_op_data.h"
16 #include "tensorflow/lite/c/common.h"
17 #include "tensorflow/lite/kernels/internal/compatibility.h"
18 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
19 #include "tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h"
20 #include "tensorflow/lite/kernels/internal/reference/l2normalization.h"
21 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
22 #include "tensorflow/lite/kernels/internal/tensor.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/internal/types.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26 
27 namespace tflite {
28 namespace ops {
29 namespace builtin {
30 namespace l2norm {
31 
32 // This file has two implementation of L2Norm.
33 enum KernelType {
34   kReference,
35   kGenericOptimized,
36 };
37 
38 constexpr int kInputTensor = 0;
39 constexpr int kOutputTensor = 0;
40 
Prepare(TfLiteContext * context,TfLiteNode * node)41 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
42   auto* params = reinterpret_cast<TfLiteL2NormParams*>(node->builtin_data);
43 
44   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
45   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
46 
47   const TfLiteTensor* input;
48   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
49   TfLiteTensor* output;
50   TF_LITE_ENSURE_OK(context,
51                     GetOutputSafe(context, node, kOutputTensor, &output));
52 
53   TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
54 
55   TF_LITE_ENSURE(context, output->type == kTfLiteFloat32 ||
56                               output->type == kTfLiteUInt8 ||
57                               output->type == kTfLiteInt8);
58   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
59 
60   if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
61     TF_LITE_ENSURE_EQ(context, output->params.scale, (1. / 128.));
62     if (output->type == kTfLiteUInt8) {
63       TF_LITE_ENSURE_EQ(context, output->params.zero_point, 128);
64     }
65     if (output->type == kTfLiteInt8) {
66       TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
67     }
68   }
69 
70   // TODO(ahentz): For some reason our implementations don't support
71   // activations.
72   TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
73 
74   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
75   return context->ResizeTensor(context, output, output_size);
76 }
77 
78 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)79 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
80   const TfLiteTensor* input;
81   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
82   TfLiteTensor* output;
83   TF_LITE_ENSURE_OK(context,
84                     GetOutputSafe(context, node, kOutputTensor, &output));
85 
86   // TODO(b/143912164): instead of hardcode the epsilon here, we should read it
87   // from tensorflow, i.e., adding a params.
88   // We don't compute epsilon for quantized kernel:
89   //
90   // epsilon_float = (epsilon_quant - zp) * scale
91   // so
92   // espsilon_quant = epsilon_float / scale + zp
93   // We know epsilon_float is just a very small number to avoid division by
94   // zero error, and scale is > 1, so the integer value of epsilon for quant
95   // is just dominated by the zero point.
96   // Also, GetInvSqrtQuantizedMultiplierExp handles the scenario where the sum
97   // of input value squared is zero case well.
98   // So we don't even need to do handle the epsilon for quantized kernel case.
99   const float epsilon = 1e-6f;
100   if (output->type == kTfLiteFloat32) {
101 #define TF_LITE_L2NORM(type)                                                 \
102   tflite::L2NormalizationParams op_params;                                   \
103   op_params.input_zero_point = 0;                                            \
104   type::L2Normalization(op_params, GetTensorShape(input),                    \
105                         GetTensorData<float>(input), GetTensorShape(output), \
106                         GetTensorData<float>(output), epsilon)
107 
108     if (kernel_type == kReference) {
109       TF_LITE_L2NORM(reference_ops);
110     }
111     if (kernel_type == kGenericOptimized) {
112       TF_LITE_L2NORM(optimized_ops);
113     }
114 #undef TF_LITE_L2NORM
115   } else if (output->type == kTfLiteUInt8) {
116 #define TF_LITE_L2NORM(type)                                                 \
117   tflite::L2NormalizationParams op_params;                                   \
118   op_params.input_zero_point = input->params.zero_point;                     \
119   type::L2Normalization(op_params, GetTensorShape(input),                    \
120                         GetTensorData<uint8>(input), GetTensorShape(output), \
121                         GetTensorData<uint8>(output))
122 
123     if (kernel_type == kReference) {
124       TF_LITE_L2NORM(reference_ops);
125     }
126     if (kernel_type == kGenericOptimized) {
127       TF_LITE_L2NORM(optimized_ops);
128     }
129 #undef TF_LITE_L2NORM
130   } else if (output->type == kTfLiteInt8) {
131     const auto input_shape = GetTensorShape(input);
132     const auto output_shape = GetTensorShape(output);
133     const int trailing_dim = input_shape.DimensionsCount() - 1;
134     const int depth =
135         MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
136     const int outer_size =
137         MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
138     reference_integer_ops::L2Normalization(input->params.zero_point, outer_size,
139                                            depth, GetTensorData<int8>(input),
140                                            GetTensorData<int8>(output));
141   } else {
142     TF_LITE_KERNEL_LOG(context, "Output type is %s, requires float.",
143                        TfLiteTypeGetName(output->type));
144     return kTfLiteError;
145   }
146 
147   return kTfLiteOk;
148 }
149 
150 }  // namespace l2norm
151 
Register_L2NORM_REF()152 TfLiteRegistration* Register_L2NORM_REF() {
153   static TfLiteRegistration r = {nullptr, nullptr, l2norm::Prepare,
154                                  l2norm::Eval<l2norm::kReference>};
155   return &r;
156 }
157 
Register_L2NORM_GENERIC_OPT()158 TfLiteRegistration* Register_L2NORM_GENERIC_OPT() {
159   static TfLiteRegistration r = {nullptr, nullptr, l2norm::Prepare,
160                                  l2norm::Eval<l2norm::kGenericOptimized>};
161   return &r;
162 }
163 
Register_L2_NORMALIZATION()164 TfLiteRegistration* Register_L2_NORMALIZATION() {
165   return Register_L2NORM_GENERIC_OPT();
166 }
167 
168 }  // namespace builtin
169 }  // namespace ops
170 }  // namespace tflite
171