1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/python/toco_python_api.h"
16
17 #include <fstream>
18 #include <map>
19 #include <memory>
20 #include <string>
21 #include <vector>
22
23 #include "google/protobuf/text_format.h"
24 #include "tensorflow/c/kernels.h"
25 #include "tensorflow/compiler/mlir/lite/metrics/error_collector.h"
26 #include "tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.h"
27 #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
28 #include "tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h"
29 #include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
30 #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h"
31 #include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h"
32 #include "tensorflow/core/framework/op.h"
33 #include "tensorflow/core/framework/op_def.pb.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/lite/c/common.h"
36 #include "tensorflow/lite/core/api/error_reporter.h"
37 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
38 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
39 #include "tensorflow/lite/schema/schema_generated.h"
40 #include "tensorflow/lite/toco/import_tensorflow.h"
41 #include "tensorflow/lite/toco/logging/conversion_log_util.h"
42 #include "tensorflow/lite/toco/logging/toco_conversion_log.pb.h"
43 #include "tensorflow/lite/toco/model_flags.pb.h"
44 #include "tensorflow/lite/toco/toco_convert.h"
45 #include "tensorflow/lite/toco/toco_flags.pb.h"
46 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
47 #include "tensorflow/lite/toco/toco_port.h"
48 #include "tensorflow/lite/toco/toco_tooling.h"
49 #include "tensorflow/lite/toco/toco_types.h"
50 #include "tensorflow/lite/toco/tooling_util.h"
51 #include "tensorflow/lite/toco/types.pb.h"
52
53 namespace toco {
54 using mlir::lite::StringSet;
55
PopulateConversionLogHelper(const toco::ModelFlags & model_flags,toco::TocoFlags * toco_flags,const std::string & input_contents_txt,const std::string & output_file_contents_txt,const std::string & error_message,GraphVizDumpOptions * dump_options)56 void PopulateConversionLogHelper(const toco::ModelFlags& model_flags,
57 toco::TocoFlags* toco_flags,
58 const std::string& input_contents_txt,
59 const std::string& output_file_contents_txt,
60 const std::string& error_message,
61 GraphVizDumpOptions* dump_options) {
62 // Make sure the graphviz file will be dumped under the same folder.
63 dump_options->dump_graphviz = toco_flags->conversion_summary_dir();
64 // Here we construct the `toco::Model` class based on the input graph def,
65 // it will then be used to populate the conversion log.
66 // TODO(haoliang): Don't depend on `toco::Model`.
67 std::unique_ptr<toco::Model> imported_model =
68 toco::Import(*toco_flags, model_flags, input_contents_txt);
69 // Dump pre-conversion toco logs.
70 TocoConversionLog toco_log_before;
71 PopulateConversionLog(*imported_model, &toco_log_before);
72 std::ofstream osstream_before(toco_flags->conversion_summary_dir() +
73 "/toco_log_before.pb");
74 toco_log_before.SerializeToOstream(&osstream_before);
75 osstream_before.close();
76 toco::LogDump(toco::kLogLevelModelChanged, "tf_graph", *imported_model);
77
78 // Populate the post-conversion log, for convenient initiate the
79 // `toco::Model` class from the generated flatbuffer.
80 toco_flags->set_input_format(toco::FileFormat::TFLITE);
81 std::unique_ptr<toco::Model> flatbuffer_model =
82 toco::Import(*toco_flags, model_flags, output_file_contents_txt);
83 // Dump post-conversion toco logs.
84 TocoConversionLog toco_log_after;
85 PopulateConversionLog(*flatbuffer_model, &toco_log_after);
86 // Make sure we sanitize the error message.
87 toco_log_after.set_toco_err_logs(SanitizeErrorMessage(error_message));
88 std::ofstream ostream_after(toco_flags->conversion_summary_dir() +
89 "/toco_log_after.pb");
90 toco_log_after.SerializeToOstream(&ostream_after);
91 ostream_after.close();
92 toco::LogDump(toco::kLogLevelModelChanged, "tflite_graph", *flatbuffer_model);
93 }
94
95 // NOTE(aselle): We are using raw PyObject's here because we want to make
96 // sure we input and output bytes rather than unicode strings for Python3.
TocoConvert(PyObject * model_flags_proto_txt_raw,PyObject * toco_flags_proto_txt_raw,PyObject * input_contents_txt_raw,bool extended_return,PyObject * debug_info_txt_raw,bool enable_mlir_converter)97 PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
98 PyObject* toco_flags_proto_txt_raw,
99 PyObject* input_contents_txt_raw, bool extended_return,
100 PyObject* debug_info_txt_raw,
101 bool enable_mlir_converter) {
102 // Use Python C API to validate and convert arguments. In py3 (bytes),
103 // in py2 (str).
104 auto ConvertArg = [&](PyObject* obj, bool* error) {
105 char* buf;
106 Py_ssize_t len;
107 if (::tflite::python_utils::ConvertFromPyString(obj, &buf, &len) == -1) {
108 *error = true;
109 return std::string();
110 } else {
111 *error = false;
112 return std::string(buf, len);
113 }
114 };
115
116 bool error;
117 std::string model_flags_proto_txt =
118 ConvertArg(model_flags_proto_txt_raw, &error);
119 if (error) {
120 PyErr_SetString(PyExc_ValueError, "Model flags are invalid.");
121 return nullptr;
122 }
123 std::string toco_flags_proto_txt =
124 ConvertArg(toco_flags_proto_txt_raw, &error);
125 if (error) {
126 PyErr_SetString(PyExc_ValueError, "Toco flags are invalid.");
127 return nullptr;
128 }
129
130 // Use TOCO to produce new outputs.
131 toco::ModelFlags model_flags;
132 if (!model_flags.ParseFromString(model_flags_proto_txt)) {
133 PyErr_SetString(PyExc_ValueError,
134 "Failed to convert Model to Python String.");
135 return nullptr;
136 }
137 toco::TocoFlags toco_flags;
138 if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
139 PyErr_SetString(PyExc_ValueError,
140 "Failed to convert Toco to Python String.");
141 return nullptr;
142 }
143
144 tensorflow::GraphDebugInfo debug_info;
145 if (debug_info_txt_raw && debug_info_txt_raw != Py_None) {
146 std::string debug_info_txt = ConvertArg(debug_info_txt_raw, &error);
147 if (error) {
148 PyErr_SetString(PyExc_ValueError, "Input DebugInfo is invalid.");
149 return nullptr;
150 }
151 if (!debug_info.ParseFromString(debug_info_txt)) {
152 PyErr_SetString(PyExc_ValueError,
153 "Failed to convert DebugInfo to Python String.");
154 return nullptr;
155 }
156 }
157
158 tensorflow::GraphDef graph_def;
159 std::string input_contents_txt;
160 if (model_flags.saved_model_dir().empty()) {
161 input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
162 if (error) {
163 PyErr_SetString(PyExc_ValueError, "Input GraphDef is invalid.");
164 return nullptr;
165 }
166 if (!model_flags.use_hlo_import() &&
167 !graph_def.ParseFromString(input_contents_txt)) {
168 PyErr_SetString(PyExc_ValueError,
169 "Failed to convert GraphDef to Python String.");
170 return nullptr;
171 }
172 }
173
174 auto& dump_options = *GraphVizDumpOptions::singleton();
175 if (toco_flags.has_dump_graphviz_dir()) {
176 dump_options.dump_graphviz = toco_flags.dump_graphviz_dir();
177 }
178 if (toco_flags.has_dump_graphviz_include_video()) {
179 dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video();
180 }
181
182 std::string output_file_contents_txt;
183 tensorflow::Status status;
184 int64_t arithmetic_ops_count;
185
186 // Convert model.
187 if (enable_mlir_converter) {
188 if (model_flags.use_hlo_import() && model_flags.has_saved_model_dir()) {
189 PyErr_SetString(PyExc_ValueError,
190 "Cannot specify both saved_model and hlo import.");
191 return nullptr;
192 }
193
194 if (model_flags.use_hlo_import()) {
195 status = tensorflow::ConvertJaxToTFLiteFlatBuffer(
196 input_contents_txt, model_flags, toco_flags,
197 &output_file_contents_txt);
198 } else if (!model_flags.saved_model_dir().empty()) {
199 status = tensorflow::ConvertSavedModelToTFLiteFlatBuffer(
200 model_flags, toco_flags, &output_file_contents_txt);
201 } else {
202 tensorflow::GraphDef graph_def;
203 if (!graph_def.ParseFromString(input_contents_txt)) {
204 PyErr_SetString(PyExc_ValueError,
205 "Failed to convert GraphDef to Python String.");
206 return nullptr;
207 }
208
209 status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
210 model_flags, toco_flags, debug_info, graph_def,
211 &output_file_contents_txt);
212 if (!toco_flags.conversion_summary_dir().empty()) {
213 PopulateConversionLogHelper(
214 model_flags, &toco_flags, input_contents_txt,
215 output_file_contents_txt, status.error_message(), &dump_options);
216 }
217 }
218 } else {
219 status = Convert(input_contents_txt, toco_flags, model_flags,
220 &output_file_contents_txt, &arithmetic_ops_count);
221 }
222
223 if (!status.ok()) {
224 PyErr_SetString(PyExc_Exception, status.error_message().c_str());
225 return nullptr;
226 }
227 if (extended_return && !enable_mlir_converter) {
228 PyObject* dict = PyDict_New();
229 PyDict_SetItemString(
230 dict, "flatbuffer",
231 ::tflite::python_utils::ConvertToPyString(
232 output_file_contents_txt.data(), output_file_contents_txt.size()));
233 PyDict_SetItemString(dict, "arithmetic_ops",
234 PyLong_FromLong(arithmetic_ops_count));
235 return dict;
236 }
237 // Convert arguments back to byte (py3) or str (py2)
238 return ::tflite::python_utils::ConvertToPyString(
239 output_file_contents_txt.data(), output_file_contents_txt.size());
240 }
241
FromTocoDataTypeToTflitToTensorType(int inference_type)242 tflite::TensorType FromTocoDataTypeToTflitToTensorType(int inference_type) {
243 switch (inference_type) {
244 case toco::IODataType::QUANTIZED_INT16:
245 return tflite::TensorType_INT16;
246 case toco::IODataType::QUANTIZED_UINT8:
247 return tflite::TensorType_UINT8;
248 case toco::IODataType::UINT8:
249 return tflite::TensorType_UINT8;
250 case toco::IODataType::QUANTIZED_INT8:
251 return tflite::TensorType_INT8;
252 case toco::IODataType::INT8:
253 return tflite::TensorType_INT8;
254 default:
255 return tflite::TensorType_FLOAT32;
256 }
257 }
258
ToStringSet(PyObject * py_denylist,StringSet * string_set)259 int ToStringSet(PyObject* py_denylist, StringSet* string_set) {
260 using tflite::python_utils::ConvertFromPyString;
261 // Ensure op_denylist is non null
262 if (!py_denylist) {
263 return 0;
264 }
265 if (PyList_Check(py_denylist)) {
266 for (int i = 0; i < PyList_GET_SIZE(py_denylist); ++i) {
267 PyObject* value = PyList_GetItem(py_denylist, i);
268 char* str_buf;
269 Py_ssize_t length;
270 if (ConvertFromPyString(value, &str_buf, &length) == -1) {
271 return -1;
272 }
273 string_set->emplace(str_buf, length);
274 }
275 }
276 if (PySet_Check(py_denylist)) {
277 auto* tmp = PySet_New(py_denylist);
278 while (PySet_GET_SIZE(tmp)) {
279 PyObject* value = PySet_Pop(tmp);
280 char* str_buf;
281 Py_ssize_t length;
282 if (ConvertFromPyString(value, &str_buf, &length) == -1) {
283 return -1;
284 }
285 string_set->emplace(str_buf, length);
286 }
287 }
288 return 0;
289 }
290
MlirQuantizeModel(PyObject * data,bool disable_per_channel,bool fully_quantize,int inference_type,int input_data_type,int output_data_type,bool enable_numeric_verify,bool enable_whole_model_verify,PyObject * op_denylist,PyObject * node_denylist)291 PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
292 bool fully_quantize, int inference_type,
293 int input_data_type, int output_data_type,
294 bool enable_numeric_verify,
295 bool enable_whole_model_verify,
296 PyObject* op_denylist, PyObject* node_denylist) {
297 using tflite::interpreter_wrapper::PythonErrorReporter;
298 char* buf = nullptr;
299 Py_ssize_t length;
300 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
301
302 if (tflite::python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
303 PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject");
304 return nullptr;
305 }
306
307 StringSet denylisted_ops;
308 StringSet denylisted_nodes;
309 if (ToStringSet(op_denylist, &denylisted_ops) == -1) {
310 PyErr_Format(PyExc_ValueError, "Failed to convert op denylist PyObject");
311 return nullptr;
312 }
313 if (ToStringSet(node_denylist, &denylisted_nodes) == -1) {
314 PyErr_Format(PyExc_ValueError, "Failed to convert node denylist PyObject");
315 return nullptr;
316 }
317
318 std::unique_ptr<tflite::FlatBufferModel> model =
319 tflite::FlatBufferModel::BuildFromBuffer(buf, length,
320 error_reporter.get());
321 if (!model) {
322 PyErr_Format(PyExc_ValueError, "Invalid model");
323 return nullptr;
324 }
325 auto tflite_model = std::make_unique<tflite::ModelT>();
326 model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
327
328 tflite::TensorType inference_tensor_type =
329 FromTocoDataTypeToTflitToTensorType(inference_type);
330 tflite::TensorType input_type =
331 FromTocoDataTypeToTflitToTensorType(input_data_type);
332 tflite::TensorType output_type =
333 FromTocoDataTypeToTflitToTensorType(output_data_type);
334
335 flatbuffers::FlatBufferBuilder builder;
336 auto status = mlir::lite::QuantizeModel(
337 *tflite_model, input_type, output_type, inference_tensor_type, {},
338 disable_per_channel, fully_quantize, &builder, error_reporter.get(),
339 enable_numeric_verify, enable_whole_model_verify,
340 /*legacy_float_scale=*/true, denylisted_ops, denylisted_nodes);
341
342 if (status != kTfLiteOk) {
343 error_reporter->exception();
344 return nullptr;
345 }
346 return tflite::python_utils::ConvertToPyString(
347 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
348 builder.GetSize());
349 }
350
MlirSparsifyModel(PyObject * data)351 PyObject* MlirSparsifyModel(PyObject* data) {
352 using tflite::interpreter_wrapper::PythonErrorReporter;
353 char* buf = nullptr;
354 Py_ssize_t length;
355 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
356
357 if (tflite::python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
358 PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject");
359 return nullptr;
360 }
361 std::unique_ptr<tflite::FlatBufferModel> model =
362 tflite::FlatBufferModel::BuildFromBuffer(buf, length,
363 error_reporter.get());
364 if (!model) {
365 PyErr_Format(PyExc_ValueError, "Invalid model");
366 return nullptr;
367 }
368 auto tflite_model = std::make_unique<tflite::ModelT>();
369 model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
370
371 flatbuffers::FlatBufferBuilder builder;
372 auto status =
373 mlir::lite::SparsifyModel(*tflite_model, &builder, error_reporter.get());
374
375 if (status != kTfLiteOk) {
376 error_reporter->exception();
377 return nullptr;
378 }
379 return tflite::python_utils::ConvertToPyString(
380 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
381 builder.GetSize());
382 }
383
RegisterCustomOpdefs(PyObject * list)384 PyObject* RegisterCustomOpdefs(PyObject* list) {
385 if (!PyList_Check(list)) {
386 PyErr_SetString(PyExc_TypeError, "Expected list in argument");
387 return nullptr;
388 }
389
390 int64_t size = PyList_Size(list);
391 for (int i = 0; i < size; ++i) {
392 // Get character array from Python object.
393 char* tf_opdefs;
394 Py_ssize_t len;
395 if (tflite::python_utils::ConvertFromPyString(PyList_GetItem(list, i),
396 &tf_opdefs, &len) == -1) {
397 PyErr_Format(PyExc_ValueError,
398 "Failed to convert Python string at index %d of custom op "
399 "defs argument",
400 i);
401 return nullptr;
402 }
403
404 // Parse op def from character array.
405 tensorflow::OpDef opdef;
406 if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs, &opdef)) {
407 PyErr_Format(
408 PyExc_ValueError,
409 "Failed to parse opdefs at index %d of custom op defs argument: %s",
410 i, tf_opdefs);
411 return nullptr;
412 }
413
414 // Register extra opdefs to TensorFlow global op registry.
415 tensorflow::OpRegistry::Global()->Register(
416 [opdef](
417 tensorflow::OpRegistrationData* op_reg_data) -> tensorflow::Status {
418 *op_reg_data = tensorflow::OpRegistrationData(opdef);
419 return ::tensorflow::OkStatus();
420 });
421
422 // Register the corresponding fake op kernel.
423 const char* node_name = opdef.name().c_str();
424 const char* op_name = opdef.name().c_str();
425 const char* device_name = "CPU";
426 static auto fake_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
427 };
428
429 TF_KernelBuilder* builder =
430 TF_NewKernelBuilder(op_name, device_name, /*create_func=*/nullptr,
431 fake_compute_func, /*delete_func=*/nullptr);
432
433 TF_Status* status = TF_NewStatus();
434 TF_RegisterKernelBuilder(node_name, builder, status);
435 if (TF_GetCode(status) != TF_OK) {
436 TF_DeleteStatus(status);
437 PyErr_Format(PyExc_ValueError,
438 "Failed to register fake op kernel at index %d of custom op "
439 "defs argument",
440 i);
441 return nullptr;
442 }
443 TF_DeleteStatus(status);
444 }
445
446 Py_RETURN_TRUE;
447 }
448
RetrieveCollectedErrors()449 const std::vector<std::string> RetrieveCollectedErrors() {
450 mlir::TFL::ErrorCollector* collector =
451 mlir::TFL::ErrorCollector::GetErrorCollector();
452 std::vector<std::string> collected_errors;
453 for (const auto& error_data : collector->CollectedErrors()) {
454 collected_errors.push_back(error_data.SerializeAsString());
455 }
456 collector->Clear();
457 return collected_errors;
458 }
459
FlatBufferFileToMlir(const std::string & model,bool input_is_filepath)460 std::string FlatBufferFileToMlir(const std::string& model,
461 bool input_is_filepath) {
462 return ::tensorflow::FlatBufferFileToMlir(model, input_is_filepath);
463 }
464
465 } // namespace toco
466