xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/python/mlir.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/python/mlir.h"
17 
18 #include <string>
19 #include <type_traits>
20 #include <utility>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/container/inlined_vector.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/str_split.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/InitAllPasses.h"  // from @llvm-project
32 #include "mlir/Parser/Parser.h"  // from @llvm-project
33 #include "mlir/Pass/PassManager.h"  // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
35 #include "tensorflow/c/eager/c_api.h"
36 #include "tensorflow/c/eager/tfe_context_internal.h"
37 #include "tensorflow/c/tf_status.h"
38 #include "tensorflow/c/tf_status_helper.h"
39 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
43 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
44 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
45 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
48 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
49 #include "tensorflow/compiler/mlir/tosa/tf_passes.h"
50 #include "tensorflow/compiler/mlir/tosa/tf_tfl_passes.h"
51 #include "tensorflow/compiler/mlir/tosa/tfl_passes.h"
52 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
53 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
54 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h"
55 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/register_passes.h"
56 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
57 #include "tensorflow/compiler/xla/status_macros.h"
58 #include "tensorflow/core/common_runtime/eager/context.h"
59 #include "tensorflow/core/common_runtime/function_body.h"
60 #include "tensorflow/core/common_runtime/function_def_utils.h"
61 #include "tensorflow/core/framework/function.h"
62 #include "tensorflow/core/framework/function.pb.h"
63 #include "tensorflow/core/framework/op.h"
64 #include "tensorflow/core/framework/tensor_shape.pb.h"
65 #include "tensorflow/core/framework/types.h"
66 #include "tensorflow/core/framework/types.pb.h"
67 #include "tensorflow/core/lib/core/errors.h"
68 #include "tensorflow/core/platform/types.h"
69 
70 namespace tensorflow {
71 
72 namespace {
73 // All the passes we will make available to Python by default.
74 // TODO(tf): this should be sharded instead of being monolithic like that.
RegisterPasses()75 static void RegisterPasses() {
76   static bool unique_registration = [] {
77     mlir::registerAllPasses();
78     mlir::registerTensorFlowPasses();
79     mlir::TFDevice::registerTensorFlowDevicePasses();
80     mlir::mhlo::registerAllMhloPasses();
81     mlir::lmhlo::registerAllLmhloPasses();
82     // These are in compiler/mlir/xla and not part of the above MHLO
83     // passes.
84     mlir::mhlo::registerXlaPasses();
85     mlir::mhlo::registerTfXlaPasses();
86     mlir::mhlo::registerLegalizeTFPass();
87     mlir::mhlo::registerLegalizeTFControlFlowPass();
88     mlir::mhlo::registerLegalizeTfTypesPassPass();
89     mlir::tosa::registerLegalizeTosaPasses();
90     mlir::tosa::registerTFtoTOSALegalizationPipeline();
91     mlir::tosa::registerTFLtoTOSALegalizationPipeline();
92     mlir::tosa::registerTFTFLtoTOSALegalizationPipeline();
93     mlir::tf_saved_model::registerTensorFlowSavedModelPasses();
94     return true;
95   }();
96   (void)unique_registration;
97 }
98 
99 // Runs pass pipeline `pass_pipeline` on `module` if `pass_pipeline` is not
100 // empty.
RunPassPipelineOnModule(mlir::ModuleOp module,const std::string & pass_pipeline,bool show_debug_info,TF_Status * status)101 std::string RunPassPipelineOnModule(mlir::ModuleOp module,
102                                     const std::string& pass_pipeline,
103                                     bool show_debug_info, TF_Status* status) {
104   RegisterPasses();
105   if (!pass_pipeline.empty()) {
106     mlir::PassManager pm(module.getContext());
107     std::string error;
108     llvm::raw_string_ostream error_stream(error);
109     if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
110       TF_SetStatus(status, TF_INVALID_ARGUMENT,
111                    ("Invalid pass_pipeline: " + error_stream.str()).c_str());
112       return "// error";
113     }
114 
115     mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext());
116     if (failed(pm.run(module))) {
117       Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
118       return "// error";
119     }
120   }
121   return MlirModuleToString(module, show_debug_info);
122 }
123 
124 }  // anonymous namespace
125 
ImportGraphDefImpl(const std::string & proto,const std::string & pass_pipeline,bool show_debug_info,GraphDebugInfo & debug_info,GraphImportConfig & specs,TF_Status * status)126 static std::string ImportGraphDefImpl(const std::string& proto,
127                                       const std::string& pass_pipeline,
128                                       bool show_debug_info,
129                                       GraphDebugInfo& debug_info,
130                                       GraphImportConfig& specs,
131                                       TF_Status* status) {
132   GraphDef graphdef;
133   auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
134   if (!s.ok()) {
135     Set_TF_Status_from_Status(status, s);
136     return "// error";
137   }
138   mlir::MLIRContext context;
139   auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
140   if (!module.ok()) {
141     Set_TF_Status_from_Status(status, module.status());
142     return "// error";
143   }
144 
145   return RunPassPipelineOnModule(module->get(), pass_pipeline, show_debug_info,
146                                  status);
147 }
148 
ImportFunction(const std::string & functiondef_proto,const std::string & pass_pipeline,bool show_debug_info,TFE_Context * tfe_context,TF_Status * status)149 std::string ImportFunction(const std::string& functiondef_proto,
150                            const std::string& pass_pipeline,
151                            bool show_debug_info, TFE_Context* tfe_context,
152                            TF_Status* status) {
153   FunctionDef functiondef;
154   auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef);
155   if (!s.ok()) {
156     Set_TF_Status_from_Status(status, s);
157     return "// error";
158   }
159 
160   const std::string& function_name = functiondef.signature().name();
161   EagerContext* cpp_context = ContextFromInterface(unwrap(tfe_context));
162   FunctionLibraryDefinition& flib_def = *cpp_context->FuncLibDef();
163   const tensorflow::FunctionDef* fdef = flib_def.Find(function_name);
164   if (fdef == nullptr) {
165     s = tensorflow::errors::NotFound("Cannot find function ", function_name);
166     Set_TF_Status_from_Status(status, s);
167     return "// error";
168   }
169 
170   std::unique_ptr<tensorflow::FunctionBody> fbody;
171   s = FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(), &flib_def,
172                               &fbody);
173   if (!s.ok()) {
174     Set_TF_Status_from_Status(status, s);
175     return "// error";
176   }
177 
178   mlir::MLIRContext context;
179   auto module = ConvertFunctionToMlir(fbody.get(), flib_def, &context);
180   if (!module.ok()) {
181     Set_TF_Status_from_Status(status, module.status());
182     return "// error";
183   }
184 
185   return RunPassPipelineOnModule(module->get(), pass_pipeline, show_debug_info,
186                                  status);
187 }
188 
ImportGraphDef(const std::string & proto,const std::string & pass_pipeline,bool show_debug_info,TF_Status * status)189 std::string ImportGraphDef(const std::string& proto,
190                            const std::string& pass_pipeline,
191                            bool show_debug_info, TF_Status* status) {
192   GraphDebugInfo debug_info;
193   GraphImportConfig specs;
194   return ImportGraphDefImpl(proto, pass_pipeline, show_debug_info, debug_info,
195                             specs, status);
196 }
197 
ImportGraphDef(const std::string & proto,const std::string & pass_pipeline,bool show_debug_info,absl::string_view input_names,absl::string_view input_data_types,absl::string_view input_data_shapes,absl::string_view output_names,TF_Status * status)198 std::string ImportGraphDef(const std::string& proto,
199                            const std::string& pass_pipeline,
200                            bool show_debug_info, absl::string_view input_names,
201                            absl::string_view input_data_types,
202                            absl::string_view input_data_shapes,
203                            absl::string_view output_names, TF_Status* status) {
204   GraphDebugInfo debug_info;
205   GraphImportConfig specs;
206   auto s = ParseInputArrayInfo(input_names, input_data_types, input_data_shapes,
207                                &specs.inputs);
208   if (!s.ok()) {
209     Set_TF_Status_from_Status(status, s);
210     return "// error";
211   }
212   if (!output_names.empty()) {
213     specs.outputs = absl::StrSplit(output_names, ',');
214   }
215   return ImportGraphDefImpl(proto, pass_pipeline, show_debug_info, debug_info,
216                             specs, status);
217 }
218 
ExperimentalConvertSavedModelToMlir(const std::string & saved_model_path,const std::string & exported_names_str,bool show_debug_info,TF_Status * status)219 std::string ExperimentalConvertSavedModelToMlir(
220     const std::string& saved_model_path, const std::string& exported_names_str,
221     bool show_debug_info, TF_Status* status) {
222   // Load the saved model into a SavedModelV2Bundle.
223 
224   tensorflow::SavedModelV2Bundle bundle;
225   auto load_status =
226       tensorflow::SavedModelV2Bundle::Load(saved_model_path, &bundle);
227   if (!load_status.ok()) {
228     Set_TF_Status_from_Status(status, load_status);
229     return "// error";
230   }
231 
232   // Convert the SavedModelV2Bundle to an MLIR module.
233 
234   std::vector<string> exported_names =
235       absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
236   mlir::MLIRContext context;
237   auto module_or = ConvertSavedModelToMlir(
238       &bundle, &context, absl::Span<std::string>(exported_names));
239   if (!module_or.status().ok()) {
240     Set_TF_Status_from_Status(status, module_or.status());
241     return "// error";
242   }
243 
244   return MlirModuleToString(*std::move(module_or).value(), show_debug_info);
245 }
246 
ExperimentalConvertSavedModelV1ToMlirLite(const std::string & saved_model_path,const std::string & exported_names_str,const std::string & tags,bool upgrade_legacy,bool show_debug_info,TF_Status * status)247 std::string ExperimentalConvertSavedModelV1ToMlirLite(
248     const std::string& saved_model_path, const std::string& exported_names_str,
249     const std::string& tags, bool upgrade_legacy, bool show_debug_info,
250     TF_Status* status) {
251   std::unordered_set<string> tag_set =
252       absl::StrSplit(tags, ',', absl::SkipEmpty());
253 
254   std::vector<string> exported_names =
255       absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
256   mlir::MLIRContext context;
257 
258   tensorflow::MLIRImportOptions import_options;
259   import_options.upgrade_legacy = upgrade_legacy;
260   auto module_or = SavedModelSignatureDefsToMlirImportLite(
261       saved_model_path, tag_set, absl::Span<std::string>(exported_names),
262       &context, import_options);
263   if (!module_or.status().ok()) {
264     Set_TF_Status_from_Status(status, module_or.status());
265     return "// error";
266   }
267 
268   return MlirModuleToString(*module_or.ValueOrDie(), show_debug_info);
269 }
270 
ExperimentalConvertSavedModelV1ToMlir(const std::string & saved_model_path,const std::string & exported_names_str,const std::string & tags,bool lift_variables,bool upgrade_legacy,bool show_debug_info,TF_Status * status)271 std::string ExperimentalConvertSavedModelV1ToMlir(
272     const std::string& saved_model_path, const std::string& exported_names_str,
273     const std::string& tags, bool lift_variables, bool upgrade_legacy,
274     bool show_debug_info, TF_Status* status) {
275   // Load the saved model into a SavedModelBundle.
276 
277   std::unordered_set<string> tag_set =
278       absl::StrSplit(tags, ',', absl::SkipEmpty());
279 
280   tensorflow::SavedModelBundle bundle;
281   auto load_status =
282       tensorflow::LoadSavedModel({}, {}, saved_model_path, tag_set, &bundle);
283   if (!load_status.ok()) {
284     Set_TF_Status_from_Status(status, load_status);
285     return "// error";
286   }
287 
288   // Convert the SavedModelBundle to an MLIR module.
289   std::vector<string> exported_names =
290       absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
291   mlir::MLIRContext context;
292   tensorflow::MLIRImportOptions import_options;
293   import_options.upgrade_legacy = upgrade_legacy;
294   auto module_or =
295       ConvertSavedModelV1ToMlir(bundle, absl::Span<std::string>(exported_names),
296                                 &context, import_options, lift_variables);
297   if (!module_or.status().ok()) {
298     Set_TF_Status_from_Status(status, module_or.status());
299     return "// error";
300   }
301 
302   // Run the tf standard pipeline by default and then, run passes that lift
303   // variables if the flag is set on the module.
304   mlir::OwningOpRef<mlir::ModuleOp> module = std::move(module_or).value();
305   mlir::PassManager pm(&context);
306   std::string error;
307   llvm::raw_string_ostream error_stream(error);
308 
309   mlir::TF::StandardPipelineOptions tf_options;
310   mlir::TF::CreateTFStandardPipeline(pm, tf_options);
311 
312   mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
313   if (failed(pm.run(*module))) {
314     Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
315     return "// error";
316   }
317   return MlirModuleToString(*module, show_debug_info);
318 }
319 
ExperimentalRunPassPipeline(const std::string & mlir_txt,const std::string & pass_pipeline,bool show_debug_info,TF_Status * status)320 std::string ExperimentalRunPassPipeline(const std::string& mlir_txt,
321                                         const std::string& pass_pipeline,
322                                         bool show_debug_info,
323                                         TF_Status* status) {
324   RegisterPasses();
325   mlir::DialectRegistry registry;
326   mlir::RegisterAllTensorFlowDialects(registry);
327   mlir::MLIRContext context(registry);
328   mlir::OwningOpRef<mlir::ModuleOp> module;
329   {
330     mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
331     module = mlir::parseSourceString<mlir::ModuleOp>(mlir_txt, &context);
332     if (!module) {
333       Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
334       return "// error";
335     }
336   }
337 
338   // Run the pass_pipeline on the module.
339   mlir::PassManager pm(&context);
340   std::string error;
341   llvm::raw_string_ostream error_stream(error);
342   if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
343     TF_SetStatus(status, TF_INVALID_ARGUMENT,
344                  ("Invalid pass_pipeline: " + error_stream.str()).c_str());
345     return "// error";
346   }
347 
348   mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
349   if (failed(pm.run(*module))) {
350     Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
351     return "// error";
352   }
353   return MlirModuleToString(*module, show_debug_info);
354 }
355 
356 }  // namespace tensorflow
357