xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
17 
18 #include "absl/memory/memory.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
21 #include "mlir/IR/Attributes.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
24 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "mlir/Parser/Parser.h"  // from @llvm-project
27 #include "tensorflow/cc/saved_model/bundle_v2.h"
28 #include "tensorflow/cc/saved_model/reader.h"
29 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
30 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
31 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
32 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
34 #include "tensorflow/core/framework/graph.pb.h"
35 #include "tensorflow/core/framework/versions.pb.h"
36 #include "tensorflow/core/graph/tensor_id.h"
37 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/platform/protobuf.h"
40 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
41 
42 namespace tensorflow {
43 
GraphdefToMlirImport(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,bool unconditionally_use_set_output_shapes,mlir::MLIRContext * context)44 static StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> GraphdefToMlirImport(
45     llvm::StringRef input, absl::string_view debug_info_file,
46     const std::vector<std::string>& input_arrays,
47     const std::vector<std::string>& input_dtypes,
48     const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
49     const std::vector<std::string>& output_arrays,
50     const std::vector<std::string>& control_output_arrays,
51     bool prune_unused_nodes, bool convert_legacy_fed_inputs,
52     bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
53     bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context) {
54   GraphDef graphdef;
55   TF_RETURN_IF_ERROR(
56       tensorflow::LoadProtoFromBuffer({input.data(), input.size()}, &graphdef));
57 
58   GraphDebugInfo debug_info;
59   if (!debug_info_file.empty()) {
60     TF_RETURN_IF_ERROR(LoadProtoFromFile(debug_info_file, &debug_info));
61   }
62 
63   GraphImportConfig specs;
64   specs.prune_unused_nodes = prune_unused_nodes;
65   specs.convert_legacy_fed_inputs = convert_legacy_fed_inputs;
66   specs.graph_as_function = graph_as_function;
67   specs.upgrade_legacy = upgrade_legacy;
68   specs.enable_shape_inference = enable_shape_inference;
69   specs.unconditionally_use_set_output_shapes =
70       unconditionally_use_set_output_shapes;
71   TF_RETURN_IF_ERROR(ParseInputArrayInfo(input_arrays, input_dtypes,
72                                          input_shapes, &specs.inputs));
73   TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.outputs));
74   TF_RETURN_IF_ERROR(
75       ParseOutputArrayInfo(control_output_arrays, &specs.control_outputs));
76   // TODO(b/142828368): Pruning should not be needed when TF import
77   // supports importing graphs w/ unregistered ops natively.
78   GraphDef pruned_graph_def;
79   if (specs.prune_unused_nodes) {
80     std::vector<std::string> terminal_nodes;
81     terminal_nodes.reserve(specs.outputs.size() + specs.inputs.size());
82     for (const auto& output : specs.outputs) {
83       terminal_nodes.push_back(std::string(ParseTensorName(output).node()));
84     }
85     for (const auto& control_output : specs.control_outputs) {
86       terminal_nodes.push_back(std::string(control_output));
87     }
88     for (const auto& input : specs.inputs) {
89       terminal_nodes.push_back(input.first);
90     }
91     TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph(
92         graphdef, &pruned_graph_def, terminal_nodes));
93     // TODO(ashwinm): Add a separate utility in grappler utils that abstracts
94     // both SetTransitiveFaninGraph and restoring the missing contents from the
95     // original graph like function def library and version.
96     pruned_graph_def.mutable_library()->Swap(graphdef.mutable_library());
97     pruned_graph_def.mutable_versions()->Swap(graphdef.mutable_versions());
98   }
99   return ConvertGraphdefToMlir(
100       specs.prune_unused_nodes ? pruned_graph_def : graphdef, debug_info, specs,
101       context);
102 }
103 
GraphdefToMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,bool unconditionally_use_set_output_shapes,mlir::MLIRContext * context)104 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> GraphdefToMlirTranslateFunction(
105     llvm::StringRef input, absl::string_view debug_info_file,
106     const std::vector<std::string>& input_arrays,
107     const std::vector<std::string>& input_dtypes,
108     const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
109     const std::vector<std::string>& output_arrays,
110     const std::vector<std::string>& control_output_arrays,
111     bool prune_unused_nodes, bool convert_legacy_fed_inputs,
112     bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
113     bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context) {
114   auto module_or = GraphdefToMlirImport(
115       input, debug_info_file, input_arrays, input_dtypes, input_shapes,
116       output_arrays, control_output_arrays, prune_unused_nodes,
117       convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
118       enable_shape_inference, unconditionally_use_set_output_shapes, context);
119   if (!module_or.status().ok()) {
120     LOG(ERROR) << "Graph import failed: " << module_or.status();
121   }
122   return module_or;
123 }
124 
GraphdefToMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,absl::string_view control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,bool unconditionally_use_set_output_shapes,mlir::MLIRContext * context)125 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> GraphdefToMlirTranslateFunction(
126     llvm::StringRef input, absl::string_view debug_info_file,
127     absl::string_view input_arrays, absl::string_view input_dtypes,
128     absl::string_view input_shapes, absl::string_view output_arrays,
129     absl::string_view control_output_arrays, bool prune_unused_nodes,
130     bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
131     bool enable_shape_inference, bool unconditionally_use_set_output_shapes,
132     mlir::MLIRContext* context) {
133   std::vector<std::string> input_array_vector;
134   std::vector<std::string> input_dtype_vector;
135   std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
136   std::vector<std::string> output_array_vector;
137   std::vector<std::string> control_output_array_vector;
138   TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector));
139   TF_RETURN_IF_ERROR(ParseNodeDataTypes(input_dtypes, input_dtype_vector));
140   TF_RETURN_IF_ERROR(ParseNodeNames(output_arrays, output_array_vector));
141   TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes, input_shapes_vector));
142   TF_RETURN_IF_ERROR(
143       ParseNodeNames(control_output_arrays, control_output_array_vector));
144   return GraphdefToMlirTranslateFunction(
145       input, debug_info_file, input_array_vector, input_dtype_vector,
146       input_shapes_vector, output_array_vector, control_output_array_vector,
147       prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function,
148       upgrade_legacy, enable_shape_inference,
149       unconditionally_use_set_output_shapes, context);
150 }
151 
SavedModelObjectGraphToMlirImport(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context,bool unconditionally_use_set_output_shapes)152 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> SavedModelObjectGraphToMlirImport(
153     absl::string_view saved_model_dir,
154     const std::unordered_set<std::string>& tags,
155     absl::Span<std::string> exported_names, mlir::MLIRContext* context,
156     bool unconditionally_use_set_output_shapes) {
157   tensorflow::SavedModelV2Bundle bundle;
158   auto load_status = tensorflow::SavedModelV2Bundle::Load(
159       std::string(saved_model_dir.data(), saved_model_dir.length()), &bundle);
160   if (!load_status.ok()) {
161     LOG(ERROR) << "Failed to load saved model '" << saved_model_dir
162                << "': " << load_status;
163     return load_status;
164   }
165 
166   auto module_or = ConvertSavedModelToMlir(
167       &bundle, context, exported_names, /*add_default_attributes=*/true,
168       /*unconditionally_use_set_output_shapes=*/
169       unconditionally_use_set_output_shapes);
170   if (!module_or.status().ok()) {
171     LOG(ERROR) << "SavedModel import failed: " << module_or.status();
172   }
173   return module_or;
174 }
175 
SavedModelSignatureDefsToMlirImport(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options,bool lift_variables,std::unique_ptr<tensorflow::SavedModelBundle> * saved_model_bundle)176 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> SavedModelSignatureDefsToMlirImport(
177     absl::string_view saved_model_dir,
178     const std::unordered_set<std::string>& tags,
179     absl::Span<std::string> exported_names, mlir::MLIRContext* context,
180     MLIRImportOptions options, bool lift_variables,
181     std::unique_ptr<tensorflow::SavedModelBundle>* saved_model_bundle) {
182   // Create local bundle if no one is provided to use.
183   std::unique_ptr<tensorflow::SavedModelBundle> bundle;
184   if (saved_model_bundle == nullptr) {
185     bundle = std::make_unique<tensorflow::SavedModelBundle>();
186   } else if (*saved_model_bundle == nullptr) {
187     *saved_model_bundle = std::make_unique<tensorflow::SavedModelBundle>();
188   }
189   SavedModelBundle* bundle_ptr =
190       saved_model_bundle ? saved_model_bundle->get() : bundle.get();
191   tensorflow::SessionOptions session_options;
192 
193   // Force saved model states to be restored to CPU.
194   (*session_options.config.mutable_device_count())["GPU"] = 0;
195   auto load_status = tensorflow::LoadSavedModel(
196       session_options, /* run_options = */ {}, std::string(saved_model_dir),
197       tags, bundle_ptr);
198   if (!load_status.ok()) {
199     LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
200                << "': " << load_status;
201     return load_status;
202   }
203 
204   auto module_or = ConvertSavedModelV1ToMlir(*bundle_ptr, exported_names,
205                                              context, options, lift_variables);
206   if (!module_or.status().ok()) {
207     LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
208   }
209   return module_or;
210 }
211 
212 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
SavedModelSignatureDefsToMlirImportLite(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options)213 SavedModelSignatureDefsToMlirImportLite(
214     absl::string_view saved_model_dir,
215     const std::unordered_set<std::string>& tags,
216     absl::Span<std::string> exported_names, mlir::MLIRContext* context,
217     MLIRImportOptions options) {
218   MetaGraphDef meta_graph_def;
219   auto status = ReadMetaGraphDefFromSavedModel(std::string(saved_model_dir),
220                                                tags, &meta_graph_def);
221   if (!status.ok()) {
222     LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
223                << "': " << status;
224     return status;
225   }
226 
227   std::optional<absl::Span<const std::string>> optional_exported_names;
228   if (!exported_names.empty()) optional_exported_names = exported_names;
229 
230   // TODO(b/186898924): debug info in the savedmodel should not be ignored and
231   // should be passed here.
232   auto module_or =
233       ConvertSavedModelV1ToMlirLite(meta_graph_def, /*debug_info=*/{},
234                                     optional_exported_names, context, options);
235   if (!module_or.status().ok()) {
236     LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
237   }
238   return module_or;
239 }
240 
241 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
GraphdefToSplattedMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,bool unconditionally_use_set_output_shapes,mlir::MLIRContext * context)242 GraphdefToSplattedMlirTranslateFunction(
243     llvm::StringRef input, absl::string_view debug_info_file,
244     const std::vector<std::string>& input_arrays,
245     const std::vector<std::string>& input_dtypes,
246     const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
247     const std::vector<std::string>& output_arrays,
248     const std::vector<std::string>& control_output_arrays,
249     bool prune_unused_nodes, bool convert_legacy_fed_inputs,
250     bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
251     bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context) {
252   auto module_or = GraphdefToMlirImport(
253       input, debug_info_file, input_arrays, input_dtypes, input_shapes,
254       output_arrays, control_output_arrays, prune_unused_nodes,
255       convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
256       enable_shape_inference, unconditionally_use_set_output_shapes, context);
257   if (!module_or.status().ok()) {
258     LOG(ERROR) << "Graph import failed: " << module_or.status();
259     return module_or.status();
260   }
261   auto& module = module_or.ValueOrDie();
262   std::srand(0);
263   for (auto fn : module->getOps<mlir::func::FuncOp>()) {
264     for (auto& bb : fn) {
265       for (auto& inst : bb) {
266         auto attr_id = mlir::StringAttr::get(context, "value");
267         if (auto attr = inst.getAttrOfType<mlir::ElementsAttr>(attr_id)) {
268           mlir::Attribute rand_val;
269           mlir::Type element_type = attr.getType().getElementType();
270           if (element_type.isa<mlir::IntegerType>()) {
271             rand_val = mlir::IntegerAttr::get(element_type, std::rand());
272           } else if (element_type.isF16() || element_type.isF32() ||
273                      element_type.isF64()) {
274             rand_val = mlir::FloatAttr::get(element_type,
275                                             std::rand() * 1.0 / RAND_MAX);
276 
277           } else {
278             inst.emitWarning()
279                 << "Skipping splat conversion for "
280                 << "an unsupported attribute type " << element_type;
281             continue;
282           }
283           auto new_attr =
284               mlir::DenseElementsAttr::get(attr.getType(), rand_val);
285           inst.setAttr(attr_id, new_attr);
286         }
287       }
288     }
289   }
290   return module_or;
291 }
292 
293 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
GraphdefToSplattedMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,absl::string_view control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,bool unconditionally_use_set_output_shapes,mlir::MLIRContext * context)294 GraphdefToSplattedMlirTranslateFunction(
295     llvm::StringRef input, absl::string_view debug_info_file,
296     absl::string_view input_arrays, absl::string_view input_dtypes,
297     absl::string_view input_shapes, absl::string_view output_arrays,
298     absl::string_view control_output_arrays, bool prune_unused_nodes,
299     bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
300     bool enable_shape_inference, bool unconditionally_use_set_output_shapes,
301     mlir::MLIRContext* context) {
302   std::vector<std::string> input_array_vector;
303   std::vector<std::string> input_dtype_vector;
304   std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
305   std::vector<std::string> output_array_vector;
306   std::vector<std::string> control_output_array_vector;
307   TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector));
308   TF_RETURN_IF_ERROR(ParseNodeDataTypes(input_dtypes, input_dtype_vector));
309   TF_RETURN_IF_ERROR(ParseNodeNames(output_arrays, output_array_vector));
310   TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes, input_shapes_vector));
311   TF_RETURN_IF_ERROR(
312       ParseNodeNames(control_output_arrays, control_output_array_vector));
313   return GraphdefToSplattedMlirTranslateFunction(
314       input, debug_info_file, input_array_vector, input_dtype_vector,
315       input_shapes_vector, output_array_vector, control_output_array_vector,
316       prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function,
317       upgrade_legacy, enable_shape_inference,
318       unconditionally_use_set_output_shapes, context);
319 }
320 
321 }  // namespace tensorflow
322