1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
17
18 #include "absl/memory/memory.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
21 #include "mlir/IR/Attributes.h" // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
24 #include "mlir/IR/MLIRContext.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "mlir/Parser/Parser.h" // from @llvm-project
27 #include "tensorflow/cc/saved_model/bundle_v2.h"
28 #include "tensorflow/cc/saved_model/reader.h"
29 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
30 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
31 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
32 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
34 #include "tensorflow/core/framework/graph.pb.h"
35 #include "tensorflow/core/framework/versions.pb.h"
36 #include "tensorflow/core/graph/tensor_id.h"
37 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/platform/protobuf.h"
40 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
41
42 namespace tensorflow {
43
GraphdefToMlirImport(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,bool unconditionally_use_set_output_shapes,mlir::MLIRContext * context)44 static StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> GraphdefToMlirImport(
45 llvm::StringRef input, absl::string_view debug_info_file,
46 const std::vector<std::string>& input_arrays,
47 const std::vector<std::string>& input_dtypes,
48 const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
49 const std::vector<std::string>& output_arrays,
50 const std::vector<std::string>& control_output_arrays,
51 bool prune_unused_nodes, bool convert_legacy_fed_inputs,
52 bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
53 bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context) {
54 GraphDef graphdef;
55 TF_RETURN_IF_ERROR(
56 tensorflow::LoadProtoFromBuffer({input.data(), input.size()}, &graphdef));
57
58 GraphDebugInfo debug_info;
59 if (!debug_info_file.empty()) {
60 TF_RETURN_IF_ERROR(LoadProtoFromFile(debug_info_file, &debug_info));
61 }
62
63 GraphImportConfig specs;
64 specs.prune_unused_nodes = prune_unused_nodes;
65 specs.convert_legacy_fed_inputs = convert_legacy_fed_inputs;
66 specs.graph_as_function = graph_as_function;
67 specs.upgrade_legacy = upgrade_legacy;
68 specs.enable_shape_inference = enable_shape_inference;
69 specs.unconditionally_use_set_output_shapes =
70 unconditionally_use_set_output_shapes;
71 TF_RETURN_IF_ERROR(ParseInputArrayInfo(input_arrays, input_dtypes,
72 input_shapes, &specs.inputs));
73 TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.outputs));
74 TF_RETURN_IF_ERROR(
75 ParseOutputArrayInfo(control_output_arrays, &specs.control_outputs));
76 // TODO(b/142828368): Pruning should not be needed when TF import
77 // supports importing graphs w/ unregistered ops natively.
78 GraphDef pruned_graph_def;
79 if (specs.prune_unused_nodes) {
80 std::vector<std::string> terminal_nodes;
81 terminal_nodes.reserve(specs.outputs.size() + specs.inputs.size());
82 for (const auto& output : specs.outputs) {
83 terminal_nodes.push_back(std::string(ParseTensorName(output).node()));
84 }
85 for (const auto& control_output : specs.control_outputs) {
86 terminal_nodes.push_back(std::string(control_output));
87 }
88 for (const auto& input : specs.inputs) {
89 terminal_nodes.push_back(input.first);
90 }
91 TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph(
92 graphdef, &pruned_graph_def, terminal_nodes));
93 // TODO(ashwinm): Add a separate utility in grappler utils that abstracts
94 // both SetTransitiveFaninGraph and restoring the missing contents from the
95 // original graph like function def library and version.
96 pruned_graph_def.mutable_library()->Swap(graphdef.mutable_library());
97 pruned_graph_def.mutable_versions()->Swap(graphdef.mutable_versions());
98 }
99 return ConvertGraphdefToMlir(
100 specs.prune_unused_nodes ? pruned_graph_def : graphdef, debug_info, specs,
101 context);
102 }
103
GraphdefToMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,bool unconditionally_use_set_output_shapes,mlir::MLIRContext * context)104 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> GraphdefToMlirTranslateFunction(
105 llvm::StringRef input, absl::string_view debug_info_file,
106 const std::vector<std::string>& input_arrays,
107 const std::vector<std::string>& input_dtypes,
108 const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
109 const std::vector<std::string>& output_arrays,
110 const std::vector<std::string>& control_output_arrays,
111 bool prune_unused_nodes, bool convert_legacy_fed_inputs,
112 bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
113 bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context) {
114 auto module_or = GraphdefToMlirImport(
115 input, debug_info_file, input_arrays, input_dtypes, input_shapes,
116 output_arrays, control_output_arrays, prune_unused_nodes,
117 convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
118 enable_shape_inference, unconditionally_use_set_output_shapes, context);
119 if (!module_or.status().ok()) {
120 LOG(ERROR) << "Graph import failed: " << module_or.status();
121 }
122 return module_or;
123 }
124
GraphdefToMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,absl::string_view control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,bool unconditionally_use_set_output_shapes,mlir::MLIRContext * context)125 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> GraphdefToMlirTranslateFunction(
126 llvm::StringRef input, absl::string_view debug_info_file,
127 absl::string_view input_arrays, absl::string_view input_dtypes,
128 absl::string_view input_shapes, absl::string_view output_arrays,
129 absl::string_view control_output_arrays, bool prune_unused_nodes,
130 bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
131 bool enable_shape_inference, bool unconditionally_use_set_output_shapes,
132 mlir::MLIRContext* context) {
133 std::vector<std::string> input_array_vector;
134 std::vector<std::string> input_dtype_vector;
135 std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
136 std::vector<std::string> output_array_vector;
137 std::vector<std::string> control_output_array_vector;
138 TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector));
139 TF_RETURN_IF_ERROR(ParseNodeDataTypes(input_dtypes, input_dtype_vector));
140 TF_RETURN_IF_ERROR(ParseNodeNames(output_arrays, output_array_vector));
141 TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes, input_shapes_vector));
142 TF_RETURN_IF_ERROR(
143 ParseNodeNames(control_output_arrays, control_output_array_vector));
144 return GraphdefToMlirTranslateFunction(
145 input, debug_info_file, input_array_vector, input_dtype_vector,
146 input_shapes_vector, output_array_vector, control_output_array_vector,
147 prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function,
148 upgrade_legacy, enable_shape_inference,
149 unconditionally_use_set_output_shapes, context);
150 }
151
SavedModelObjectGraphToMlirImport(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context,bool unconditionally_use_set_output_shapes)152 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> SavedModelObjectGraphToMlirImport(
153 absl::string_view saved_model_dir,
154 const std::unordered_set<std::string>& tags,
155 absl::Span<std::string> exported_names, mlir::MLIRContext* context,
156 bool unconditionally_use_set_output_shapes) {
157 tensorflow::SavedModelV2Bundle bundle;
158 auto load_status = tensorflow::SavedModelV2Bundle::Load(
159 std::string(saved_model_dir.data(), saved_model_dir.length()), &bundle);
160 if (!load_status.ok()) {
161 LOG(ERROR) << "Failed to load saved model '" << saved_model_dir
162 << "': " << load_status;
163 return load_status;
164 }
165
166 auto module_or = ConvertSavedModelToMlir(
167 &bundle, context, exported_names, /*add_default_attributes=*/true,
168 /*unconditionally_use_set_output_shapes=*/
169 unconditionally_use_set_output_shapes);
170 if (!module_or.status().ok()) {
171 LOG(ERROR) << "SavedModel import failed: " << module_or.status();
172 }
173 return module_or;
174 }
175
SavedModelSignatureDefsToMlirImport(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options,bool lift_variables,std::unique_ptr<tensorflow::SavedModelBundle> * saved_model_bundle)176 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> SavedModelSignatureDefsToMlirImport(
177 absl::string_view saved_model_dir,
178 const std::unordered_set<std::string>& tags,
179 absl::Span<std::string> exported_names, mlir::MLIRContext* context,
180 MLIRImportOptions options, bool lift_variables,
181 std::unique_ptr<tensorflow::SavedModelBundle>* saved_model_bundle) {
182 // Create local bundle if no one is provided to use.
183 std::unique_ptr<tensorflow::SavedModelBundle> bundle;
184 if (saved_model_bundle == nullptr) {
185 bundle = std::make_unique<tensorflow::SavedModelBundle>();
186 } else if (*saved_model_bundle == nullptr) {
187 *saved_model_bundle = std::make_unique<tensorflow::SavedModelBundle>();
188 }
189 SavedModelBundle* bundle_ptr =
190 saved_model_bundle ? saved_model_bundle->get() : bundle.get();
191 tensorflow::SessionOptions session_options;
192
193 // Force saved model states to be restored to CPU.
194 (*session_options.config.mutable_device_count())["GPU"] = 0;
195 auto load_status = tensorflow::LoadSavedModel(
196 session_options, /* run_options = */ {}, std::string(saved_model_dir),
197 tags, bundle_ptr);
198 if (!load_status.ok()) {
199 LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
200 << "': " << load_status;
201 return load_status;
202 }
203
204 auto module_or = ConvertSavedModelV1ToMlir(*bundle_ptr, exported_names,
205 context, options, lift_variables);
206 if (!module_or.status().ok()) {
207 LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
208 }
209 return module_or;
210 }
211
212 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
SavedModelSignatureDefsToMlirImportLite(absl::string_view saved_model_dir,const std::unordered_set<std::string> & tags,absl::Span<std::string> exported_names,mlir::MLIRContext * context,MLIRImportOptions options)213 SavedModelSignatureDefsToMlirImportLite(
214 absl::string_view saved_model_dir,
215 const std::unordered_set<std::string>& tags,
216 absl::Span<std::string> exported_names, mlir::MLIRContext* context,
217 MLIRImportOptions options) {
218 MetaGraphDef meta_graph_def;
219 auto status = ReadMetaGraphDefFromSavedModel(std::string(saved_model_dir),
220 tags, &meta_graph_def);
221 if (!status.ok()) {
222 LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
223 << "': " << status;
224 return status;
225 }
226
227 std::optional<absl::Span<const std::string>> optional_exported_names;
228 if (!exported_names.empty()) optional_exported_names = exported_names;
229
230 // TODO(b/186898924): debug info in the savedmodel should not be ignored and
231 // should be passed here.
232 auto module_or =
233 ConvertSavedModelV1ToMlirLite(meta_graph_def, /*debug_info=*/{},
234 optional_exported_names, context, options);
235 if (!module_or.status().ok()) {
236 LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
237 }
238 return module_or;
239 }
240
241 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
GraphdefToSplattedMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,const std::vector<std::string> & input_arrays,const std::vector<std::string> & input_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & input_shapes,const std::vector<std::string> & output_arrays,const std::vector<std::string> & control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,bool unconditionally_use_set_output_shapes,mlir::MLIRContext * context)242 GraphdefToSplattedMlirTranslateFunction(
243 llvm::StringRef input, absl::string_view debug_info_file,
244 const std::vector<std::string>& input_arrays,
245 const std::vector<std::string>& input_dtypes,
246 const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
247 const std::vector<std::string>& output_arrays,
248 const std::vector<std::string>& control_output_arrays,
249 bool prune_unused_nodes, bool convert_legacy_fed_inputs,
250 bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference,
251 bool unconditionally_use_set_output_shapes, mlir::MLIRContext* context) {
252 auto module_or = GraphdefToMlirImport(
253 input, debug_info_file, input_arrays, input_dtypes, input_shapes,
254 output_arrays, control_output_arrays, prune_unused_nodes,
255 convert_legacy_fed_inputs, graph_as_function, upgrade_legacy,
256 enable_shape_inference, unconditionally_use_set_output_shapes, context);
257 if (!module_or.status().ok()) {
258 LOG(ERROR) << "Graph import failed: " << module_or.status();
259 return module_or.status();
260 }
261 auto& module = module_or.ValueOrDie();
262 std::srand(0);
263 for (auto fn : module->getOps<mlir::func::FuncOp>()) {
264 for (auto& bb : fn) {
265 for (auto& inst : bb) {
266 auto attr_id = mlir::StringAttr::get(context, "value");
267 if (auto attr = inst.getAttrOfType<mlir::ElementsAttr>(attr_id)) {
268 mlir::Attribute rand_val;
269 mlir::Type element_type = attr.getType().getElementType();
270 if (element_type.isa<mlir::IntegerType>()) {
271 rand_val = mlir::IntegerAttr::get(element_type, std::rand());
272 } else if (element_type.isF16() || element_type.isF32() ||
273 element_type.isF64()) {
274 rand_val = mlir::FloatAttr::get(element_type,
275 std::rand() * 1.0 / RAND_MAX);
276
277 } else {
278 inst.emitWarning()
279 << "Skipping splat conversion for "
280 << "an unsupported attribute type " << element_type;
281 continue;
282 }
283 auto new_attr =
284 mlir::DenseElementsAttr::get(attr.getType(), rand_val);
285 inst.setAttr(attr_id, new_attr);
286 }
287 }
288 }
289 }
290 return module_or;
291 }
292
293 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
GraphdefToSplattedMlirTranslateFunction(llvm::StringRef input,absl::string_view debug_info_file,absl::string_view input_arrays,absl::string_view input_dtypes,absl::string_view input_shapes,absl::string_view output_arrays,absl::string_view control_output_arrays,bool prune_unused_nodes,bool convert_legacy_fed_inputs,bool graph_as_function,bool upgrade_legacy,bool enable_shape_inference,bool unconditionally_use_set_output_shapes,mlir::MLIRContext * context)294 GraphdefToSplattedMlirTranslateFunction(
295 llvm::StringRef input, absl::string_view debug_info_file,
296 absl::string_view input_arrays, absl::string_view input_dtypes,
297 absl::string_view input_shapes, absl::string_view output_arrays,
298 absl::string_view control_output_arrays, bool prune_unused_nodes,
299 bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy,
300 bool enable_shape_inference, bool unconditionally_use_set_output_shapes,
301 mlir::MLIRContext* context) {
302 std::vector<std::string> input_array_vector;
303 std::vector<std::string> input_dtype_vector;
304 std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
305 std::vector<std::string> output_array_vector;
306 std::vector<std::string> control_output_array_vector;
307 TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector));
308 TF_RETURN_IF_ERROR(ParseNodeDataTypes(input_dtypes, input_dtype_vector));
309 TF_RETURN_IF_ERROR(ParseNodeNames(output_arrays, output_array_vector));
310 TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes, input_shapes_vector));
311 TF_RETURN_IF_ERROR(
312 ParseNodeNames(control_output_arrays, control_output_array_vector));
313 return GraphdefToSplattedMlirTranslateFunction(
314 input, debug_info_file, input_array_vector, input_dtype_vector,
315 input_shapes_vector, output_array_vector, control_output_array_vector,
316 prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function,
317 upgrade_legacy, enable_shape_inference,
318 unconditionally_use_set_output_shapes, context);
319 }
320
321 } // namespace tensorflow
322