1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/container/inlined_vector.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/optional.h"
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/Support/Casting.h"
33 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
34 #include "mlir/IR/Attributes.h" // from @llvm-project
35 #include "mlir/IR/Builders.h" // from @llvm-project
36 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
37 #include "mlir/IR/Location.h" // from @llvm-project
38 #include "mlir/IR/Operation.h" // from @llvm-project
39 #include "mlir/IR/SymbolTable.h" // from @llvm-project
40 #include "mlir/IR/Types.h" // from @llvm-project
41 #include "mlir/Pass/Pass.h" // from @llvm-project
42 #include "mlir/Pass/PassManager.h" // from @llvm-project
43 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
44 #include "mlir/Support/LogicalResult.h" // from @llvm-project
45 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
48 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
49 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
50 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
51 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
52 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
53 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
54 #include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h"
55 #include "tensorflow/compiler/mlir/utils/name_utils.h"
56 #include "tensorflow/compiler/xla/status_macros.h"
57 #include "tensorflow/core/framework/graph.pb.h"
58 #include "tensorflow/core/framework/graph_to_functiondef.h"
59 #include "tensorflow/core/framework/node_def.pb.h"
60 #include "tensorflow/core/framework/node_def_util.h"
61 #include "tensorflow/core/framework/op.h"
62 #include "tensorflow/core/framework/types.pb.h"
63 #include "tensorflow/core/framework/versions.pb.h"
64 #include "tensorflow/core/graph/algorithm.h"
65 #include "tensorflow/core/graph/graph.h"
66 #include "tensorflow/core/graph/tensor_id.h"
67 #include "tensorflow/core/lib/core/errors.h"
68 #include "tensorflow/core/lib/core/status.h"
69
70 namespace tensorflow {
71 using llvm::dyn_cast;
72 using llvm::isa;
73 using mlir::BlockArgument;
74 using mlir::Dialect;
75 using mlir::Operation;
76 using mlir::SymbolTable;
77 using mlir::Value;
78 using mlir::func::FuncOp;
79 using stream_executor::port::StatusOr;
80
81 namespace {
82
83 constexpr char kDeviceAttr[] = "tf.device";
84 constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
85 constexpr char kEntryFuncAttr[] = "tf.entry_function";
86 constexpr char kAliasingAttr[] = "tf.aliasing_output";
87
88 // OpOrArgLocNameMapper that legalizes the returned name.
89 class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
90 private:
GetName(OpOrVal op_or_val)91 std::string GetName(OpOrVal op_or_val) override {
92 std::string name = OpOrArgLocNameMapper::GetName(op_or_val);
93 assert(!name.empty() && "expected non-empty name");
94 mlir::LegalizeNodeName(name);
95 return name;
96 }
97 };
98
99 // Finds first inner op if `op` is a tf_executor.island. Otherwise `op` is
100 // returned.
GetIslandInnerOpOrSelf(mlir::Operation * op)101 Operation* GetIslandInnerOpOrSelf(mlir::Operation* op) {
102 auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(op);
103 if (island) return &island.GetBody().front();
104 return op;
105 }
106
107 // Stateful helper class to export a function into a Graph.
108 class Exporter {
109 public:
110 // Converts the given Module to a Graph. The given module should only contain
111 // one entry function, which is identified by name "main". This entry function
112 // is converted to the base of the graph graph. The rest of the functions are
113 // converted to the library functions in that graph.
114 static Status Convert(mlir::ModuleOp module, const GraphExportConfig& configs,
115 std::unique_ptr<Graph>* graph,
116 FunctionLibraryDefinition* flib_def,
117 absl::flat_hash_set<Node*>* control_ret_nodes);
118
119 // Converts a given FuncOp to a FunctionDef and adds it to the function
120 // definition library
121 static Status ConvertLibFunction(
122 const GraphExportConfig& configs, const Dialect* tf_dialect,
123 const SymbolTable& symbol_table, FuncOp function,
124 FunctionDefLibrary* flib, llvm::SmallDenseSet<FuncOp>& visited_functions);
125
126 // Converts the given FuncOp to a Graph. The arguments and returns of
127 // function are added to the graph with special op names kArgOp and kRetOp.
128 // Later on, this graph can be converted a function definition and added to
129 // another graph.
130 static StatusOr<std::unique_ptr<Graph>> Convert(
131 const GraphExportConfig& configs, const Dialect* tf_dialect,
132 const SymbolTable& symbol_table, FuncOp function,
133 FunctionDefLibrary* flib, llvm::SmallDenseSet<FuncOp>& visited_functions,
134 absl::flat_hash_set<Node*>* control_ret_nodes);
135
136 private:
Exporter(Graph * graph,const Dialect * tf_dialect)137 explicit Exporter(Graph* graph, const Dialect* tf_dialect)
138 : graph_(graph), tf_dialect_(tf_dialect) {}
139
140 Status AddArgumentNode(BlockArgument arg, unsigned index,
141 llvm::StringRef name);
142 Status AddFetchNode(FuncOp function, mlir::tf_executor::FetchOp fetch,
143 llvm::ArrayRef<llvm::StringRef> names);
144 Status AddInstructionNode(Operation* inst);
145 Status AddEdge(Operation* inst);
146
147 StatusOr<std::unique_ptr<NodeDef>> GetArgumentNode(BlockArgument arg,
148 unsigned index,
149 llvm::StringRef name);
150 StatusOr<std::unique_ptr<NodeDef>> GetReturnNode(FuncOp function,
151 Value operand,
152 unsigned index,
153 llvm::StringRef name);
154 Status GetControlRetNodes(mlir::tf_executor::FetchOp fetch,
155 absl::flat_hash_set<Node*>* control_ret_nodes);
156 // Adds one edge between src_node and dst_node. If it is not a control edge,
157 // an index is used to find out the right operand of the dst_node.
158 Status AddEdgeBetweenNodes(Value src, Node* dst_node, unsigned dst_index);
159
160 Graph* graph_;
161 LegalizedOpOrValLocNameMapper op_to_name_;
162 absl::flat_hash_map<Operation*, Node*> nodes_;
163 llvm::DenseMap<BlockArgument, Node*> args_;
164 // One single return operation can return multiple results, and each of them
165 // will be converted to one node in the graph.
166 typedef absl::InlinedVector<Node*, 4> NodeVector;
167 absl::flat_hash_map<Operation*, NodeVector> returns_;
168 const mlir::Dialect* tf_dialect_;
169 };
170
GetArgumentNode(BlockArgument arg,unsigned index,llvm::StringRef name)171 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
172 BlockArgument arg, unsigned index, llvm::StringRef name) {
173 auto func = arg.getParentRegion()->getParentOfType<FuncOp>();
174
175 auto node_def = std::make_unique<NodeDef>();
176 if (!name.empty())
177 node_def->set_name(std::string(ParseTensorName(name.str()).node()));
178 else
179 node_def->set_name(
180 std::string(op_to_name_.GetUniqueName(func.getName().str())));
181
182 node_def->set_op(FunctionLibraryDefinition::kArgOp);
183
184 mlir::TensorType arg_type = arg.getType().cast<mlir::TensorType>();
185 if (auto resource_type =
186 arg_type.getElementType().dyn_cast<mlir::TF::ResourceType>()) {
187 llvm::ArrayRef<mlir::TensorType> subtypes = resource_type.getSubtypes();
188 if (!subtypes.empty()) {
189 AttrValue handle_dtypes_attr;
190 AttrValue handle_shapes_attr;
191 for (mlir::TensorType subtype : subtypes) {
192 DataType dtype;
193 TF_RETURN_IF_ERROR(ConvertToDataType(subtype.getElementType(), &dtype));
194 handle_dtypes_attr.mutable_list()->add_type(dtype);
195
196 SetTensorShapeProto(subtype,
197 handle_shapes_attr.mutable_list()->add_shape());
198 }
199
200 (*node_def->mutable_attr())["_handle_dtypes"] = handle_dtypes_attr;
201 (*node_def->mutable_attr())["_handle_shapes"] = handle_shapes_attr;
202 }
203 }
204
205 TF_RETURN_IF_ERROR(
206 SetShapeAttribute("_output_shapes", arg_type, node_def->mutable_attr()));
207
208 DataType dtype;
209 TF_RETURN_IF_ERROR(ConvertToDataType(arg_type.getElementType(), &dtype));
210 AttrValue type_attr;
211 type_attr.set_type(dtype);
212 (*node_def->mutable_attr())["T"] = type_attr;
213
214 AttrValue index_attr;
215 index_attr.set_i(index);
216 (*node_def->mutable_attr())["index"] = index_attr;
217
218 if (auto device_attr =
219 func.getArgAttrOfType<mlir::StringAttr>(index, kDeviceAttr))
220 *node_def->mutable_device() = device_attr.getValue().str();
221
222 llvm::ArrayRef<mlir::NamedAttribute> func_arg_i_attrs =
223 func.getArgAttrs(index);
224 absl::flat_hash_set<absl::string_view> attrs_to_ignore = {kDeviceAttr,
225 kAliasingAttr};
226 TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore,
227 /*remove_ref_type=*/false,
228 node_def->mutable_attr()));
229
230 return node_def;
231 }
232
GetReturnNode(FuncOp function,Value operand,unsigned index,llvm::StringRef name)233 StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
234 FuncOp function, Value operand, unsigned index, llvm::StringRef name) {
235 auto node_def = std::make_unique<NodeDef>();
236 if (!name.empty())
237 node_def->set_name(std::string(ParseTensorName(name.str()).node()));
238 else
239 node_def->set_name(
240 std::string(op_to_name_.GetUniqueName(function.getName().str())));
241
242 node_def->set_op(FunctionLibraryDefinition::kRetOp);
243 DataType dtype;
244 TF_RETURN_IF_ERROR(ConvertToDataType(
245 operand.getType().cast<mlir::TensorType>().getElementType(), &dtype));
246 AttrValue type_attr;
247 type_attr.set_type(dtype);
248 (*node_def->mutable_attr())["T"] = type_attr;
249 AttrValue index_attr;
250 index_attr.set_i(index);
251 (*node_def->mutable_attr())["index"] = index_attr;
252
253 if (auto device_attr =
254 function.getResultAttrOfType<mlir::StringAttr>(index, kDeviceAttr))
255 *node_def->mutable_device() = device_attr.getValue().str();
256
257 llvm::ArrayRef<mlir::NamedAttribute> func_res_i_attrs =
258 function.getResultAttrs(index);
259 absl::flat_hash_set<absl::string_view> attrs_to_ignore = {kDeviceAttr};
260 TF_RETURN_IF_ERROR(ConvertAttributes(func_res_i_attrs, attrs_to_ignore,
261 /*remove_ref_type=*/false,
262 node_def->mutable_attr()));
263
264 return node_def;
265 }
266
AddEdgeBetweenNodes(Value src,Node * dst_node,unsigned dst_index)267 Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node,
268 unsigned dst_index) {
269 if (auto input_result = src.dyn_cast<mlir::OpResult>()) {
270 auto* input_inst = GetIslandInnerOpOrSelf(input_result.getOwner());
271 // Replaces the input node with NextIteration sink if it is a NextIteration
272 // source.
273 if (auto next_iter_source =
274 llvm::dyn_cast<mlir::tf_executor::NextIterationSourceOp>(
275 input_inst))
276 input_inst = next_iter_source.GetSink();
277
278 auto node_it = nodes_.find(input_inst);
279 TF_RET_CHECK(node_it != nodes_.end())
280 << "Use of OpResult encountered before def!";
281 if (input_result.getType().isa<mlir::tf_executor::ControlType>()) {
282 graph_->AddControlEdge(node_it->second, dst_node);
283 } else {
284 graph_->AddEdge(node_it->second, input_result.getResultNumber(), dst_node,
285 dst_index);
286 }
287 return OkStatus();
288 }
289
290 auto input_arg = src.cast<BlockArgument>();
291 auto input_node_it = args_.find(input_arg);
292 TF_RET_CHECK(input_node_it != args_.end())
293 << "Use of BlockArgument encounted before def!";
294 // For argument, there is only one result output, so the index is always 0.
295 graph_->AddEdge(input_node_it->second, 0, dst_node, dst_index);
296 return OkStatus();
297 }
298
AddEdge(Operation * inst)299 Status Exporter::AddEdge(Operation* inst) {
300 // For tf_executor.fetch, add only its data edges. Control edges are captured
301 // later.
302 if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) {
303 for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) {
304 Value operand = operand_and_idx.value();
305 if (operand.getType().isa<mlir::tf_executor::ControlType>()) break;
306
307 auto* dst_node = returns_[fetch][operand_and_idx.index()];
308 TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand, dst_node, 0));
309 }
310
311 return OkStatus();
312 }
313
314 // For tf_executor.NextIteration.Sink, skip its token operand and add data and
315 // control edges with their index offset by 1.
316 if (auto next_iter_sink =
317 llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(inst)) {
318 auto* dst_node = nodes_[inst];
319 TF_RETURN_IF_ERROR(
320 AddEdgeBetweenNodes(next_iter_sink.input(), dst_node, 0));
321 for (auto control_and_idx : llvm::enumerate(next_iter_sink.controlInputs()))
322 TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(control_and_idx.value(), dst_node,
323 control_and_idx.index() + 1));
324
325 return OkStatus();
326 }
327
328 // For tf_executor.NextIteration.Source, op can be skipped as it is assumed
329 // there are no operands.
330 if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) {
331 assert(inst->getNumOperands() == 0);
332 return OkStatus();
333 }
334
335 Operation* op = GetIslandInnerOpOrSelf(inst);
336 auto* dst_node = nodes_[op];
337 int operand_offset = 0;
338 // For tf_executor.island, add data edges from its wrapped op before control
339 // edges.
340 if (auto island = llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) {
341 for (auto operand_and_idx : llvm::enumerate(op->getOperands()))
342 TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand_and_idx.value(), dst_node,
343 operand_and_idx.index()));
344
345 operand_offset = op->getNumOperands();
346 }
347
348 // For all other ops (including tf_executor.island), add remaining edges.
349 for (auto operand_and_idx : llvm::enumerate(inst->getOperands()))
350 TF_RETURN_IF_ERROR(
351 AddEdgeBetweenNodes(operand_and_idx.value(), dst_node,
352 operand_and_idx.index() + operand_offset));
353
354 return OkStatus();
355 }
356
AddInstructionNode(Operation * inst)357 Status Exporter::AddInstructionNode(Operation* inst) {
358 std::unique_ptr<NodeDef> node_def;
359 auto name = op_to_name_.GetUniqueName(inst);
360 // Convert registered TF ops to NodeDef. Only registered ops are handled to
361 // ensure that PopulateDerivedAttrs adds the correct attributes.
362 TF_ASSIGN_OR_RETURN(node_def,
363 ConvertTFDialectOpToNodeDef(
364 inst, name, /*ignore_unregistered_attrs=*/false));
365
366 TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(*node_def));
367 DCHECK(node != nullptr);
368 nodes_[inst] = node;
369 return OkStatus();
370 }
371
IsEntryFunctionArg(BlockArgument arg)372 bool IsEntryFunctionArg(BlockArgument arg) {
373 return arg.getParentRegion()->getParentOfType<FuncOp>().getName() == "main";
374 }
375
376 // Creates argument nodes from Block argument. If a name is supplied, that
377 // name will be used instead of generating a unique name.
AddArgumentNode(BlockArgument arg,unsigned index,llvm::StringRef name)378 Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index,
379 llvm::StringRef name) {
380 TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index, name));
381 TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(*node_def));
382 args_[arg] = node;
383 return OkStatus();
384 }
385
386 // Creates return nodes per operand of a FetchOp. If names is supplied, those
387 // names will be used per node in order instead of generating a unique name.
AddFetchNode(FuncOp function,mlir::tf_executor::FetchOp fetch,llvm::ArrayRef<llvm::StringRef> names)388 Status Exporter::AddFetchNode(FuncOp function, mlir::tf_executor::FetchOp fetch,
389 llvm::ArrayRef<llvm::StringRef> names) {
390 auto& return_nodes = returns_[fetch];
391 for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) {
392 if (operand_and_idx.value().getType().isa<mlir::tf_executor::ControlType>())
393 break;
394
395 TF_ASSIGN_OR_RETURN(
396 auto node_def,
397 GetReturnNode(function, operand_and_idx.value(),
398 operand_and_idx.index(),
399 names.empty() ? "" : names[operand_and_idx.index()]));
400 TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(*node_def));
401 return_nodes.push_back(node);
402 }
403 return OkStatus();
404 }
405
406 // Collects control ret Nodes based on tf_executor.graph's associated
407 // tf_executor.fetch control inputs.
GetControlRetNodes(mlir::tf_executor::FetchOp fetch,absl::flat_hash_set<Node * > * control_ret_nodes)408 Status Exporter::GetControlRetNodes(
409 mlir::tf_executor::FetchOp fetch,
410 absl::flat_hash_set<Node*>* control_ret_nodes) {
411 for (Value fetch_operand : fetch.getOperands()) {
412 if (fetch_operand.getType().isa<mlir::tf_executor::ControlType>()) {
413 Operation* defining_op =
414 GetIslandInnerOpOrSelf(fetch_operand.getDefiningOp());
415 auto node_it = nodes_.find(defining_op);
416 TF_RET_CHECK(node_it != nodes_.end());
417 control_ret_nodes->insert(node_it->second);
418 }
419 }
420 return OkStatus();
421 }
422
Convert(const GraphExportConfig & configs,const Dialect * tf_dialect,const SymbolTable & symbol_table,FuncOp function,FunctionDefLibrary * flib,llvm::SmallDenseSet<FuncOp> & visited_functions,absl::flat_hash_set<Node * > * control_ret_nodes)423 StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
424 const GraphExportConfig& configs, const Dialect* tf_dialect,
425 const SymbolTable& symbol_table, FuncOp function, FunctionDefLibrary* flib,
426 llvm::SmallDenseSet<FuncOp>& visited_functions,
427 absl::flat_hash_set<Node*>* control_ret_nodes) {
428 mlir::Block& block = function.front();
429
430 // Extract input & output names if set.
431 llvm::SmallVector<llvm::StringRef, 2> input_names;
432 llvm::SmallVector<llvm::StringRef, 2> output_names;
433 llvm::SmallVector<llvm::StringRef, 2> unique_output_names;
434 auto dict_attr =
435 function->getAttrOfType<mlir::DictionaryAttr>(kEntryFuncAttr);
436 if (dict_attr) {
437 TF_RET_CHECK(dict_attr.get("inputs").isa<mlir::StringAttr>())
438 << "inputs missing in entry function attribute";
439 TF_RET_CHECK(dict_attr.get("outputs").isa<mlir::StringAttr>())
440 << "outputs missing in entry function attribute";
441 dict_attr.get("inputs").cast<mlir::StringAttr>().getValue().split(
442 input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
443 dict_attr.get("outputs").cast<mlir::StringAttr>().getValue().split(
444 output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
445 }
446
447 auto graph = std::make_unique<Graph>(OpRegistry::Global());
448
449 // Extract version info.
450 VersionDef versions;
451 auto module = function->getParentOfType<mlir::ModuleOp>();
452 if (mlir::succeeded(ExtractTfVersions(module, &versions))) {
453 graph->set_versions(versions);
454 }
455
456 Exporter exporter(graph.get(), tf_dialect);
457
458 auto graph_op = llvm::cast<mlir::tf_executor::GraphOp>(block.front());
459
460 // Set input and output names and increment the use counter for them to help
461 // generate unique names.
462 if (!output_names.empty()) {
463 const int num_data_results = graph_op.getNumResults();
464 const int64_t output_names_size = output_names.size();
465 TF_RET_CHECK(output_names_size == num_data_results)
466 << "output names (" << output_names.size()
467 << ") != terminator operands (" << num_data_results << ")";
468 llvm::DenseMap<Operation*, llvm::StringRef> output_op_to_name;
469 llvm::StringMap<Operation*> name_to_op;
470 for (const auto& it : llvm::enumerate(graph_op.GetFetch().getOperands())) {
471 // Skip control rets.
472 const int64_t index = it.index();
473 if (index >= num_data_results) break;
474 // TODO(jpienaar): If there is a result index specified, ensure only one
475 // and that it matches the result index of the op.
476 std::string name(output_names[index]);
477 auto tensor_id = ParseTensorName(name);
478 std::string tensor_id_node(tensor_id.node());
479 assert(!tensor_id_node.empty() && "expected non-empty name");
480 mlir::LegalizeNodeName(tensor_id_node);
481
482 // Ensure name does not get reused.
483 unique_output_names.push_back(
484 exporter.op_to_name_.GetUniqueName(tensor_id_node));
485 }
486 }
487
488 if (!input_names.empty()) {
489 TF_RET_CHECK(input_names.size() == block.getNumArguments());
490 for (const auto& it : llvm::enumerate(function.getArguments())) {
491 // TODO(lyandy): Update when changing feed/fetch import.
492 std::string name(input_names[it.index()]);
493 assert(!name.empty() && "expected non-empty name");
494 mlir::LegalizeNodeName(name);
495 auto tensor_id = ParseTensorName(name);
496 TF_RET_CHECK(tensor_id.index() == 0)
497 << "input port designation not supported";
498 // Only assign user of argument the input name if the main graph did not
499 // have its _Arg nodes lifted into the functions arguments.
500 // Ensure name does not get reused.
501 (void)exporter.op_to_name_.GetUniqueName(name);
502 }
503 }
504
505 // Adds nodes for basic block (function) arguments.
506 for (auto it : llvm::enumerate(block.getArguments())) {
507 int index = it.index();
508 auto arg = it.value();
509 mlir::Type type = arg.getType();
510 if (!type.isa<mlir::TensorType>()) {
511 return errors::InvalidArgument(
512 "FuncOps arguments must have tensor types. Found ",
513 mlir::debugString(type), " in function ", function.getName().str());
514 }
515
516 TF_RETURN_IF_ERROR(exporter.AddArgumentNode(
517 arg, index, !input_names.empty() ? input_names[index] : ""));
518 }
519
520 auto convert_called_function = [&](llvm::StringRef name) {
521 auto func = symbol_table.lookup<FuncOp>(name);
522 if (func != nullptr) {
523 TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, symbol_table,
524 func, flib, visited_functions));
525 // TODO(prakalps): Optimize to only add the requested function to graph
526 // library rather than the all the functions exported so far.
527 TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
528 }
529 return OkStatus();
530 };
531
532 // Adds nodes for operations.
533 for (Operation& inst : graph_op.GetBody()) {
534 for (auto type : inst.getResultTypes())
535 if (!type.isa<mlir::TensorType, mlir::tf_executor::ControlType,
536 mlir::tf_executor::TokenType>())
537 return errors::InvalidArgument(
538 "Values must be of tensor type, TensorFlow control type, or "
539 "TensorFlow token type. Found ",
540 mlir::debugString(type));
541
542 if (llvm::isa<mlir::tf_executor::NextIterationSourceOp>(inst)) {
543 // Skip tf_executor.NextIteration.Source as associated
544 // tf_executor.NextIteration.Sink will be used instead.
545 continue;
546 } else if (auto fetch = llvm::dyn_cast<mlir::tf_executor::FetchOp>(inst)) {
547 TF_RETURN_IF_ERROR(
548 exporter.AddFetchNode(function, fetch, unique_output_names));
549 } else if (auto island =
550 llvm::dyn_cast<mlir::tf_executor::IslandOp>(inst)) {
551 Operation& inner_op = island.GetBody().front();
552 auto op_name = GetTensorFlowOpName(inner_op.getName().getStringRef());
553 if (op_name.ok()) {
554 // If it is TF Control dialect specific op, look up custom operation
555 // in the module and first convert that, then add it to function
556 // definition library
557 // TODO(prakalps): If two functions have cyclic dependence, this will
558 // introduce an infinite loop.
559 TF_RETURN_IF_ERROR(convert_called_function(op_name.ValueOrDie().str()));
560 }
561
562 if (IsLegacyCallInstruction(&inner_op)) {
563 TF_RETURN_IF_ERROR(convert_called_function(
564 inner_op.getAttrOfType<mlir::SymbolRefAttr>("f")
565 .getLeafReference()
566 .getValue()));
567 }
568
569 TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inner_op));
570 } else {
571 TF_RETURN_IF_ERROR(exporter.AddInstructionNode(&inst));
572 }
573 }
574 // Adds edges between the argument, operation and return nodes.
575 for (Operation& inst : graph_op.GetBody()) {
576 TF_RETURN_IF_ERROR(exporter.AddEdge(&inst));
577 }
578 // Fixes the edges between the inserted nodes and special "_SOURCE" and
579 // "_SINK".
580 FixupSourceAndSinkEdges(graph.get());
581
582 TF_RETURN_IF_ERROR(
583 exporter.GetControlRetNodes(graph_op.GetFetch(), control_ret_nodes));
584
585 return graph;
586 }
587
ConvertLibFunction(const GraphExportConfig & configs,const Dialect * tf_dialect,const SymbolTable & symbol_table,FuncOp function,FunctionDefLibrary * flib,llvm::SmallDenseSet<FuncOp> & visited_functions)588 Status Exporter::ConvertLibFunction(
589 const GraphExportConfig& configs, const Dialect* tf_dialect,
590 const SymbolTable& symbol_table, FuncOp function, FunctionDefLibrary* flib,
591 llvm::SmallDenseSet<FuncOp>& visited_functions) {
592 // Return early if the function has already been exported.
593 bool is_new_function = visited_functions.insert(function).second;
594 if (!is_new_function) return OkStatus();
595
596 auto function_name = function.getName().str();
597
598 // TODO(fengliuai): use a small flib_def to reduce overhead
599 absl::flat_hash_set<Node*> control_ret_nodes;
600 TF_ASSIGN_OR_RETURN(
601 auto sub_graph,
602 Exporter::Convert(configs, tf_dialect, symbol_table, function, flib,
603 visited_functions, &control_ret_nodes));
604 const auto control_ret = [&](const Node* n) -> std::optional<string> {
605 return control_ret_nodes.contains(n)
606 ? absl::make_optional<string>(n->name())
607 : std::nullopt;
608 };
609 FunctionDef func_def;
610 TF_RETURN_IF_ERROR(
611 GraphToFunctionDef(*sub_graph, function_name, control_ret, &func_def));
612
613 // The node defs in FunctionDef might contain debug info which was added
614 // by the GraphToFunctionDef method. We should remove it if we don't want
615 // to export them to avoid failing the roundtrip test.
616 if (!configs.export_debug_info) {
617 for (auto& node_def : *func_def.mutable_node_def()) {
618 node_def.clear_experimental_debug_info();
619 }
620 }
621
622 // Checks for gradient attribute. If present converts the gradient function
623 // and populates the GradientDef.
624 auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
625 if (auto attr =
626 function->getAttrOfType<mlir::FlatSymbolRefAttr>(grad_string)) {
627 auto grad_func = symbol_table.lookup<FuncOp>(attr.getValue());
628 TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, symbol_table,
629 grad_func, flib, visited_functions));
630 GradientDef grad;
631 grad.set_function_name(function_name);
632 grad.set_gradient_func(grad_func.getName().str());
633 *flib->add_gradient() = grad;
634 }
635
636 auto stateful_string = mlir::TF::TensorFlowDialect::GetStatefulAttrName();
637 if (auto attr = function->getAttrOfType<mlir::UnitAttr>(stateful_string)) {
638 func_def.mutable_signature()->set_is_stateful(true);
639 }
640
641 // Ignore the gradient and is_stateful attribute on the function as they have
642 // been handled above. Ignore the entry func attribute as it is an MLIR
643 // metadata attribute and is not required in the function definition.
644 absl::flat_hash_set<absl::string_view> attrs_to_ignore = {
645 grad_string.data(), stateful_string.data(), kEntryFuncAttr};
646 llvm::SmallVector<mlir::NamedAttribute, 8> funcAttrs(
647 function->getDialectAttrs());
648 TF_RETURN_IF_ERROR(ConvertAttributes(funcAttrs, attrs_to_ignore,
649 /*remove_ref_type=*/false,
650 func_def.mutable_attr()));
651
652 for (int i = 0, e = function.getNumArguments(); i < e; ++i) {
653 if (auto resource_arg_unique_id_attr =
654 function.getArgAttrOfType<mlir::IntegerAttr>(
655 i, kResourceArgUniqueIdAttr)) {
656 (*func_def.mutable_resource_arg_unique_id())[i] =
657 resource_arg_unique_id_attr.getInt();
658 }
659 }
660
661 (*flib->add_function()) = std::move(func_def);
662 return OkStatus();
663 }
664
Convert(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,absl::flat_hash_set<Node * > * control_ret_nodes)665 Status Exporter::Convert(mlir::ModuleOp module,
666 const GraphExportConfig& configs,
667 std::unique_ptr<Graph>* graph,
668 FunctionLibraryDefinition* flib_def,
669 absl::flat_hash_set<Node*>* control_ret_nodes) {
670 mlir::StringAttr entry_func_id =
671 mlir::StringAttr::get(module.getContext(), "main");
672 std::optional<FuncOp> entry_func;
673 FunctionDefLibrary flib;
674 llvm::SmallDenseSet<FuncOp> visited_functions;
675 auto tf_dialect = module.getContext()->getLoadedDialect("tf");
676 // Construct SymbolTable to enable cheap function lookups. The cost
677 // of constructing the table is offset by the number of queries.
678 SymbolTable symbol_table(module);
679 for (auto function : module.getOps<FuncOp>()) {
680 if (function.isExternal())
681 return errors::FailedPrecondition("External functions not supported");
682
683 if (function.getName() == entry_func_id &&
684 !configs.export_entry_func_to_flib) {
685 entry_func.emplace(function);
686 } else {
687 TF_RETURN_IF_ERROR(ConvertLibFunction(configs, tf_dialect, symbol_table,
688 function, &flib,
689 visited_functions));
690 }
691 }
692
693 if (!configs.export_entry_func_to_flib) {
694 if (!entry_func.has_value())
695 return errors::FailedPrecondition(
696 "entry function `main` must be present");
697
698 // Updates the graph and the function library definition.
699 TF_ASSIGN_OR_RETURN(
700 *graph,
701 Exporter::Convert(configs, tf_dialect, symbol_table, entry_func.value(),
702 &flib, visited_functions, control_ret_nodes));
703 // Add FunctionDefs and GradientDefs of MLIR functions to graph's function
704 // library. If duplicate FunctionDefs already exist (can happen if exporter
705 // had already added some FunctionDefs to the library to support legacy
706 // calls), they are ignored.
707 TF_RETURN_IF_ERROR(graph->get()->AddFunctionLibrary(flib));
708 }
709
710 for (auto& func_def : flib.function()) {
711 TF_RETURN_IF_ERROR(flib_def->AddFunctionDef(func_def));
712 }
713 for (auto& grad_def : flib.gradient()) {
714 TF_RETURN_IF_ERROR(flib_def->AddGradientDef(grad_def));
715 }
716 return OkStatus();
717 }
718 } // namespace
719
ConvertMlirToGraph(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,absl::flat_hash_set<Node * > * control_ret_nodes)720 Status ConvertMlirToGraph(mlir::ModuleOp module,
721 const GraphExportConfig& configs,
722 std::unique_ptr<Graph>* graph,
723 FunctionLibraryDefinition* flib_def,
724 absl::flat_hash_set<Node*>* control_ret_nodes) {
725 mlir::StatusScopedDiagnosticHandler sh(module.getContext());
726 if (failed(VerifyExportSuitable(module))) return sh.ConsumeStatus();
727 return sh.Combine(
728 Exporter::Convert(module, configs, graph, flib_def, control_ret_nodes));
729 }
730
ConvertMlirToGraph(mlir::ModuleOp module,const GraphExportConfig & configs,std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def)731 Status ConvertMlirToGraph(mlir::ModuleOp module,
732 const GraphExportConfig& configs,
733 std::unique_ptr<Graph>* graph,
734 FunctionLibraryDefinition* flib_def) {
735 absl::flat_hash_set<Node*> control_ret_nodes;
736 return ConvertMlirToGraph(module, configs, graph, flib_def,
737 &control_ret_nodes);
738 }
739
ConvertMlirToGraphdef(mlir::ModuleOp module,const GraphExportConfig & configs)740 StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
741 mlir::ModuleOp module, const GraphExportConfig& configs) {
742 FunctionLibraryDefinition flib_def(OpRegistry::Global(),
743 FunctionDefLibrary());
744 std::unique_ptr<Graph> graph;
745 TF_RETURN_IF_ERROR(ConvertMlirToGraph(module, configs, &graph, &flib_def));
746
747 // If the entry function is exported to flib, then no graph is constructed.
748 // Construct one in that case.
749 if (configs.export_entry_func_to_flib) {
750 graph = std::make_unique<Graph>(OpRegistry::Global());
751 // TODO(hinsu): Avoid Proto -> Memory -> Proto conversion here.
752 FunctionDefLibrary flib = flib_def.ToProto();
753 TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(flib));
754 }
755
756 auto graphdef = std::make_unique<GraphDef>();
757 graph->ToGraphDef(graphdef.get());
758 if (!configs.export_library) graphdef->clear_library();
759 if (!configs.export_shapes) {
760 for (auto& node_def : *graphdef->mutable_node()) {
761 node_def.mutable_attr()->erase("shape");
762 }
763 }
764 if (!configs.export_debug_info) {
765 for (auto& node_def : *graphdef->mutable_node()) {
766 node_def.clear_experimental_debug_info();
767 }
768 }
769 return graphdef;
770 }
771
ConvertMlirFunctionToFunctionLibraryDef(FuncOp func,const GraphExportConfig & configs,FunctionDef * function_def)772 stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef(
773 FuncOp func, const GraphExportConfig& configs, FunctionDef* function_def) {
774 Dialect* tf_dialect = func.getContext()->getLoadedDialect("tf");
775 FunctionDefLibrary flib;
776 llvm::SmallDenseSet<FuncOp> visited_functions;
777 // Construct SymbolTable to enable cheap function lookups. The cost
778 // of constructing the table is offset by the number of queries. Even
779 // though this only converts one function in theory, this function
780 // may have gradient associated which would result in a lookup. This
781 // could be made lazy if we find this to be broad.
782 SymbolTable symbol_table(func->getParentOfType<mlir::ModuleOp>());
783 TF_RETURN_IF_ERROR(Exporter::ConvertLibFunction(
784 configs, tf_dialect, symbol_table, func, &flib, visited_functions));
785 for (auto& func_def : flib.function()) {
786 if (func_def.signature().name() == func.getName()) {
787 *function_def = func_def;
788 return OkStatus();
789 }
790 }
791 return errors::InvalidArgument(
792 "Function couldn't be found in the FunctionDefLibrary after converting "
793 "from MLIR");
794 }
795
796 } // namespace tensorflow
797