xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/numeric_verify.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <math.h>
16 #include <stddef.h>
17 #include <stdlib.h>
18 
19 #include <algorithm>
20 #include <cstdint>
21 #include <numeric>
22 #include <vector>
23 
24 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
25 #include "tensorflow/lite/c/common.h"
26 #include "tensorflow/lite/kernels/dequantize.h"
27 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
28 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
29 #include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h"
30 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
31 #include "tensorflow/lite/kernels/internal/tensor.h"
32 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
33 #include "tensorflow/lite/kernels/kernel_util.h"
34 
35 namespace tflite {
36 namespace ops {
37 namespace custom {
38 namespace numeric_verify {
39 
40 static constexpr const char kToleranceStr[] = "tolerance";
41 static constexpr const char kLogIfFailedStr[] = "log_if_failed";
42 static constexpr const int kTemporaryDequantizedTensor = 0;
43 static constexpr const int kOutputTensor = 0;
44 
45 struct OpContext {
OpContexttflite::ops::custom::numeric_verify::OpContext46   OpContext(TfLiteContext* context, TfLiteNode* node) {
47     input = GetInput(context, node, 0);
48     ref = GetInput(context, node, 1);
49     output = GetOutput(context, node, 0);
50   }
51   const TfLiteTensor* input;
52   const TfLiteTensor* ref;
53   TfLiteTensor* output;
54 };
55 
56 const int kTensorNotAllocated = -1;
57 
58 struct OpData {
59   // The percentage of the tensor value range. Must be a number less than 1.0.
60   float tolerance;
61   // This boolean value is only used when the input tensor is constant.
62   bool float_input_initialized;
63   int cache_tensor_id = kTensorNotAllocated;
64   // This boolean value is for controlling the behavior of numeric verify op.
65   bool log_if_failed;
66 };
67 
Init(TfLiteContext * context,const char * buffer,size_t length)68 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
69   auto* op_data = new OpData();
70   op_data->float_input_initialized = false;
71 
72   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
73   const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
74   const float tolerance = m[kToleranceStr].AsFloat();
75   const bool log_if_failed = m[kLogIfFailedStr].AsBool();
76   op_data->tolerance = tolerance;
77   op_data->log_if_failed = log_if_failed;
78 
79   return op_data;
80 }
81 
Free(TfLiteContext * context,void * buffer)82 void Free(TfLiteContext* context, void* buffer) {
83   delete reinterpret_cast<OpData*>(buffer);
84 }
85 
Prepare(TfLiteContext * context,TfLiteNode * node)86 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
87   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
88   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
89   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
90 
91   OpContext op_context(context, node);
92 
93   TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 ||
94                               op_context.input->type == kTfLiteInt8 ||
95                               op_context.input->type == kTfLiteInt16 ||
96                               op_context.input->type == kTfLiteFloat16);
97   TF_LITE_ENSURE(context, op_context.ref->type == kTfLiteFloat32);
98 
99   // Allocate tensor to store the dequantized inputs.
100   if (op_data->cache_tensor_id == kTensorNotAllocated) {
101     TF_LITE_ENSURE_OK(
102         context, context->AddTensors(context, 1, &op_data->cache_tensor_id));
103   }
104 
105   TfLiteIntArrayFree(node->temporaries);
106   node->temporaries = TfLiteIntArrayCreate(1);
107   node->temporaries->data[0] = op_data->cache_tensor_id;
108 
109   TfLiteTensor* dequantized;
110   TF_LITE_ENSURE_OK(context,
111                     GetTemporarySafe(context, node, kTemporaryDequantizedTensor,
112                                      &dequantized));
113   dequantized->type = op_context.ref->type;
114   dequantized->allocation_type = kTfLiteDynamic;
115 
116   TF_LITE_ENSURE_OK(context, context->ResizeTensor(
117                                  context, dequantized,
118                                  TfLiteIntArrayCopy(op_context.input->dims)));
119 
120   TF_LITE_ENSURE_OK(
121       context, GetOutputSafe(context, node, kOutputTensor, &op_context.output));
122   op_context.output->type = kTfLiteFloat32;
123   op_context.output->allocation_type = kTfLiteArenaRwPersistent;
124   return context->ResizeTensor(context, op_context.output,
125                                TfLiteIntArrayCopy(op_context.input->dims));
126 }
127 
GetQuantizedValue(const OpContext & op_context,int index)128 static int32_t GetQuantizedValue(const OpContext& op_context, int index) {
129   switch (op_context.input->type) {
130     case kTfLiteUInt8:
131       return GetTensorData<uint8_t>(op_context.input)[index];
132     case kTfLiteInt8:
133       return GetTensorData<int8_t>(op_context.input)[index];
134     case kTfLiteInt16:
135       return GetTensorData<int16_t>(op_context.input)[index];
136     default:
137       return 0;
138   }
139 }
140 
141 template <builtin::dequantize::KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)142 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
143   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
144   OpContext op_context(context, node);
145   if (IsConstantTensor(op_context.input) && op_data->float_input_initialized) {
146     return kTfLiteOk;
147   }
148 
149   // Dequantize the input
150   TfLiteTensor* dequantized;
151   TF_LITE_ENSURE_OK(context,
152                     GetTemporarySafe(context, node, kTemporaryDequantizedTensor,
153                                      &dequantized));
154   auto status = builtin::dequantize::DequantizeImpl<kernel_type>(
155       context, node, op_context.input, dequantized);
156   if (status != kTfLiteOk) {
157     return status;
158   }
159 
160   if (IsConstantTensor(op_context.input)) {
161     op_data->float_input_initialized = true;
162   }
163 
164   TF_LITE_ENSURE_OK(
165       context, GetOutputSafe(context, node, kOutputTensor, &op_context.output));
166   auto output_data = GetTensorData<float>(op_context.output);
167 
168   // If log_if_failed is on, calculate differences between float and
169   // quantized values, their statistics and output logs.
170   // Throw errors if any diff greater than tolerance exists.
171   const int n = NumElements(dequantized);
172   if (op_data->log_if_failed && op_data->tolerance >= 0.1) {
173     // Verify the dequantized output.
174     auto max_diff = op_data->tolerance * op_context.input->params.scale;
175     for (int i = 0; i < n; ++i) {
176       int32_t value = GetQuantizedValue(op_context, i);
177       float dequant = GetTensorData<float>(dequantized)[i];
178       float reference = GetTensorData<float>(op_context.ref)[i];
179       output_data[i] = dequant - reference;
180       float diff = std::abs(output_data[i]);
181       if (diff > max_diff) {
182         TF_LITE_KERNEL_LOG(
183             context,
184             "Mismatch: %f is quantized to %d with (%f, %d). "
185             "abs(%f - %f) = %f > %f (tolerance) range percentage %f.\n",
186             reference, value, op_context.input->params.scale,
187             op_context.input->params.zero_point, reference, dequant, diff,
188             max_diff, op_data->tolerance);
189         return kTfLiteError;
190       }
191     }
192   } else {
193     // If tolerance is small or log_if_failed is off, then we only care about
194     // statistics.
195     // These statistics logging was added to identify some errors in practice.
196     std::vector<double> diffs, temp;
197     diffs.reserve(n);
198     temp.reserve(n);
199     diffs.resize(n);
200     temp.resize(n);
201     for (int i = 0; i < n; ++i) {
202       float dequant = GetTensorData<float>(dequantized)[i];
203       float reference = GetTensorData<float>(op_context.ref)[i];
204       diffs[i] = static_cast<double>(dequant - reference);
205       output_data[i] = dequant - reference;
206     }
207     double mean =
208         std::accumulate(diffs.begin(), diffs.end(), 0.0) / diffs.size();
209     double max_diff = 0.0;
210     std::transform(diffs.begin(), diffs.end(), temp.begin(),
211                    [mean, &max_diff](double x) {
212                      max_diff = std::max(max_diff, std::abs(x));
213                      return x - mean;
214                    });
215     double sq_sum =
216         std::inner_product(temp.begin(), temp.end(), temp.begin(), 0.0);
217     double std = std::sqrt(sq_sum / diffs.size());
218     TF_LITE_KERNEL_LOG(
219         context,
220         "std: %f, mean: %f, max_diff: %f (scale: %f, zero_point: %d).\n", std,
221         mean, max_diff, op_context.input->params.scale,
222         op_context.input->params.zero_point);
223   }
224   return kTfLiteOk;
225 }
226 
227 }  // namespace numeric_verify
228 
Register_NUMERIC_VERIFY_OPT()229 TfLiteRegistration* Register_NUMERIC_VERIFY_OPT() {
230   static TfLiteRegistration r = {
231       numeric_verify::Init, numeric_verify::Free, numeric_verify::Prepare,
232       numeric_verify::Eval<builtin::dequantize::kGenericOptimized>};
233   return &r;
234 }
235 
Register_NUMERIC_VERIFY_REF()236 TfLiteRegistration* Register_NUMERIC_VERIFY_REF() {
237   static TfLiteRegistration r = {
238       numeric_verify::Init, numeric_verify::Free, numeric_verify::Prepare,
239       numeric_verify::Eval<builtin::dequantize::kReference>};
240   return &r;
241 }
242 
Register_NUMERIC_VERIFY()243 TfLiteRegistration* Register_NUMERIC_VERIFY() {
244 #ifdef USE_NEON
245   return Register_NUMERIC_VERIFY_OPT();
246 #else
247   return Register_NUMERIC_VERIFY_REF();
248 #endif
249 }
250 
251 }  // namespace custom
252 }  // namespace ops
253 }  // namespace tflite
254