1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/model_utils.h"
16
17 #include <fstream>
18 #include <memory>
19 #include <string>
20
21 #include "absl/memory/memory.h"
22 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/model.h"
25 #include "tensorflow/lite/schema/schema_conversion_utils.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27 #include "tensorflow/lite/schema/schema_utils.h"
28 #include "tensorflow/lite/tools/optimize/operator_property.h"
29
30 namespace tflite {
31 namespace optimize {
32 namespace utils {
33
34 namespace {
35
36 // Returns the index of the OpCode.
37 // If a OpCode doesn't exist, adds it and returns its index.
GetOrInsertOpCodeIndex(ModelT * model,const BuiltinOperator & op_code,int32_t version)38 int32_t GetOrInsertOpCodeIndex(ModelT* model, const BuiltinOperator& op_code,
39 int32_t version) {
40 for (size_t i = 0; i < model->operator_codes.size(); ++i) {
41 if (GetBuiltinCode(model->operator_codes[i].get()) == op_code) {
42 return i;
43 }
44 }
45 model->operator_codes.push_back(std::make_unique<OperatorCodeT>());
46 int op_code_idx = model->operator_codes.size() - 1;
47 model->operator_codes[op_code_idx]->builtin_code = op_code;
48 model->operator_codes[op_code_idx]->deprecated_builtin_code =
49 ConvertBuiltinCodeToDeprecatedBuiltinCode(op_code);
50 // Version 2 and onwards supports INT8 inputs.
51 model->operator_codes[op_code_idx]->version = version;
52
53 // Return the index of the newly placed OperatorCodeT.
54 return op_code_idx;
55 }
56
57 } // namespace
58
59 // Creates a Dequantize OperatorT object.
MakeDequantizeOperator(ModelT * model,std::unique_ptr<OperatorT> * op,int32_t input,int32_t output)60 void MakeDequantizeOperator(ModelT* model, std::unique_ptr<OperatorT>* op,
61 int32_t input, int32_t output) {
62 OperatorT* op_raw = new OperatorT;
63 // Version 2 and onwards supports INT8 inputs.
64 op_raw->opcode_index =
65 GetOrInsertOpCodeIndex(model, BuiltinOperator_DEQUANTIZE, 2);
66 op_raw->inputs = {input};
67 op_raw->outputs = {output};
68
69 op->reset(op_raw);
70 }
71
72 // Creates a Quantize OperatorT object.
MakeQuantizeOperator(ModelT * model,std::unique_ptr<OperatorT> * op,int32_t input,int32_t output)73 void MakeQuantizeOperator(ModelT* model, std::unique_ptr<OperatorT>* op,
74 int32_t input, int32_t output) {
75 OperatorT* op_raw = new OperatorT;
76 op_raw->opcode_index =
77 GetOrInsertOpCodeIndex(model, BuiltinOperator_QUANTIZE, 1);
78 op_raw->inputs = {input};
79 op_raw->outputs = {output};
80
81 op->reset(op_raw);
82 }
83
84 // Create a new TensorT object without quantization parameters.
MakeTensor(const string & name,const std::vector<int32_t> & shape,const std::vector<int32_t> & shape_signature,const TensorType & type,std::unique_ptr<TensorT> * tensor)85 void MakeTensor(const string& name, const std::vector<int32_t>& shape,
86 const std::vector<int32_t>& shape_signature,
87 const TensorType& type, std::unique_ptr<TensorT>* tensor) {
88 TensorT* tensor_raw = new TensorT;
89 tensor_raw->name = name;
90 tensor_raw->shape = shape;
91 if (!shape_signature.empty()) {
92 tensor_raw->shape_signature = shape_signature;
93 }
94 tensor_raw->type = type;
95
96 tensor->reset(tensor_raw);
97 }
98
99 // Create a new TensorT object with quantization parameters.
MakeTensorWithQuantParam(const string & name,const std::vector<int32_t> & shape,const std::vector<int32_t> & shape_signature,const TensorType & type,float scale,int64_t zero_point,std::unique_ptr<TensorT> * tensor)100 void MakeTensorWithQuantParam(const string& name,
101 const std::vector<int32_t>& shape,
102 const std::vector<int32_t>& shape_signature,
103 const TensorType& type, float scale,
104 int64_t zero_point,
105 std::unique_ptr<TensorT>* tensor) {
106 MakeTensor(name, shape, shape_signature, type, tensor);
107 (*tensor)->quantization = std::make_unique<QuantizationParametersT>();
108 (*tensor)->quantization->scale.push_back(scale);
109 (*tensor)->quantization->zero_point.push_back(zero_point);
110 }
111
QuantizationParametersExist(const TensorT * tensor)112 bool QuantizationParametersExist(const TensorT* tensor) {
113 return tensor->quantization != nullptr &&
114 !tensor->quantization->scale.empty() &&
115 !tensor->quantization->zero_point.empty();
116 }
117
HasBuffer(const ModelT * model,const SubGraphT * subgraph,int tensor_index)118 bool HasBuffer(const ModelT* model, const SubGraphT* subgraph,
119 int tensor_index) {
120 const int buffer_index = subgraph->tensors[tensor_index]->buffer;
121 BufferT* buffer = model->buffers[buffer_index].get();
122 if (buffer == nullptr || buffer->data.empty()) {
123 return false;
124 }
125 return true;
126 }
127
HasMinMax(const TensorT * tensor)128 bool HasMinMax(const TensorT* tensor) {
129 return tensor->quantization && !tensor->quantization->min.empty() &&
130 !tensor->quantization->max.empty();
131 }
132
SetOperatorCodeVersion(ModelT * model)133 void SetOperatorCodeVersion(ModelT* model) {
134 for (int subgraph_idx = 0, end = model->subgraphs.size(); subgraph_idx < end;
135 subgraph_idx++) {
136 SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
137 // Iterate backward to avoid messing with index.
138 for (int op_idx = subgraph->operators.size() - 1; op_idx >= 0; op_idx--) {
139 OperatorT* op = subgraph->operators[op_idx].get();
140 OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get();
141 operator_property::OperatorProperty property =
142 operator_property::GetOperatorProperty(model, subgraph_idx, op_idx);
143 if (property.quantizable && op_code->version < property.version) {
144 // Only update the versions of quantizable operations if the original
145 // version is lesser than minimum quantized one mentioned by
146 // OperatorProperty.
147 op_code->version = property.version;
148 }
149 }
150 }
151 }
152
WriteFile(const std::string & out_file,const uint8_t * bytes,size_t num_bytes)153 void WriteFile(const std::string& out_file, const uint8_t* bytes,
154 size_t num_bytes) {
155 std::fstream stream(out_file, std::ios::binary | std::ios::out);
156 for (size_t i = 0; i < num_bytes; i++) {
157 stream << bytes[i];
158 }
159 TFLITE_DCHECK(!stream.bad() && !stream.fail());
160 }
161
FinishModel(const tflite::ModelT * model)162 std::unique_ptr<flatbuffers::FlatBufferBuilder> FinishModel(
163 const tflite::ModelT* model) {
164 std::unique_ptr<flatbuffers::FlatBufferBuilder> builder(
165 new flatbuffers::FlatBufferBuilder());
166 auto packed_model = tflite::Model::Pack(*builder, model);
167 tflite::FinishModelBuffer(*builder, packed_model);
168 return builder;
169 }
170
CreateMutableModelFromFile(const string & model_filepath)171 std::unique_ptr<tflite::ModelT> CreateMutableModelFromFile(
172 const string& model_filepath) {
173 auto fb_model =
174 tflite::FlatBufferModel::BuildFromFile(model_filepath.c_str());
175 auto tflite_model = fb_model->GetModel();
176 auto copied_model = std::make_unique<tflite::ModelT>();
177 tflite_model->UnPackTo(copied_model.get(), nullptr);
178 return copied_model;
179 }
180
181 } // namespace utils
182 } // namespace optimize
183 } // namespace tflite
184