xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/python/toco_python_api.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/python/toco_python_api.h"
16 
17 #include <fstream>
18 #include <map>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "google/protobuf/text_format.h"
24 #include "tensorflow/c/kernels.h"
25 #include "tensorflow/compiler/mlir/lite/metrics/error_collector.h"
26 #include "tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.h"
27 #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
28 #include "tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h"
29 #include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
30 #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h"
31 #include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h"
32 #include "tensorflow/core/framework/op.h"
33 #include "tensorflow/core/framework/op_def.pb.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/lite/c/common.h"
36 #include "tensorflow/lite/core/api/error_reporter.h"
37 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
38 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
39 #include "tensorflow/lite/schema/schema_generated.h"
40 #include "tensorflow/lite/toco/import_tensorflow.h"
41 #include "tensorflow/lite/toco/logging/conversion_log_util.h"
42 #include "tensorflow/lite/toco/logging/toco_conversion_log.pb.h"
43 #include "tensorflow/lite/toco/model_flags.pb.h"
44 #include "tensorflow/lite/toco/toco_convert.h"
45 #include "tensorflow/lite/toco/toco_flags.pb.h"
46 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
47 #include "tensorflow/lite/toco/toco_port.h"
48 #include "tensorflow/lite/toco/toco_tooling.h"
49 #include "tensorflow/lite/toco/toco_types.h"
50 #include "tensorflow/lite/toco/tooling_util.h"
51 #include "tensorflow/lite/toco/types.pb.h"
52 
53 namespace toco {
54 using mlir::lite::StringSet;
55 
PopulateConversionLogHelper(const toco::ModelFlags & model_flags,toco::TocoFlags * toco_flags,const std::string & input_contents_txt,const std::string & output_file_contents_txt,const std::string & error_message,GraphVizDumpOptions * dump_options)56 void PopulateConversionLogHelper(const toco::ModelFlags& model_flags,
57                                  toco::TocoFlags* toco_flags,
58                                  const std::string& input_contents_txt,
59                                  const std::string& output_file_contents_txt,
60                                  const std::string& error_message,
61                                  GraphVizDumpOptions* dump_options) {
62   // Make sure the graphviz file will be dumped under the same folder.
63   dump_options->dump_graphviz = toco_flags->conversion_summary_dir();
64   // Here we construct the `toco::Model` class based on the input graph def,
65   // it will then be used to populate the conversion log.
66   // TODO(haoliang): Don't depend on `toco::Model`.
67   std::unique_ptr<toco::Model> imported_model =
68       toco::Import(*toco_flags, model_flags, input_contents_txt);
69   // Dump pre-conversion toco logs.
70   TocoConversionLog toco_log_before;
71   PopulateConversionLog(*imported_model, &toco_log_before);
72   std::ofstream osstream_before(toco_flags->conversion_summary_dir() +
73                                 "/toco_log_before.pb");
74   toco_log_before.SerializeToOstream(&osstream_before);
75   osstream_before.close();
76   toco::LogDump(toco::kLogLevelModelChanged, "tf_graph", *imported_model);
77 
78   // Populate the post-conversion log, for convenient initiate the
79   // `toco::Model` class from the generated flatbuffer.
80   toco_flags->set_input_format(toco::FileFormat::TFLITE);
81   std::unique_ptr<toco::Model> flatbuffer_model =
82       toco::Import(*toco_flags, model_flags, output_file_contents_txt);
83   // Dump post-conversion toco logs.
84   TocoConversionLog toco_log_after;
85   PopulateConversionLog(*flatbuffer_model, &toco_log_after);
86   // Make sure we sanitize the error message.
87   toco_log_after.set_toco_err_logs(SanitizeErrorMessage(error_message));
88   std::ofstream ostream_after(toco_flags->conversion_summary_dir() +
89                               "/toco_log_after.pb");
90   toco_log_after.SerializeToOstream(&ostream_after);
91   ostream_after.close();
92   toco::LogDump(toco::kLogLevelModelChanged, "tflite_graph", *flatbuffer_model);
93 }
94 
95 // NOTE(aselle): We are using raw PyObject's here because we want to make
96 // sure we input and output bytes rather than unicode strings for Python3.
TocoConvert(PyObject * model_flags_proto_txt_raw,PyObject * toco_flags_proto_txt_raw,PyObject * input_contents_txt_raw,bool extended_return,PyObject * debug_info_txt_raw,bool enable_mlir_converter)97 PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
98                       PyObject* toco_flags_proto_txt_raw,
99                       PyObject* input_contents_txt_raw, bool extended_return,
100                       PyObject* debug_info_txt_raw,
101                       bool enable_mlir_converter) {
102   // Use Python C API to validate and convert arguments. In py3 (bytes),
103   // in py2 (str).
104   auto ConvertArg = [&](PyObject* obj, bool* error) {
105     char* buf;
106     Py_ssize_t len;
107     if (::tflite::python_utils::ConvertFromPyString(obj, &buf, &len) == -1) {
108       *error = true;
109       return std::string();
110     } else {
111       *error = false;
112       return std::string(buf, len);
113     }
114   };
115 
116   bool error;
117   std::string model_flags_proto_txt =
118       ConvertArg(model_flags_proto_txt_raw, &error);
119   if (error) {
120     PyErr_SetString(PyExc_ValueError, "Model flags are invalid.");
121     return nullptr;
122   }
123   std::string toco_flags_proto_txt =
124       ConvertArg(toco_flags_proto_txt_raw, &error);
125   if (error) {
126     PyErr_SetString(PyExc_ValueError, "Toco flags are invalid.");
127     return nullptr;
128   }
129 
130   // Use TOCO to produce new outputs.
131   toco::ModelFlags model_flags;
132   if (!model_flags.ParseFromString(model_flags_proto_txt)) {
133     PyErr_SetString(PyExc_ValueError,
134                     "Failed to convert Model to Python String.");
135     return nullptr;
136   }
137   toco::TocoFlags toco_flags;
138   if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
139     PyErr_SetString(PyExc_ValueError,
140                     "Failed to convert Toco to Python String.");
141     return nullptr;
142   }
143 
144   tensorflow::GraphDebugInfo debug_info;
145   if (debug_info_txt_raw && debug_info_txt_raw != Py_None) {
146     std::string debug_info_txt = ConvertArg(debug_info_txt_raw, &error);
147     if (error) {
148       PyErr_SetString(PyExc_ValueError, "Input DebugInfo is invalid.");
149       return nullptr;
150     }
151     if (!debug_info.ParseFromString(debug_info_txt)) {
152       PyErr_SetString(PyExc_ValueError,
153                       "Failed to convert DebugInfo to Python String.");
154       return nullptr;
155     }
156   }
157 
158   tensorflow::GraphDef graph_def;
159   std::string input_contents_txt;
160   if (model_flags.saved_model_dir().empty()) {
161     input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
162     if (error) {
163       PyErr_SetString(PyExc_ValueError, "Input GraphDef is invalid.");
164       return nullptr;
165     }
166     if (!model_flags.use_hlo_import() &&
167         !graph_def.ParseFromString(input_contents_txt)) {
168       PyErr_SetString(PyExc_ValueError,
169                       "Failed to convert GraphDef to Python String.");
170       return nullptr;
171     }
172   }
173 
174   auto& dump_options = *GraphVizDumpOptions::singleton();
175   if (toco_flags.has_dump_graphviz_dir()) {
176     dump_options.dump_graphviz = toco_flags.dump_graphviz_dir();
177   }
178   if (toco_flags.has_dump_graphviz_include_video()) {
179     dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video();
180   }
181 
182   std::string output_file_contents_txt;
183   tensorflow::Status status;
184   int64_t arithmetic_ops_count;
185 
186   // Convert model.
187   if (enable_mlir_converter) {
188     if (model_flags.use_hlo_import() && model_flags.has_saved_model_dir()) {
189       PyErr_SetString(PyExc_ValueError,
190                       "Cannot specify both saved_model and hlo import.");
191       return nullptr;
192     }
193 
194     if (model_flags.use_hlo_import()) {
195       status = tensorflow::ConvertJaxToTFLiteFlatBuffer(
196           input_contents_txt, model_flags, toco_flags,
197           &output_file_contents_txt);
198     } else if (!model_flags.saved_model_dir().empty()) {
199       status = tensorflow::ConvertSavedModelToTFLiteFlatBuffer(
200           model_flags, toco_flags, &output_file_contents_txt);
201     } else {
202       tensorflow::GraphDef graph_def;
203       if (!graph_def.ParseFromString(input_contents_txt)) {
204         PyErr_SetString(PyExc_ValueError,
205                         "Failed to convert GraphDef to Python String.");
206         return nullptr;
207       }
208 
209       status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
210           model_flags, toco_flags, debug_info, graph_def,
211           &output_file_contents_txt);
212       if (!toco_flags.conversion_summary_dir().empty()) {
213         PopulateConversionLogHelper(
214             model_flags, &toco_flags, input_contents_txt,
215             output_file_contents_txt, status.error_message(), &dump_options);
216       }
217     }
218   } else {
219     status = Convert(input_contents_txt, toco_flags, model_flags,
220                      &output_file_contents_txt, &arithmetic_ops_count);
221   }
222 
223   if (!status.ok()) {
224     PyErr_SetString(PyExc_Exception, status.error_message().c_str());
225     return nullptr;
226   }
227   if (extended_return && !enable_mlir_converter) {
228     PyObject* dict = PyDict_New();
229     PyDict_SetItemString(
230         dict, "flatbuffer",
231         ::tflite::python_utils::ConvertToPyString(
232             output_file_contents_txt.data(), output_file_contents_txt.size()));
233     PyDict_SetItemString(dict, "arithmetic_ops",
234                          PyLong_FromLong(arithmetic_ops_count));
235     return dict;
236   }
237   // Convert arguments back to byte (py3) or str (py2)
238   return ::tflite::python_utils::ConvertToPyString(
239       output_file_contents_txt.data(), output_file_contents_txt.size());
240 }
241 
FromTocoDataTypeToTflitToTensorType(int inference_type)242 tflite::TensorType FromTocoDataTypeToTflitToTensorType(int inference_type) {
243   switch (inference_type) {
244     case toco::IODataType::QUANTIZED_INT16:
245       return tflite::TensorType_INT16;
246     case toco::IODataType::QUANTIZED_UINT8:
247       return tflite::TensorType_UINT8;
248     case toco::IODataType::UINT8:
249       return tflite::TensorType_UINT8;
250     case toco::IODataType::QUANTIZED_INT8:
251       return tflite::TensorType_INT8;
252     case toco::IODataType::INT8:
253       return tflite::TensorType_INT8;
254     default:
255       return tflite::TensorType_FLOAT32;
256   }
257 }
258 
ToStringSet(PyObject * py_denylist,StringSet * string_set)259 int ToStringSet(PyObject* py_denylist, StringSet* string_set) {
260   using tflite::python_utils::ConvertFromPyString;
261   // Ensure op_denylist is non null
262   if (!py_denylist) {
263     return 0;
264   }
265   if (PyList_Check(py_denylist)) {
266     for (int i = 0; i < PyList_GET_SIZE(py_denylist); ++i) {
267       PyObject* value = PyList_GetItem(py_denylist, i);
268       char* str_buf;
269       Py_ssize_t length;
270       if (ConvertFromPyString(value, &str_buf, &length) == -1) {
271         return -1;
272       }
273       string_set->emplace(str_buf, length);
274     }
275   }
276   if (PySet_Check(py_denylist)) {
277     auto* tmp = PySet_New(py_denylist);
278     while (PySet_GET_SIZE(tmp)) {
279       PyObject* value = PySet_Pop(tmp);
280       char* str_buf;
281       Py_ssize_t length;
282       if (ConvertFromPyString(value, &str_buf, &length) == -1) {
283         return -1;
284       }
285       string_set->emplace(str_buf, length);
286     }
287   }
288   return 0;
289 }
290 
MlirQuantizeModel(PyObject * data,bool disable_per_channel,bool fully_quantize,int inference_type,int input_data_type,int output_data_type,bool enable_numeric_verify,bool enable_whole_model_verify,PyObject * op_denylist,PyObject * node_denylist)291 PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
292                             bool fully_quantize, int inference_type,
293                             int input_data_type, int output_data_type,
294                             bool enable_numeric_verify,
295                             bool enable_whole_model_verify,
296                             PyObject* op_denylist, PyObject* node_denylist) {
297   using tflite::interpreter_wrapper::PythonErrorReporter;
298   char* buf = nullptr;
299   Py_ssize_t length;
300   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
301 
302   if (tflite::python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
303     PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject");
304     return nullptr;
305   }
306 
307   StringSet denylisted_ops;
308   StringSet denylisted_nodes;
309   if (ToStringSet(op_denylist, &denylisted_ops) == -1) {
310     PyErr_Format(PyExc_ValueError, "Failed to convert op denylist PyObject");
311     return nullptr;
312   }
313   if (ToStringSet(node_denylist, &denylisted_nodes) == -1) {
314     PyErr_Format(PyExc_ValueError, "Failed to convert node denylist PyObject");
315     return nullptr;
316   }
317 
318   std::unique_ptr<tflite::FlatBufferModel> model =
319       tflite::FlatBufferModel::BuildFromBuffer(buf, length,
320                                                error_reporter.get());
321   if (!model) {
322     PyErr_Format(PyExc_ValueError, "Invalid model");
323     return nullptr;
324   }
325   auto tflite_model = std::make_unique<tflite::ModelT>();
326   model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
327 
328   tflite::TensorType inference_tensor_type =
329       FromTocoDataTypeToTflitToTensorType(inference_type);
330   tflite::TensorType input_type =
331       FromTocoDataTypeToTflitToTensorType(input_data_type);
332   tflite::TensorType output_type =
333       FromTocoDataTypeToTflitToTensorType(output_data_type);
334 
335   flatbuffers::FlatBufferBuilder builder;
336   auto status = mlir::lite::QuantizeModel(
337       *tflite_model, input_type, output_type, inference_tensor_type, {},
338       disable_per_channel, fully_quantize, &builder, error_reporter.get(),
339       enable_numeric_verify, enable_whole_model_verify,
340       /*legacy_float_scale=*/true, denylisted_ops, denylisted_nodes);
341 
342   if (status != kTfLiteOk) {
343     error_reporter->exception();
344     return nullptr;
345   }
346   return tflite::python_utils::ConvertToPyString(
347       reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
348       builder.GetSize());
349 }
350 
MlirSparsifyModel(PyObject * data)351 PyObject* MlirSparsifyModel(PyObject* data) {
352   using tflite::interpreter_wrapper::PythonErrorReporter;
353   char* buf = nullptr;
354   Py_ssize_t length;
355   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
356 
357   if (tflite::python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
358     PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject");
359     return nullptr;
360   }
361   std::unique_ptr<tflite::FlatBufferModel> model =
362       tflite::FlatBufferModel::BuildFromBuffer(buf, length,
363                                                error_reporter.get());
364   if (!model) {
365     PyErr_Format(PyExc_ValueError, "Invalid model");
366     return nullptr;
367   }
368   auto tflite_model = std::make_unique<tflite::ModelT>();
369   model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
370 
371   flatbuffers::FlatBufferBuilder builder;
372   auto status =
373       mlir::lite::SparsifyModel(*tflite_model, &builder, error_reporter.get());
374 
375   if (status != kTfLiteOk) {
376     error_reporter->exception();
377     return nullptr;
378   }
379   return tflite::python_utils::ConvertToPyString(
380       reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
381       builder.GetSize());
382 }
383 
RegisterCustomOpdefs(PyObject * list)384 PyObject* RegisterCustomOpdefs(PyObject* list) {
385   if (!PyList_Check(list)) {
386     PyErr_SetString(PyExc_TypeError, "Expected list in argument");
387     return nullptr;
388   }
389 
390   int64_t size = PyList_Size(list);
391   for (int i = 0; i < size; ++i) {
392     // Get character array from Python object.
393     char* tf_opdefs;
394     Py_ssize_t len;
395     if (tflite::python_utils::ConvertFromPyString(PyList_GetItem(list, i),
396                                                   &tf_opdefs, &len) == -1) {
397       PyErr_Format(PyExc_ValueError,
398                    "Failed to convert Python string at index %d of custom op "
399                    "defs argument",
400                    i);
401       return nullptr;
402     }
403 
404     // Parse op def from character array.
405     tensorflow::OpDef opdef;
406     if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs, &opdef)) {
407       PyErr_Format(
408           PyExc_ValueError,
409           "Failed to parse opdefs at index %d of custom op defs argument: %s",
410           i, tf_opdefs);
411       return nullptr;
412     }
413 
414     // Register extra opdefs to TensorFlow global op registry.
415     tensorflow::OpRegistry::Global()->Register(
416         [opdef](
417             tensorflow::OpRegistrationData* op_reg_data) -> tensorflow::Status {
418           *op_reg_data = tensorflow::OpRegistrationData(opdef);
419           return ::tensorflow::OkStatus();
420         });
421 
422     // Register the corresponding fake op kernel.
423     const char* node_name = opdef.name().c_str();
424     const char* op_name = opdef.name().c_str();
425     const char* device_name = "CPU";
426     static auto fake_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
427     };
428 
429     TF_KernelBuilder* builder =
430         TF_NewKernelBuilder(op_name, device_name, /*create_func=*/nullptr,
431                             fake_compute_func, /*delete_func=*/nullptr);
432 
433     TF_Status* status = TF_NewStatus();
434     TF_RegisterKernelBuilder(node_name, builder, status);
435     if (TF_GetCode(status) != TF_OK) {
436       TF_DeleteStatus(status);
437       PyErr_Format(PyExc_ValueError,
438                    "Failed to register fake op kernel at index %d of custom op "
439                    "defs argument",
440                    i);
441       return nullptr;
442     }
443     TF_DeleteStatus(status);
444   }
445 
446   Py_RETURN_TRUE;
447 }
448 
RetrieveCollectedErrors()449 const std::vector<std::string> RetrieveCollectedErrors() {
450   mlir::TFL::ErrorCollector* collector =
451       mlir::TFL::ErrorCollector::GetErrorCollector();
452   std::vector<std::string> collected_errors;
453   for (const auto& error_data : collector->CollectedErrors()) {
454     collected_errors.push_back(error_data.SerializeAsString());
455   }
456   collector->Clear();
457   return collected_errors;
458 }
459 
FlatBufferFileToMlir(const std::string & model,bool input_is_filepath)460 std::string FlatBufferFileToMlir(const std::string& model,
461                                  bool input_is_filepath) {
462   return ::tensorflow::FlatBufferFileToMlir(model, input_is_filepath);
463 }
464 
465 }  // namespace toco
466