xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/optimize/model_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/model_utils.h"
16 
17 #include <fstream>
18 #include <memory>
19 #include <string>
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/model.h"
25 #include "tensorflow/lite/schema/schema_conversion_utils.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27 #include "tensorflow/lite/schema/schema_utils.h"
28 #include "tensorflow/lite/tools/optimize/operator_property.h"
29 
30 namespace tflite {
31 namespace optimize {
32 namespace utils {
33 
34 namespace {
35 
36 // Returns the index of the OpCode.
37 // If a OpCode doesn't exist, adds it and returns its index.
GetOrInsertOpCodeIndex(ModelT * model,const BuiltinOperator & op_code,int32_t version)38 int32_t GetOrInsertOpCodeIndex(ModelT* model, const BuiltinOperator& op_code,
39                                int32_t version) {
40   for (size_t i = 0; i < model->operator_codes.size(); ++i) {
41     if (GetBuiltinCode(model->operator_codes[i].get()) == op_code) {
42       return i;
43     }
44   }
45   model->operator_codes.push_back(std::make_unique<OperatorCodeT>());
46   int op_code_idx = model->operator_codes.size() - 1;
47   model->operator_codes[op_code_idx]->builtin_code = op_code;
48   model->operator_codes[op_code_idx]->deprecated_builtin_code =
49       ConvertBuiltinCodeToDeprecatedBuiltinCode(op_code);
50   // Version 2 and onwards supports INT8 inputs.
51   model->operator_codes[op_code_idx]->version = version;
52 
53   // Return the index of the newly placed OperatorCodeT.
54   return op_code_idx;
55 }
56 
57 }  // namespace
58 
59 // Creates a Dequantize OperatorT object.
MakeDequantizeOperator(ModelT * model,std::unique_ptr<OperatorT> * op,int32_t input,int32_t output)60 void MakeDequantizeOperator(ModelT* model, std::unique_ptr<OperatorT>* op,
61                             int32_t input, int32_t output) {
62   OperatorT* op_raw = new OperatorT;
63   // Version 2 and onwards supports INT8 inputs.
64   op_raw->opcode_index =
65       GetOrInsertOpCodeIndex(model, BuiltinOperator_DEQUANTIZE, 2);
66   op_raw->inputs = {input};
67   op_raw->outputs = {output};
68 
69   op->reset(op_raw);
70 }
71 
72 // Creates a Quantize OperatorT object.
MakeQuantizeOperator(ModelT * model,std::unique_ptr<OperatorT> * op,int32_t input,int32_t output)73 void MakeQuantizeOperator(ModelT* model, std::unique_ptr<OperatorT>* op,
74                           int32_t input, int32_t output) {
75   OperatorT* op_raw = new OperatorT;
76   op_raw->opcode_index =
77       GetOrInsertOpCodeIndex(model, BuiltinOperator_QUANTIZE, 1);
78   op_raw->inputs = {input};
79   op_raw->outputs = {output};
80 
81   op->reset(op_raw);
82 }
83 
84 // Create a new TensorT object without quantization parameters.
MakeTensor(const string & name,const std::vector<int32_t> & shape,const std::vector<int32_t> & shape_signature,const TensorType & type,std::unique_ptr<TensorT> * tensor)85 void MakeTensor(const string& name, const std::vector<int32_t>& shape,
86                 const std::vector<int32_t>& shape_signature,
87                 const TensorType& type, std::unique_ptr<TensorT>* tensor) {
88   TensorT* tensor_raw = new TensorT;
89   tensor_raw->name = name;
90   tensor_raw->shape = shape;
91   if (!shape_signature.empty()) {
92     tensor_raw->shape_signature = shape_signature;
93   }
94   tensor_raw->type = type;
95 
96   tensor->reset(tensor_raw);
97 }
98 
99 // Create a new TensorT object with quantization parameters.
MakeTensorWithQuantParam(const string & name,const std::vector<int32_t> & shape,const std::vector<int32_t> & shape_signature,const TensorType & type,float scale,int64_t zero_point,std::unique_ptr<TensorT> * tensor)100 void MakeTensorWithQuantParam(const string& name,
101                               const std::vector<int32_t>& shape,
102                               const std::vector<int32_t>& shape_signature,
103                               const TensorType& type, float scale,
104                               int64_t zero_point,
105                               std::unique_ptr<TensorT>* tensor) {
106   MakeTensor(name, shape, shape_signature, type, tensor);
107   (*tensor)->quantization = std::make_unique<QuantizationParametersT>();
108   (*tensor)->quantization->scale.push_back(scale);
109   (*tensor)->quantization->zero_point.push_back(zero_point);
110 }
111 
QuantizationParametersExist(const TensorT * tensor)112 bool QuantizationParametersExist(const TensorT* tensor) {
113   return tensor->quantization != nullptr &&
114          !tensor->quantization->scale.empty() &&
115          !tensor->quantization->zero_point.empty();
116 }
117 
HasBuffer(const ModelT * model,const SubGraphT * subgraph,int tensor_index)118 bool HasBuffer(const ModelT* model, const SubGraphT* subgraph,
119                int tensor_index) {
120   const int buffer_index = subgraph->tensors[tensor_index]->buffer;
121   BufferT* buffer = model->buffers[buffer_index].get();
122   if (buffer == nullptr || buffer->data.empty()) {
123     return false;
124   }
125   return true;
126 }
127 
HasMinMax(const TensorT * tensor)128 bool HasMinMax(const TensorT* tensor) {
129   return tensor->quantization && !tensor->quantization->min.empty() &&
130          !tensor->quantization->max.empty();
131 }
132 
SetOperatorCodeVersion(ModelT * model)133 void SetOperatorCodeVersion(ModelT* model) {
134   for (int subgraph_idx = 0, end = model->subgraphs.size(); subgraph_idx < end;
135        subgraph_idx++) {
136     SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
137     // Iterate backward to avoid messing with index.
138     for (int op_idx = subgraph->operators.size() - 1; op_idx >= 0; op_idx--) {
139       OperatorT* op = subgraph->operators[op_idx].get();
140       OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get();
141       operator_property::OperatorProperty property =
142           operator_property::GetOperatorProperty(model, subgraph_idx, op_idx);
143       if (property.quantizable && op_code->version < property.version) {
144         // Only update the versions of quantizable operations if the original
145         // version is lesser than minimum quantized one mentioned by
146         // OperatorProperty.
147         op_code->version = property.version;
148       }
149     }
150   }
151 }
152 
WriteFile(const std::string & out_file,const uint8_t * bytes,size_t num_bytes)153 void WriteFile(const std::string& out_file, const uint8_t* bytes,
154                size_t num_bytes) {
155   std::fstream stream(out_file, std::ios::binary | std::ios::out);
156   for (size_t i = 0; i < num_bytes; i++) {
157     stream << bytes[i];
158   }
159   TFLITE_DCHECK(!stream.bad() && !stream.fail());
160 }
161 
FinishModel(const tflite::ModelT * model)162 std::unique_ptr<flatbuffers::FlatBufferBuilder> FinishModel(
163     const tflite::ModelT* model) {
164   std::unique_ptr<flatbuffers::FlatBufferBuilder> builder(
165       new flatbuffers::FlatBufferBuilder());
166   auto packed_model = tflite::Model::Pack(*builder, model);
167   tflite::FinishModelBuffer(*builder, packed_model);
168   return builder;
169 }
170 
CreateMutableModelFromFile(const string & model_filepath)171 std::unique_ptr<tflite::ModelT> CreateMutableModelFromFile(
172     const string& model_filepath) {
173   auto fb_model =
174       tflite::FlatBufferModel::BuildFromFile(model_filepath.c_str());
175   auto tflite_model = fb_model->GetModel();
176   auto copied_model = std::make_unique<tflite::ModelT>();
177   tflite_model->UnPackTo(copied_model.get(), nullptr);
178   return copied_model;
179 }
180 
181 }  // namespace utils
182 }  // namespace optimize
183 }  // namespace tflite
184